Skip to content

Commit dcd7b22

Browse files
committed
CLI: drop PT dependecny, add stdout output flag
Signed-off-by: Alexander Bezzubov <[email protected]>
1 parent f4bca02 commit dcd7b22

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@ env.sh
22
.mypy_cache
33
notebooks/output
44
notebooks/repos
5+
.venv/
6+
.vscode/

notebooks/codesearchnet-opennmt.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,28 @@
1+
"""
2+
CLI tool for converting CodeSearchNet dataset to OpenNMT format for
3+
function name suggestion task.
4+
5+
Usage example:
6+
wget 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip'
7+
unzip java.zip
8+
python notebooks/codesearchnet-opennmt.py \
9+
--data_dir='java/final/jsonl/valid' \
10+
--newline='\\n'
11+
"""
112
from argparse import ArgumentParser, Namespace
213
import logging
314
from pathlib import Path
415
from time import time
516
from typing import List, Tuple
617

718
import pandas as pd
8-
from torch.utils.data import Dataset
919

1020

1121
logging.basicConfig(level=logging.INFO)
1222

1323

14-
class CodeSearchNetRAM(Dataset):
15-
"""Stores one split of CodeSearchNet data in memory
16-
17-
Usage example:
18-
wget 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip'
19-
unzip java.zip
20-
python notebooks/codesearchnet-opennmt.py \
21-
--data_dir='java/final/jsonl/valid' \
22-
--newline='\\n'
23-
"""
24+
class CodeSearchNetRAM(object):
25+
"""Stores one split of CodeSearchNet data in memory"""
2426

2527
def __init__(self, split_path: Path, newline_repl: str):
2628
super().__init__()
@@ -79,8 +81,12 @@ def main(args: Namespace) -> None:
7981
for fn_name, fn_body in dataset:
8082
if not fn_name or not fn_body:
8183
continue
82-
print(fn_body, file=s)
83-
print(fn_name if args.word_level_targets else " ".join(fn_name), file=t)
84+
tgt = fn_name if args.word_level_targets else " ".join(fn_name)
85+
if args.print:
86+
print(f"'{fn_name[:40]:40}' - '{tgt[:40]:40}'")
87+
else:
88+
print(fn_body, file=s)
89+
print(tgt, file=t)
8490

8591

8692
if __name__ == "__main__":
@@ -110,5 +116,9 @@ def main(args: Namespace) -> None:
110116
"--tgt_file", type=str, default="tgt-%s.txt", help="File with function texts"
111117
)
112118

119+
parser.add_argument(
120+
"--print", action="store_true", help="Print data preview to the STDOUT"
121+
)
122+
113123
args = parser.parse_args()
114124
main(args)

0 commit comments

Comments
 (0)