Skip to content

Commit 16063f1

Browse files
committed
CLI: add --token-level-sources option
Signed-off-by: Alexander Bezzubov <[email protected]>
1 parent dcd7b22 commit 16063f1

File tree

1 file changed

+31
-9
lines changed

1 file changed

+31
-9
lines changed

notebooks/codesearchnet-opennmt.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@ class CodeSearchNetRAM(object):
2727
def __init__(self, split_path: Path, newline_repl: str):
2828
super().__init__()
2929
self.pd = pd
30+
self.newline_repl = newline_repl
3031

3132
files = sorted(split_path.glob("**/*.gz"))
3233
logging.info(f"Total number of files: {len(files):,}")
3334
assert files, "could not find files under %s" % split_path
3435

35-
columns_list = ["code", "func_name"]
36+
columns_list = ["code", "func_name", "code_tokens"]
3637

3738
start = time()
3839
self.pd = self._jsonl_list_to_dataframe(files, columns_list)
@@ -63,10 +64,21 @@ def __getitem__(self, idx: int) -> Tuple[str, str]:
6364

6465
# drop fn signature
6566
code = row["code"]
66-
fn_body = code[code.find("{") + 1 : code.rfind("}")].lstrip().rstrip()
67-
fn_body = fn_body.replace("\n", "\\n")
67+
fn_body = (
68+
code[
69+
code.find("{", code.find(fn_name) + len(fn_name)) + 1 : code.rfind("}")
70+
]
71+
.lstrip()
72+
.rstrip()
73+
)
74+
fn_body = fn_body.replace("\n", self.newline_repl)
6875
# fn_body_enc = self.enc.encode(fn_body)
69-
return (fn_name, fn_body)
76+
77+
tokens = row["code_tokens"]
78+
body_tokens = tokens[tokens.index(fn_name) + 2 :]
79+
fn_body_tokens = body_tokens[body_tokens.index("{") + 1 : len(body_tokens) - 1]
80+
81+
return (fn_name, fn_body, fn_body_tokens)
7082

7183
def __len__(self) -> int:
7284
return len(self.pd)
@@ -78,14 +90,15 @@ def main(args: Namespace) -> None:
7890
with open(args.src_file % split_name, mode="w", encoding="utf8") as s, open(
7991
args.tgt_file % split_name, mode="w", encoding="utf8"
8092
) as t:
81-
for fn_name, fn_body in dataset:
93+
for fn_name, fn_body, fn_body_tokens in dataset:
8294
if not fn_name or not fn_body:
8395
continue
96+
src = " ".join(fn_body_tokens) if args.token_level_sources else fn_body
8497
tgt = fn_name if args.word_level_targets else " ".join(fn_name)
8598
if args.print:
86-
print(f"'{fn_name[:40]:40}' - '{tgt[:40]:40}'")
99+
print(f"'{tgt[:40]:40}' - '{src[:40]:40}'")
87100
else:
88-
print(fn_body, file=s)
101+
print(src, file=s)
89102
print(tgt, file=t)
90103

91104

@@ -102,18 +115,27 @@ def main(args: Namespace) -> None:
102115
"--newline", type=str, default="\\n", help="Replace newline with this"
103116
)
104117

118+
parser.add_argument(
119+
"--token-level-sources",
120+
action="store_true",
121+
help="Use language-specific token sources instead of word level ones",
122+
)
123+
105124
parser.add_argument(
106125
"--word-level-targets",
107126
action="store_true",
108127
help="Use word level targets instead of char level ones",
109128
)
110129

111130
parser.add_argument(
112-
"--src_file", type=str, default="src-%s.txt", help="File with function bodies",
131+
"--src_file",
132+
type=str,
133+
default="src-%s.token",
134+
help="File with function bodies",
113135
)
114136

115137
parser.add_argument(
116-
"--tgt_file", type=str, default="tgt-%s.txt", help="File with function texts"
138+
"--tgt_file", type=str, default="tgt-%s.token", help="File with function texts"
117139
)
118140

119141
parser.add_argument(

0 commit comments

Comments
 (0)