Skip to content

Commit 446323c

Browse files
author
egor
committed
Checked seq2seq part of pipeline locally
Signed-off-by: egor <[email protected]>
1 parent 14a07a9 commit 446323c

File tree

1 file changed

+52
-45
lines changed

1 file changed

+52
-45
lines changed

notebooks/Name suggestion.ipynb

Lines changed: 52 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
},
2828
{
2929
"cell_type": "code",
30-
"execution_count": 8,
30+
"execution_count": null,
3131
"metadata": {},
3232
"outputs": [],
3333
"source": [
@@ -61,6 +61,8 @@
6161
" MODEL_CONFIG = [\"model\", \"config.yml\"] \n",
6262
" MODEL_PRETRAINED = [\"pretrained\", \"ckpt-25000\"]\n",
6363
" ENC_VAL_NAMES_PRED = [\"val.bpe.pred.tgt\"]\n",
64+
" SAMPLE_ENC_VAL_BODIES = [\"sample_val.bpe.src\"]\n",
65+
" SAMPLE_ENC_VAL_NAMES = [\"sample_val.bpe.tgt\"]\n",
6466
"\n",
6567
" \n",
6668
"class Dirs(DirsABC, Enum):\n",
@@ -278,13 +280,6 @@
278280
" - Y lable, a name of the function.\n"
279281
]
280282
},
281-
{
282-
"cell_type": "code",
283-
"execution_count": null,
284-
"metadata": {},
285-
"outputs": [],
286-
"source": []
287-
},
288283
{
289284
"cell_type": "code",
290285
"execution_count": null,
@@ -383,7 +378,9 @@
383378
"execution_count": null,
384379
"metadata": {},
385380
"outputs": [],
386-
"source": []
381+
"source": [
382+
"import pandas as pd"
383+
]
387384
},
388385
{
389386
"cell_type": "code",
@@ -491,26 +488,6 @@
491488
"Get vector represenation using the vocabulary from the trained BPE tokenizer, in the format compatible with [OpenNMT](http://opennmt.net/OpenNMT-tf/data.html#vocabulary)."
492489
]
493490
},
494-
{
495-
"cell_type": "markdown",
496-
"metadata": {},
497-
"source": [
498-
"## Save the vocabulary on disk\n",
499-
"\n",
500-
"We'll need only one file, as the same vocabulary will be used for both, identifiers and function names. Different vocabularies can be used without any change to the model e.g the sub-words (BPE) only for identifers and char for the function names."
501-
]
502-
},
503-
{
504-
"cell_type": "code",
505-
"execution_count": null,
506-
"metadata": {},
507-
"outputs": [],
508-
"source": [
509-
"with open(run.path(Files.VOCABULARY), \"w\") as vocab_fd:\n",
510-
" for i in range(vocab_size + 5):\n",
511-
" vocab_fd.write(str(i) + \"\\n\")"
512-
]
513-
},
514491
{
515492
"cell_type": "markdown",
516493
"metadata": {},
@@ -537,7 +514,7 @@
537514
"bpe_encode(run.path(Files.TRAIN_BODIES), run.path(Files.ENC_TRAIN_BODIES))\n",
538515
"bpe_encode(run.path(Files.TRAIN_NAMES), run.path(Files.ENC_TRAIN_NAMES))\n",
539516
"bpe_encode(run.path(Files.VAL_BODIES), run.path(Files.ENC_VAL_BODIES))\n",
540-
"bpe_encode(run.path(Files.VAL_BODIES), run.path(Files.ENC_VAL_NAMES))"
517+
"bpe_encode(run.path(Files.VAL_NAMES), run.path(Files.ENC_VAL_NAMES))"
541518
]
542519
},
543520
{
@@ -557,7 +534,7 @@
557534
"metadata": {},
558535
"outputs": [],
559536
"source": [
560-
"# TODO: src_vocab_loc, tgt_vocab_loc\n",
537+
"import os\n",
561538
"\n",
562539
"# approach requires to provide vocabularies\n",
563540
"# so launch these commands\n",
@@ -567,6 +544,7 @@
567544
" input_text)\n",
568545
"\n",
569546
"if not os.path.exists(run.path(Files.SRC_VOCABULARY)):\n",
547+
" print(\"Generating vocabularies\")\n",
570548
" # in case of pretrained model we reuse vocabulary\n",
571549
" cmd = generate_build_vocab(save_vocab_loc=run.path(Files.SRC_VOCABULARY),\n",
572550
" input_text=run.path(Files.ENC_TRAIN_BODIES),\n",
@@ -585,7 +563,6 @@
585563
"metadata": {},
586564
"outputs": [],
587565
"source": [
588-
"\n",
589566
"model_dir = run.path(Dirs.MODEL_RUN)\n",
590567
"\n",
591568
"# prepare config file for model\n",
@@ -614,7 +591,7 @@
614591
"train:\n",
615592
" # (optional when batch_type=tokens) If not set, the training will search the largest\n",
616593
" # possible batch size.\n",
617-
" batch_size: 256\n",
594+
" batch_size: 32\n",
618595
"\n",
619596
"eval:\n",
620597
" # (optional) The batch size to use (default: 32).\n",
@@ -647,6 +624,15 @@
647624
" f.write(yaml_content)"
648625
]
649626
},
627+
{
628+
"cell_type": "markdown",
629+
"metadata": {},
630+
"source": [
631+
"### small GPU vs CPU comparison:\n",
632+
"* CPU with 4 cores: `source words/s = 104, target words/s = 34`\n",
633+
"* 1080 GPU: `source words/s = 6959, target words/s = 1434`"
634+
]
635+
},
650636
{
651637
"cell_type": "code",
652638
"execution_count": null,
@@ -697,17 +683,28 @@
697683
"metadata": {},
698684
"outputs": [],
699685
"source": [
700-
"bpe_val_predictions = \"val.pred.tgt\"\n",
701-
"\n",
686+
"# limit number of samples to process\n",
687+
"!head -50 {run.path(Files.ENC_VAL_BODIES)} > {run.path(Files.SAMPLE_ENC_VAL_BODIES)}\n",
688+
"!head -50 {run.path(Files.ENC_VAL_NAMES)} > {run.path(Files.SAMPLE_ENC_VAL_NAMES)}"
689+
]
690+
},
691+
{
692+
"cell_type": "code",
693+
"execution_count": null,
694+
"metadata": {},
695+
"outputs": [],
696+
"source": [
702697
"predict_cmd = \"\"\"onmt-main \\\n",
703-
"--config %s --auto_config \\\n",
698+
"--config %s --auto_config --model_type LuongAttention \\\n",
699+
"--checkpoint_path %s \\\n",
704700
"infer \\\n",
705701
"--features_file %s \\\n",
706-
"--predictions_file %s \\\n",
707-
"--checkpoint_path %s\"\"\" % (config_yaml, \n",
708-
" run.path(Files.ENC_VAL_BODIES), \n",
702+
"--predictions_file %s\n",
703+
"\"\"\" % (config_yaml, pretrained_model,\n",
704+
" run.path(Files.SAMPLE_ENC_VAL_BODIES), \n",
709705
" run.path(Files.ENC_VAL_NAMES_PRED),\n",
710-
" pretrained_model)"
706+
" )\n",
707+
"! {predict_cmd}"
711708
]
712709
},
713710
{
@@ -731,7 +728,7 @@
731728
"outputs": [],
732729
"source": [
733730
"gt_ids = []\n",
734-
"with open(run.path(Files.ENC_VAL_NAMES), \"r\") as f:\n",
731+
"with open(run.path(Files.SAMPLE_ENC_VAL_NAMES), \"r\") as f:\n",
735732
" for i, line in enumerate(f.readlines()):\n",
736733
" gt_ids.append(list(map(int, line.split())))\n",
737734
"gt_val_function_names = bpe.decode(gt_ids)"
@@ -750,10 +747,20 @@
750747
"metadata": {},
751748
"outputs": [],
752749
"source": [
753-
"for i, (a, b) in enumerate(zip(gt_function_names, predicted_function_names)):\n",
754-
" if i == 100:\n",
755-
" break\n",
756-
" print(\"%s | %s\" % (a, b)) "
750+
"for gt_name, pred_name in zip(gt_val_function_names, pred_val_function_names):\n",
751+
" print(\"%s | %s\" % (gt_name, pred_name)) "
752+
]
753+
},
754+
{
755+
"cell_type": "markdown",
756+
"metadata": {},
757+
"source": [
758+
"# Results maybe not so good because a lot of context information is missign\n",
759+
"* roles of identifiers\n",
760+
"* structural information were removed\n",
761+
"* arguments to function\n",
762+
"\n",
763+
"and so on. There are bunch of improvements possible like [code2vec](https://github.com/tech-srl/code2vec) and many more."
757764
]
758765
}
759766
],

0 commit comments

Comments
 (0)