Skip to content

Commit c1af389

Browse files
committed
CLI helper for CodeSearchNet to OpenNMT
1 parent f47ed95 commit c1af389

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed

notebooks/codesearchnet-opennmt.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from argparse import ArgumentParser
2+
import os
3+
from pathlib import Path
4+
import time
5+
from typing import Dict, List, Tuple
6+
import logging
7+
8+
import pandas as pd
9+
10+
logging.basicConfig(level=logging.INFO)
11+
12+
class CodeSearchNetRAM(object):
13+
"""Stores one split of CodeSearchNet data in memory
14+
15+
Usage example:
16+
wget 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip'
17+
unzip java.zip
18+
python notebooks/codesearchnet-opennmt.py --data_dir='java/final/jsonl/valid' --newline='\\n'
19+
"""
20+
21+
def __init__(self, split_path: Path, newline_repl: str):
22+
super().__init__()
23+
self.pd = pd
24+
25+
files = sorted(split_path.glob('**/*.gz'))
26+
logging.info(f'Total number of files: {len(files):,}')
27+
assert len(files) != 0, "could not find files under %s" % split_path
28+
29+
columns_list = ['code', 'func_name']
30+
31+
start = time.time()
32+
self.pd = self._jsonl_list_to_dataframe(files, columns_list)
33+
logging.info(f"Loading took {time.time() - start:.2f}s for {len(self)} rows")
34+
35+
@staticmethod
36+
def _jsonl_list_to_dataframe(file_list: List[Path],
37+
columns: List[str]) -> pd.DataFrame:
38+
"""Load a list of jsonl.gz files into a pandas DataFrame."""
39+
return pd.concat([pd.read_json(f,
40+
orient='records',
41+
compression='gzip',
42+
lines=True)[columns]
43+
for f in file_list], sort=False)
44+
45+
46+
def __getitem__(self, idx: int):
47+
row = self.pd.iloc[idx]
48+
49+
# drop class name
50+
fn_name = row["func_name"]
51+
fn_name = fn_name.split('.')[-1] # drop the class name
52+
# fn_name_enc = self.enc.encode(fn_name)
53+
54+
# drop fn signature
55+
code = row["code"]
56+
fn_body = code[code.find("{") + 1:code.find("}")].lstrip().rstrip()
57+
fn_body = fn_body.replace("\n", "\\n")
58+
# fn_body_enc = self.enc.encode(fn_body)
59+
return (fn_name, fn_body)
60+
61+
def __len__(self):
62+
return len(self.pd)
63+
64+
65+
def main(args):
66+
test_set = CodeSearchNetRAM(Path(args.data_dir), args.newline)
67+
with open(args.src_file, mode="w", encoding="utf8") as s, open(args.tgt_file, mode="w", encoding="utf8") as t:
68+
for fn_name, fn_body in test_set:
69+
print(f"'{fn_name[:40]:40}' - '{fn_body[:40]:40}'")
70+
print(fn_name, file=s)
71+
print(fn_body, file=t)
72+
73+
74+
75+
if __name__ == "__main__":
76+
parser = ArgumentParser(add_help=False)
77+
parser.add_argument('--data_dir',
78+
type=str,
79+
default="java/final/jsonl/test",
80+
help="Path to the unziped input data (CodeSearchNet)")
81+
82+
parser.add_argument('--newline',
83+
type=str,
84+
default="\\n",
85+
help="Replace newline with this")
86+
87+
parser.add_argument('--src_file',
88+
type=str,
89+
default="src-trian.txt",
90+
help="File with function bodies")
91+
92+
parser.add_argument('--tgt_file',
93+
type=str,
94+
default="tgt-trian.txt",
95+
help="File with function texts")
96+
97+
args = parser.parse_args()
98+
main(args)

0 commit comments

Comments
 (0)