Skip to content

Commit 394388a

Browse files
authored
Merge pull request #9 from mloncode/add-codesearchnet-preproc
CLI helper for CodeSearchNet to OpenNMT
2 parents f47ed95 + f2521a7 commit 394388a

File tree

1 file changed

+113
-0
lines changed

1 file changed

+113
-0
lines changed

notebooks/codesearchnet-opennmt.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from argparse import ArgumentParser
2+
import logging
3+
from pathlib import Path
4+
from time import time
5+
from typing import List
6+
7+
import pandas as pd
8+
9+
10+
logging.basicConfig(level=logging.INFO)
11+
12+
13+
class CodeSearchNetRAM(object):
14+
"""Stores one split of CodeSearchNet data in memory
15+
16+
Usage example:
17+
wget 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip'
18+
unzip java.zip
19+
python notebooks/codesearchnet-opennmt.py \
20+
--data_dir='java/final/jsonl/valid' \
21+
--newline='\\n'
22+
"""
23+
24+
def __init__(self, split_path: Path, newline_repl: str):
25+
super().__init__()
26+
self.pd = pd
27+
28+
files = sorted(split_path.glob("**/*.gz"))
29+
logging.info(f"Total number of files: {len(files):,}")
30+
assert files, "could not find files under %s" % split_path
31+
32+
columns_list = ["code", "func_name"]
33+
34+
start = time()
35+
self.pd = self._jsonl_list_to_dataframe(files, columns_list)
36+
logging.info(f"Loading took {time() - start:.2f}s for {len(self)} rows")
37+
38+
@staticmethod
39+
def _jsonl_list_to_dataframe(
40+
file_list: List[Path], columns: List[str]
41+
) -> pd.DataFrame:
42+
"""Load a list of jsonl.gz files into a pandas DataFrame."""
43+
return pd.concat(
44+
[
45+
pd.read_json(f, orient="records", compression="gzip", lines=True)[
46+
columns
47+
]
48+
for f in file_list
49+
],
50+
sort=False,
51+
)
52+
53+
def __getitem__(self, idx: int):
54+
row = self.pd.iloc[idx]
55+
56+
# drop class name
57+
fn_name = row["func_name"]
58+
fn_name = fn_name.split(".")[-1] # drop the class name
59+
# fn_name_enc = self.enc.encode(fn_name)
60+
61+
# drop fn signature
62+
code = row["code"]
63+
fn_body = code[code.find("{") + 1 : code.rfind("}")].lstrip().rstrip()
64+
fn_body = fn_body.replace("\n", "\\n")
65+
# fn_body_enc = self.enc.encode(fn_body)
66+
return (fn_name, fn_body)
67+
68+
def __len__(self):
69+
return len(self.pd)
70+
71+
72+
def main(args):
73+
dataset = CodeSearchNetRAM(Path(args.data_dir), args.newline)
74+
split_name = Path(args.data_dir).name
75+
with open(args.src_file % split_name, mode="w", encoding="utf8") as s, open(
76+
args.tgt_file % split_name, mode="w", encoding="utf8"
77+
) as t:
78+
for fn_name, fn_body in dataset:
79+
if not fn_name or not fn_body:
80+
continue
81+
print(fn_body, file=s)
82+
print(fn_name if args.word_level_targets else " ".join(fn_name), file=t)
83+
84+
85+
if __name__ == "__main__":
86+
parser = ArgumentParser(add_help=False)
87+
parser.add_argument(
88+
"--data_dir",
89+
type=str,
90+
default="java/final/jsonl/test",
91+
help="Path to the unziped input data (CodeSearchNet)",
92+
)
93+
94+
parser.add_argument(
95+
"--newline", type=str, default="\\n", help="Replace newline with this"
96+
)
97+
98+
parser.add_argument(
99+
"--word-level-targets",
100+
action="store_true",
101+
help="Use word level targets instead of char level ones",
102+
)
103+
104+
parser.add_argument(
105+
"--src_file", type=str, default="src-%s.txt", help="File with function bodies",
106+
)
107+
108+
parser.add_argument(
109+
"--tgt_file", type=str, default="tgt-%s.txt", help="File with function texts"
110+
)
111+
112+
args = parser.parse_args()
113+
main(args)

0 commit comments

Comments
 (0)