Skip to content

Commit 6cfc883

Browse files
committed
cli: add --token-level-targets \w dpu_utils.codeutils.identifiersplitting
Signed-off-by: Alexander Bezzubov <[email protected]>
1 parent aa65cc4 commit 6cfc883

File tree

1 file changed

+76
-1
lines changed

1 file changed

+76
-1
lines changed

notebooks/codesearchnet-opennmt.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,69 @@ def __len__(self) -> int:
8686
return len(self.pd)
8787

8888

89+
# id splitting from
90+
# https://github.com/microsoft/dpu-utils/blob/dfc44e354b57a4e2617828bdf4d76c1c4d81c021/python/dpu_utils/codeutils/identifiersplitting.py
91+
from functools import lru_cache
92+
from typing import List
93+
94+
def split_camelcase(camel_case_identifier: str) -> List[str]:
95+
"""
96+
Split camelCase identifiers.
97+
"""
98+
if not len(camel_case_identifier):
99+
return []
100+
101+
# split into words based on adjacent cases being the same
102+
result = []
103+
current = str(camel_case_identifier[0])
104+
prev_upper = camel_case_identifier[0].isupper()
105+
prev_digit = camel_case_identifier[0].isdigit()
106+
prev_special = not camel_case_identifier[0].isalnum()
107+
for c in camel_case_identifier[1:]:
108+
upper = c.isupper()
109+
digit = c.isdigit()
110+
special = not c.isalnum()
111+
new_upper_word = upper and not prev_upper
112+
new_digit_word = digit and not prev_digit
113+
new_special_word = special and not prev_special
114+
if new_digit_word or new_upper_word or new_special_word:
115+
result.append(current)
116+
current = c
117+
elif not upper and prev_upper and len(current) > 1:
118+
result.append(current[:-1])
119+
current = current[-1] + c
120+
elif not digit and prev_digit:
121+
result.append(current)
122+
current = c
123+
elif not special and prev_special:
124+
result.append(current)
125+
current = c
126+
else:
127+
current += c
128+
prev_digit = digit
129+
prev_upper = upper
130+
prev_special = special
131+
result.append(current)
132+
return result
133+
134+
135+
@lru_cache(maxsize=5000)
136+
def split_identifier_into_parts(identifier: str) -> List[str]:
137+
"""
138+
Split a single identifier into parts on snake_case and camelCase
139+
"""
140+
snake_case = identifier.split("_")
141+
142+
identifier_parts = [] # type: List[str]
143+
for i in range(len(snake_case)):
144+
part = snake_case[i]
145+
if len(part) > 0:
146+
identifier_parts.extend(s.lower() for s in split_camelcase(part))
147+
if len(identifier_parts) == 0:
148+
return [identifier]
149+
return identifier_parts
150+
151+
89152
def main(args: Namespace) -> None:
90153
dataset = CodeSearchNetRAM(Path(args.data_dir), args.newline)
91154
split_name = Path(args.data_dir).name
@@ -96,7 +159,13 @@ def main(args: Namespace) -> None:
96159
if not fn_name or not fn_body:
97160
continue
98161
src = " ".join(fn_body_tokens) if args.token_level_sources else fn_body
99-
tgt = fn_name if args.word_level_targets else " ".join(fn_name)
162+
163+
if args.word_level_targets:
164+
tgt = fn_name
165+
elif args.token_level_targets:
166+
tgt = " ".join(split_identifier_into_parts(fn_name))
167+
else:
168+
tgt = " ".join(fn_name)
100169
if args.print:
101170
print(f"'{tgt[:40]:40}' - '{src[:40]:40}'")
102171
else:
@@ -121,6 +190,12 @@ def main(args: Namespace) -> None:
121190
help="Use language-specific token sources instead of word level ones",
122191
)
123192

193+
parser.add_argument(
194+
"--token-level-targets",
195+
action="store_true",
196+
help="Use camlCase and snake_case split token sources instead of word or char level ones",
197+
)
198+
124199
parser.add_argument(
125200
"--word-level-targets",
126201
action="store_true",

0 commit comments

Comments
 (0)