Skip to content

Commit c385724

Browse files
author
egor
committed
Almost final version
Signed-off-by: egor <[email protected]>
1 parent bdf18b3 commit c385724

File tree

1 file changed

+109
-86
lines changed

1 file changed

+109
-86
lines changed

notebooks/Name suggestion.ipynb

Lines changed: 109 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@
5656
" ENC_TRAIN_NAMES = [\"train.bpe.tgt\"]\n",
5757
" ENC_VAL_BODIES = [\"val.bpe.src\"]\n",
5858
" ENC_VAL_NAMES = [\"val.bpe.tgt\"]\n",
59-
" VOCABULARY = [\"vocab.txt\"]\n",
59+
" TGT_VOCABULARY = [\"tgt.vocab\"]\n",
60+
" SRC_VOCABULARY = [\"src.vocab\"]\n",
6061
"\n",
6162
" \n",
6263
"class Dirs(DirsABC, Enum):\n",
@@ -535,63 +536,15 @@
535536
"bpe_encode(run.path(Files.VAL_BODIES), run.path(Files.ENC_VAL_NAMES))"
536537
]
537538
},
538-
{
539-
"cell_type": "code",
540-
"execution_count": null,
541-
"metadata": {},
542-
"outputs": [],
543-
"source": []
544-
},
545-
{
546-
"cell_type": "code",
547-
"execution_count": null,
548-
"metadata": {},
549-
"outputs": [],
550-
"source": []
551-
},
552-
{
553-
"cell_type": "markdown",
554-
"metadata": {},
555-
"source": [
556-
"# Train seq2seq model"
557-
]
558-
},
559-
{
560-
"cell_type": "code",
561-
"execution_count": null,
562-
"metadata": {},
563-
"outputs": [],
564-
"source": [
565-
"\"\"\"-data /ssd/devfest2019-workshop/opennmt_format_input -save_model /ssd/devfest2019-workshop/transformer_bpe \\\n",
566-
" -layers 6 -rnn_size 512 -word_vec_size 512 -transformer_ff 2048 -heads 8 \\\n",
567-
" -encoder_type transformer -decoder_type transformer -position_encoding \\\n",
568-
" -train_steps 200000 -max_generator_batches 2 -dropout 0.1 \\\n",
569-
" -batch_size 4096 -batch_type tokens -normalization tokens -accum_count 2 \\\n",
570-
" -optim adam -adam_beta2 0.998 -decay_method noam -warmup_steps 8000 -learning_rate 2 \\\n",
571-
" -max_grad_norm 0 -param_init 0 -param_init_glorot \\\n",
572-
" -label_smoothing 0.1 -valid_steps 10000 -save_checkpoint_steps 10000 \\\n",
573-
" -world_size 4 -gpu_ranks 0 1 2 3 \"\"\""
574-
]
575-
},
576-
{
577-
"cell_type": "code",
578-
"execution_count": null,
579-
"metadata": {},
580-
"outputs": [],
581-
"source": [
582-
"# preprocess\n",
583-
"!echo -train_src {train_bodies_bpe_loc} \\\n",
584-
" -train_tgt {train_names_bpe_loc} \\\n",
585-
" -valid_src {val_bodies_bpe_loc} \\\n",
586-
" -valid_tgt {val_names_bpe_loc} \\\n",
587-
" -save_data /ssd/devfest2019-workshop/opennmt_format_input"
588-
]
589-
},
590539
{
591540
"cell_type": "markdown",
592541
"metadata": {},
593542
"source": [
594-
"# Train transformer with openNMT-tf"
543+
"# Train seq2seq model\n",
544+
"\n",
545+
"* we will use `openNMT-tf`\n",
546+
"* prepare vocabularies (we will use functionality to train translation model from identifiers to function names)\n",
547+
"* train model"
595548
]
596549
},
597550
{
@@ -600,16 +553,22 @@
600553
"metadata": {},
601554
"outputs": [],
602555
"source": [
556+
"# TODO: src_vocab_loc, tgt_vocab_loc\n",
557+
"\n",
558+
"# approach requires to provide vocabularies\n",
559+
"# so launch these commands\n",
603560
"def generate_build_vocab(save_vocab_loc, input_text, vocab_size=vocab_size):\n",
604561
" return \"onmt-build-vocab --size %s --save_vocab %s %s\" % (vocab_size, \n",
605562
" save_vocab_loc,\n",
606563
" input_text)\n",
607564
"\n",
608-
"print(generate_build_vocab(save_vocab_loc=\"bpe_input/src.vocab\",\n",
609-
" input_text=\"bpe_input/train.src\",\n",
565+
"src_vocab_loc = os.path.join(bpe_base_dir, \"src.vocab\")\n",
566+
"print(generate_build_vocab(save_vocab_loc=src_vocab_loc,\n",
567+
" input_text=train_bodies_bpe_loc,\n",
610568
" vocab_size=vocab_size + 10))\n",
611-
"print(generate_build_vocab(save_vocab_loc=\"bpe_input/tgt.vocab\",\n",
612-
" input_text=\"bpe_input/train.tgt\",\n",
569+
"tgt_vocab_loc = os.path.join(bpe_base_dir, \"tgt.vocab\")\n",
570+
"print(generate_build_vocab(save_vocab_loc=tgt_vocab_loc,\n",
571+
" input_text=train_names_bpe_loc,\n",
613572
" vocab_size=vocab_size + 10))"
614573
]
615574
},
@@ -619,18 +578,30 @@
619578
"metadata": {},
620579
"outputs": [],
621580
"source": [
622-
"yaml_content = \"\"\"\n",
623-
"model_dir: run/\n",
624-
"\n",
625-
"data:\n",
626-
" train_features_file: bpe_input/train.src\n",
627-
" train_labels_file: bpe_input/train.tgt\n",
628-
" eval_features_file: bpe_input/val.src\n",
629-
" eval_labels_file: bpe_input/val.tgt\n",
630-
" source_vocabulary: bpe_input/src.vocab\n",
631-
" target_vocabulary: bpe_input/tgt.vocab\n",
581+
"base_train_dir = os.path.join(bpe_base_dir, \"seq2seq\")\n",
582+
"os.makedirs(base_train_dir, exist_ok=True)\n",
583+
"model_dir = os.path.join(base_train_dir, \"run/\")\n",
632584
"\n",
585+
"# prepare config file for model\n",
586+
"config_yaml = os.path.join(base_train_dir, \"config.yml\")\n",
587+
"# this directory will contain evaluation results of the model, checkpoints and so on\n",
588+
"yaml_content = \"model_dir: %s \\n\" % model_dir\n",
633589
"\n",
590+
"# describe where data is located\n",
591+
"yaml_content += \"\"\"\n",
592+
"data:\n",
593+
" train_features_file: %s\n",
594+
" train_labels_file: %s\n",
595+
" eval_features_file: %s\n",
596+
" eval_labels_file: %s\n",
597+
" source_vocabulary: %s\n",
598+
" target_vocabulary: %s\n",
599+
"\"\"\" % (train_bodies_bpe_loc, train_names_bpe_loc,\n",
600+
" val_bodies_bpe_loc, val_names_bpe_loc,\n",
601+
" src_vocab_loc, tgt_vocab_loc)\n",
602+
"\n",
603+
"# other useful configurations\n",
604+
"yaml_content += \"\"\"\n",
634605
"train:\n",
635606
" # (optional when batch_type=tokens) If not set, the training will search the largest\n",
636607
" # possible batch size.\n",
@@ -662,7 +633,7 @@
662633
" min_improvement: 0.01\n",
663634
" steps: 2\n",
664635
"\"\"\"\n",
665-
"config_yaml = \"openNMT_tf_train_data.yml\"\n",
636+
"\n",
666637
"with open(config_yaml, \"w\") as f:\n",
667638
" f.write(yaml_content)"
668639
]
@@ -673,24 +644,48 @@
673644
"metadata": {},
674645
"outputs": [],
675646
"source": [
676-
"!cp openNMT_tf_train_data.yml /ssd/devfest2019-workshop/\n"
647+
"# how to launch training\n",
648+
"train_cmd = \"\"\"\n",
649+
"onmt-main --model_type LuongAttention \\\n",
650+
"--config %s --auto_config train --with_eval\"\"\" % config_yaml\n",
651+
"print(train_cmd)\n",
652+
"\n",
653+
"# in case of GPU you can specify CUDA_VISIBLE_DEVICES & number of GPUs to use\n",
654+
"cmd_gpu = \"\"\"\n",
655+
"CUDA_VISIBLE_DEVICES=%s onmt-main --model_type LuongAttention \\\n",
656+
"--config %s --auto_config train --with_eval --num_gpus %s\"\"\" % (\"0,1\", config_yaml, 2)"
677657
]
678658
},
679659
{
680660
"cell_type": "code",
681-
"execution_count": null,
661+
"execution_count": 1,
682662
"metadata": {},
683-
"outputs": [],
663+
"outputs": [
664+
{
665+
"name": "stdout",
666+
"output_type": "stream",
667+
"text": [
668+
"alex.ipynb\t\t jupyter-server-config.json requirements-tf.txt\r\n",
669+
"base_egor.ipynb\t\t Makefile\t\t\t requirements.txt\r\n",
670+
"Dockerfile\t\t notebooks\t\t\t src.vocab\r\n",
671+
"docs\t\t\t pretrained.zip\t\t sshuttle.pid\r\n",
672+
"images\t\t\t README.md\t\t\t tgt.vocab\r\n",
673+
"jupyter-notebook-config.json requirements-bigartm.txt\r\n"
674+
]
675+
}
676+
],
684677
"source": [
685-
"cmd = \"CUDA_VISIBLE_DEVICES=0,1,2,3 onmt-main --model_type Transformer --config openNMT_tf_train_data.yml \\\n",
686-
"--auto_config train --with_eval --num_gpus 4\""
678+
"!ls"
687679
]
688680
},
689681
{
690682
"cell_type": "markdown",
691683
"metadata": {},
692684
"source": [
693-
"# Predict"
685+
"# Predict\n",
686+
"* we will use pretrained on several GPUs model to save time\n",
687+
"* predictions will be saved to file \n",
688+
"* predicted BPE ids will be converted back to text"
694689
]
695690
},
696691
{
@@ -699,11 +694,14 @@
699694
"metadata": {},
700695
"outputs": [],
701696
"source": [
702-
"\"\"\"onmt-main \\\n",
703-
" --config openNMT_tf_train_data.yml --auto_config \\\n",
704-
" average_checkpoints \\\n",
705-
" --output_dir run/baseline/avg \\\n",
706-
" --max_count 5\"\"\""
697+
"bpe_val_predictions = os.path.join(base_dir, \"val.pred.tgt\")\n",
698+
"pretrained_model = os.path.join(base_dir, \"pretrained/model\")\n",
699+
"predict_cmd = \"\"\"onmt-main \\\n",
700+
"--config %s --auto_config \\\n",
701+
"infer \\\n",
702+
"--features_file %s \\\n",
703+
"--predictions_file %s \\\n",
704+
"--checkpoint_path run/baseline/avg/ckpt-5000\"\"\" % (config_yaml, val_bodies_bpe_loc, bpe_val_predictions)"
707705
]
708706
},
709707
{
@@ -712,7 +710,32 @@
712710
"metadata": {},
713711
"outputs": [],
714712
"source": [
715-
"!onmt-main --config openNMT_tf_train_data.yml --auto_config infer --features_file bpe_input/val.src --predictions_file predictions/val.pred.tgt"
713+
"pred_ids = []\n",
714+
"with open(bpe_val_predictions, \"r\") as f:\n",
715+
" for line in f.readlines():\n",
716+
" pred_ids.append(list(map(int, line.split())))\n",
717+
"\n",
718+
"pred_val_function_names = bpe.decode(pred_ids)"
719+
]
720+
},
721+
{
722+
"cell_type": "code",
723+
"execution_count": null,
724+
"metadata": {},
725+
"outputs": [],
726+
"source": [
727+
"gt_ids = []\n",
728+
"with open(val_names_bpe_loc, \"r\") as f:\n",
729+
" for i, line in enumerate(f.readlines()):\n",
730+
" gt_ids.append(list(map(int, line.split())))\n",
731+
"gt_val_function_names = bpe.decode(gt_ids)"
732+
]
733+
},
734+
{
735+
"cell_type": "markdown",
736+
"metadata": {},
737+
"source": [
738+
"# And finally let's see the results!"
716739
]
717740
},
718741
{
@@ -721,10 +744,10 @@
721744
"metadata": {},
722745
"outputs": [],
723746
"source": [
724-
"\"\"\"onmt-main \\\n",
725-
" --config openNMT_tf_train_data.yml --auto_config \\\n",
726-
" --checkpoint_path run/baseline/avg/ckpt-5000 \\\n",
727-
" infer --features_file text_input/val.src --predictions_file predictions/val.pred.tgt\"\"\""
747+
"for i, (a, b) in enumerate(zip(gt_function_names, predicted_function_names)):\n",
748+
" if i == 100:\n",
749+
" break\n",
750+
" print(\"%s | %s\" % (a, b)) "
728751
]
729752
}
730753
],
@@ -744,7 +767,7 @@
744767
"name": "python",
745768
"nbconvert_exporter": "python",
746769
"pygments_lexer": "ipython3",
747-
"version": "3.6.8"
770+
"version": "3.6.7"
748771
}
749772
},
750773
"nbformat": 4,

0 commit comments

Comments
 (0)