Skip to content

Commit 14a07a9

Browse files
author
egor
committed
Update name notebook
Signed-off-by: egor <[email protected]>
1 parent c385724 commit 14a07a9

File tree

1 file changed

+70
-64
lines changed

1 file changed

+70
-64
lines changed

notebooks/Name suggestion.ipynb

Lines changed: 70 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
},
2828
{
2929
"cell_type": "code",
30-
"execution_count": null,
30+
"execution_count": 8,
3131
"metadata": {},
3232
"outputs": [],
3333
"source": [
@@ -58,10 +58,14 @@
5858
" ENC_VAL_NAMES = [\"val.bpe.tgt\"]\n",
5959
" TGT_VOCABULARY = [\"tgt.vocab\"]\n",
6060
" SRC_VOCABULARY = [\"src.vocab\"]\n",
61+
" MODEL_CONFIG = [\"model\", \"config.yml\"] \n",
62+
" MODEL_PRETRAINED = [\"pretrained\", \"ckpt-25000\"]\n",
63+
" ENC_VAL_NAMES_PRED = [\"val.bpe.pred.tgt\"]\n",
6164
"\n",
6265
" \n",
6366
"class Dirs(DirsABC, Enum):\n",
6467
" TF_MODELS = [\"tf\", \"models\"]\n",
68+
" MODEL_RUN = [\"model\", \"run\"]\n",
6569
"\n",
6670
"run = Run(\"name-suggestion\", \"java-full\")\n",
6771
"\n",
@@ -263,13 +267,6 @@
263267
"highlight_function_name_and_identifiers(run.path(Files.FUNCTIONS), 3)"
264268
]
265269
},
266-
{
267-
"cell_type": "code",
268-
"execution_count": null,
269-
"metadata": {},
270-
"outputs": [],
271-
"source": []
272-
},
273270
{
274271
"cell_type": "markdown",
275272
"metadata": {},
@@ -281,6 +278,13 @@
281278
" - Y lable, a name of the function.\n"
282279
]
283280
},
281+
{
282+
"cell_type": "code",
283+
"execution_count": null,
284+
"metadata": {},
285+
"outputs": [],
286+
"source": []
287+
},
284288
{
285289
"cell_type": "code",
286290
"execution_count": null,
@@ -341,13 +345,6 @@
341345
"extract_functions_parallel(run.path(Files.FUNCTIONS))"
342346
]
343347
},
344-
{
345-
"cell_type": "code",
346-
"execution_count": null,
347-
"metadata": {},
348-
"outputs": [],
349-
"source": []
350-
},
351348
{
352349
"cell_type": "markdown",
353350
"metadata": {},
@@ -381,6 +378,13 @@
381378
"We are going to use a sing vocabulary for both, identifiers and function names. In order to do so, we will need to train BPE tokenizer on a file that contains all identifiers and function names in plain text."
382379
]
383380
},
381+
{
382+
"cell_type": "code",
383+
"execution_count": null,
384+
"metadata": {},
385+
"outputs": [],
386+
"source": []
387+
},
384388
{
385389
"cell_type": "code",
386390
"execution_count": null,
@@ -562,14 +566,17 @@
562566
" save_vocab_loc,\n",
563567
" input_text)\n",
564568
"\n",
565-
"src_vocab_loc = os.path.join(bpe_base_dir, \"src.vocab\")\n",
566-
"print(generate_build_vocab(save_vocab_loc=src_vocab_loc,\n",
567-
" input_text=train_bodies_bpe_loc,\n",
568-
" vocab_size=vocab_size + 10))\n",
569-
"tgt_vocab_loc = os.path.join(bpe_base_dir, \"tgt.vocab\")\n",
570-
"print(generate_build_vocab(save_vocab_loc=tgt_vocab_loc,\n",
571-
" input_text=train_names_bpe_loc,\n",
572-
" vocab_size=vocab_size + 10))"
569+
"if not os.path.exists(run.path(Files.SRC_VOCABULARY)):\n",
570+
" # in case of pretrained model we reuse vocabulary\n",
571+
" cmd = generate_build_vocab(save_vocab_loc=run.path(Files.SRC_VOCABULARY),\n",
572+
" input_text=run.path(Files.ENC_TRAIN_BODIES),\n",
573+
" vocab_size=vocab_size + 10)\n",
574+
" ! {cmd}\n",
575+
"\n",
576+
" cmd = generate_build_vocab(save_vocab_loc=run.path(Files.TGT_VOCABULARY),\n",
577+
" input_text=run.path(Files.ENC_TRAIN_NAMES),\n",
578+
" vocab_size=vocab_size + 10)\n",
579+
" ! {cmd}"
573580
]
574581
},
575582
{
@@ -578,12 +585,11 @@
578585
"metadata": {},
579586
"outputs": [],
580587
"source": [
581-
"base_train_dir = os.path.join(bpe_base_dir, \"seq2seq\")\n",
582-
"os.makedirs(base_train_dir, exist_ok=True)\n",
583-
"model_dir = os.path.join(base_train_dir, \"run/\")\n",
588+
"\n",
589+
"model_dir = run.path(Dirs.MODEL_RUN)\n",
584590
"\n",
585591
"# prepare config file for model\n",
586-
"config_yaml = os.path.join(base_train_dir, \"config.yml\")\n",
592+
"config_yaml = run.path(Files.MODEL_CONFIG)\n",
587593
"# this directory will contain evaluation results of the model, checkpoints and so on\n",
588594
"yaml_content = \"model_dir: %s \\n\" % model_dir\n",
589595
"\n",
@@ -596,9 +602,12 @@
596602
" eval_labels_file: %s\n",
597603
" source_vocabulary: %s\n",
598604
" target_vocabulary: %s\n",
599-
"\"\"\" % (train_bodies_bpe_loc, train_names_bpe_loc,\n",
600-
" val_bodies_bpe_loc, val_names_bpe_loc,\n",
601-
" src_vocab_loc, tgt_vocab_loc)\n",
605+
"\"\"\" % (run.path(Files.ENC_TRAIN_BODIES), \n",
606+
" run.path(Files.ENC_TRAIN_NAMES),\n",
607+
" run.path(Files.ENC_VAL_BODIES), \n",
608+
" run.path(Files.ENC_VAL_NAMES),\n",
609+
" run.path(Files.SRC_VOCABULARY), \n",
610+
" run.path(Files.TGT_VOCABULARY))\n",
602611
"\n",
603612
"# other useful configurations\n",
604613
"yaml_content += \"\"\"\n",
@@ -645,37 +654,19 @@
645654
"outputs": [],
646655
"source": [
647656
"# how to launch training\n",
648-
"train_cmd = \"\"\"\n",
649-
"onmt-main --model_type LuongAttention \\\n",
650-
"--config %s --auto_config train --with_eval\"\"\" % config_yaml\n",
651-
"print(train_cmd)\n",
657+
"GPU_USE = False\n",
658+
"if not GPU_USE:\n",
659+
" train_cmd = \"\"\"\n",
660+
" onmt-main --model_type LuongAttention \\\n",
661+
" --config %s --auto_config train --with_eval\"\"\" % config_yaml\n",
662+
" ! {train_cmd}\n",
652663
"\n",
653664
"# in case of GPU you can specify CUDA_VISIBLE_DEVICES & number of GPUs to use\n",
654-
"cmd_gpu = \"\"\"\n",
655-
"CUDA_VISIBLE_DEVICES=%s onmt-main --model_type LuongAttention \\\n",
656-
"--config %s --auto_config train --with_eval --num_gpus %s\"\"\" % (\"0,1\", config_yaml, 2)"
657-
]
658-
},
659-
{
660-
"cell_type": "code",
661-
"execution_count": 1,
662-
"metadata": {},
663-
"outputs": [
664-
{
665-
"name": "stdout",
666-
"output_type": "stream",
667-
"text": [
668-
"alex.ipynb\t\t jupyter-server-config.json requirements-tf.txt\r\n",
669-
"base_egor.ipynb\t\t Makefile\t\t\t requirements.txt\r\n",
670-
"Dockerfile\t\t notebooks\t\t\t src.vocab\r\n",
671-
"docs\t\t\t pretrained.zip\t\t sshuttle.pid\r\n",
672-
"images\t\t\t README.md\t\t\t tgt.vocab\r\n",
673-
"jupyter-notebook-config.json requirements-bigartm.txt\r\n"
674-
]
675-
}
676-
],
677-
"source": [
678-
"!ls"
665+
"if GPU_USE:\n",
666+
" cmd_gpu = \"\"\"\n",
667+
" CUDA_VISIBLE_DEVICES=%s onmt-main --model_type LuongAttention \\\n",
668+
" --config %s --auto_config train --with_eval --num_gpus %s\"\"\" % (\"0,1\", config_yaml, 2)\n",
669+
" ! {cmd_gpu}"
679670
]
680671
},
681672
{
@@ -694,14 +685,29 @@
694685
"metadata": {},
695686
"outputs": [],
696687
"source": [
697-
"bpe_val_predictions = os.path.join(base_dir, \"val.pred.tgt\")\n",
698-
"pretrained_model = os.path.join(base_dir, \"pretrained/model\")\n",
688+
"# you have to specify location of pretrained model\n",
689+
"pretrained_model = None\n",
690+
"if pretrained_model is None:\n",
691+
" pretrained_model = run.path(Files.MODEL_PRETRAINED)"
692+
]
693+
},
694+
{
695+
"cell_type": "code",
696+
"execution_count": null,
697+
"metadata": {},
698+
"outputs": [],
699+
"source": [
700+
"bpe_val_predictions = \"val.pred.tgt\"\n",
701+
"\n",
699702
"predict_cmd = \"\"\"onmt-main \\\n",
700703
"--config %s --auto_config \\\n",
701704
"infer \\\n",
702705
"--features_file %s \\\n",
703706
"--predictions_file %s \\\n",
704-
"--checkpoint_path run/baseline/avg/ckpt-5000\"\"\" % (config_yaml, val_bodies_bpe_loc, bpe_val_predictions)"
707+
"--checkpoint_path %s\"\"\" % (config_yaml, \n",
708+
" run.path(Files.ENC_VAL_BODIES), \n",
709+
" run.path(Files.ENC_VAL_NAMES_PRED),\n",
710+
" pretrained_model)"
705711
]
706712
},
707713
{
@@ -711,7 +717,7 @@
711717
"outputs": [],
712718
"source": [
713719
"pred_ids = []\n",
714-
"with open(bpe_val_predictions, \"r\") as f:\n",
720+
"with open(run.path(Files.ENC_VAL_NAMES_PRED), \"r\") as f:\n",
715721
" for line in f.readlines():\n",
716722
" pred_ids.append(list(map(int, line.split())))\n",
717723
"\n",
@@ -725,7 +731,7 @@
725731
"outputs": [],
726732
"source": [
727733
"gt_ids = []\n",
728-
"with open(val_names_bpe_loc, \"r\") as f:\n",
734+
"with open(run.path(Files.ENC_VAL_NAMES), \"r\") as f:\n",
729735
" for i, line in enumerate(f.readlines()):\n",
730736
" gt_ids.append(list(map(int, line.split())))\n",
731737
"gt_val_function_names = bpe.decode(gt_ids)"

0 commit comments

Comments
 (0)