Skip to content

Commit f2521a7

Browse files
committed
Tweak the codesearchnet opennmt helper a bit
Signed-off-by: m09 <[email protected]>
1 parent c1af389 commit f2521a7

File tree

1 file changed

+62
-47
lines changed

1 file changed

+62
-47
lines changed

notebooks/codesearchnet-opennmt.py

Lines changed: 62 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,66 @@
11
from argparse import ArgumentParser
2-
import os
3-
from pathlib import Path
4-
import time
5-
from typing import Dict, List, Tuple
62
import logging
3+
from pathlib import Path
4+
from time import time
5+
from typing import List
76

87
import pandas as pd
98

9+
1010
logging.basicConfig(level=logging.INFO)
1111

12+
1213
class CodeSearchNetRAM(object):
1314
"""Stores one split of CodeSearchNet data in memory
1415
1516
Usage example:
1617
wget 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip'
1718
unzip java.zip
18-
python notebooks/codesearchnet-opennmt.py --data_dir='java/final/jsonl/valid' --newline='\\n'
19+
python notebooks/codesearchnet-opennmt.py \
20+
--data_dir='java/final/jsonl/valid' \
21+
--newline='\\n'
1922
"""
2023

2124
def __init__(self, split_path: Path, newline_repl: str):
2225
super().__init__()
2326
self.pd = pd
2427

25-
files = sorted(split_path.glob('**/*.gz'))
26-
logging.info(f'Total number of files: {len(files):,}')
27-
assert len(files) != 0, "could not find files under %s" % split_path
28+
files = sorted(split_path.glob("**/*.gz"))
29+
logging.info(f"Total number of files: {len(files):,}")
30+
assert files, "could not find files under %s" % split_path
2831

29-
columns_list = ['code', 'func_name']
32+
columns_list = ["code", "func_name"]
3033

31-
start = time.time()
34+
start = time()
3235
self.pd = self._jsonl_list_to_dataframe(files, columns_list)
33-
logging.info(f"Loading took {time.time() - start:.2f}s for {len(self)} rows")
36+
logging.info(f"Loading took {time() - start:.2f}s for {len(self)} rows")
3437

3538
@staticmethod
36-
def _jsonl_list_to_dataframe(file_list: List[Path],
37-
columns: List[str]) -> pd.DataFrame:
39+
def _jsonl_list_to_dataframe(
40+
file_list: List[Path], columns: List[str]
41+
) -> pd.DataFrame:
3842
"""Load a list of jsonl.gz files into a pandas DataFrame."""
39-
return pd.concat([pd.read_json(f,
40-
orient='records',
41-
compression='gzip',
42-
lines=True)[columns]
43-
for f in file_list], sort=False)
44-
43+
return pd.concat(
44+
[
45+
pd.read_json(f, orient="records", compression="gzip", lines=True)[
46+
columns
47+
]
48+
for f in file_list
49+
],
50+
sort=False,
51+
)
4552

4653
def __getitem__(self, idx: int):
4754
row = self.pd.iloc[idx]
4855

4956
# drop class name
5057
fn_name = row["func_name"]
51-
fn_name = fn_name.split('.')[-1] # drop the class name
58+
fn_name = fn_name.split(".")[-1] # drop the class name
5259
# fn_name_enc = self.enc.encode(fn_name)
5360

5461
# drop fn signature
5562
code = row["code"]
56-
fn_body = code[code.find("{") + 1:code.find("}")].lstrip().rstrip()
63+
fn_body = code[code.find("{") + 1 : code.rfind("}")].lstrip().rstrip()
5764
fn_body = fn_body.replace("\n", "\\n")
5865
# fn_body_enc = self.enc.encode(fn_body)
5966
return (fn_name, fn_body)
@@ -63,36 +70,44 @@ def __len__(self):
6370

6471

6572
def main(args):
66-
test_set = CodeSearchNetRAM(Path(args.data_dir), args.newline)
67-
with open(args.src_file, mode="w", encoding="utf8") as s, open(args.tgt_file, mode="w", encoding="utf8") as t:
68-
for fn_name, fn_body in test_set:
69-
print(f"'{fn_name[:40]:40}' - '{fn_body[:40]:40}'")
70-
print(fn_name, file=s)
71-
print(fn_body, file=t)
72-
73+
dataset = CodeSearchNetRAM(Path(args.data_dir), args.newline)
74+
split_name = Path(args.data_dir).name
75+
with open(args.src_file % split_name, mode="w", encoding="utf8") as s, open(
76+
args.tgt_file % split_name, mode="w", encoding="utf8"
77+
) as t:
78+
for fn_name, fn_body in dataset:
79+
if not fn_name or not fn_body:
80+
continue
81+
print(fn_body, file=s)
82+
print(fn_name if args.word_level_targets else " ".join(fn_name), file=t)
7383

7484

7585
if __name__ == "__main__":
7686
parser = ArgumentParser(add_help=False)
77-
parser.add_argument('--data_dir',
78-
type=str,
79-
default="java/final/jsonl/test",
80-
help="Path to the unziped input data (CodeSearchNet)")
81-
82-
parser.add_argument('--newline',
83-
type=str,
84-
default="\\n",
85-
help="Replace newline with this")
86-
87-
parser.add_argument('--src_file',
88-
type=str,
89-
default="src-trian.txt",
90-
help="File with function bodies")
91-
92-
parser.add_argument('--tgt_file',
93-
type=str,
94-
default="tgt-trian.txt",
95-
help="File with function texts")
87+
parser.add_argument(
88+
"--data_dir",
89+
type=str,
90+
default="java/final/jsonl/test",
91+
help="Path to the unziped input data (CodeSearchNet)",
92+
)
93+
94+
parser.add_argument(
95+
"--newline", type=str, default="\\n", help="Replace newline with this"
96+
)
97+
98+
parser.add_argument(
99+
"--word-level-targets",
100+
action="store_true",
101+
help="Use word level targets instead of char level ones",
102+
)
103+
104+
parser.add_argument(
105+
"--src_file", type=str, default="src-%s.txt", help="File with function bodies",
106+
)
107+
108+
parser.add_argument(
109+
"--tgt_file", type=str, default="tgt-%s.txt", help="File with function texts"
110+
)
96111

97112
args = parser.parse_args()
98113
main(args)

0 commit comments

Comments
 (0)