@@ -27,12 +27,13 @@ class CodeSearchNetRAM(object):
2727 def __init__ (self , split_path : Path , newline_repl : str ):
2828 super ().__init__ ()
2929 self .pd = pd
30+ self .newline_repl = newline_repl
3031
3132 files = sorted (split_path .glob ("**/*.gz" ))
3233 logging .info (f"Total number of files: { len (files ):,} " )
3334 assert files , "could not find files under %s" % split_path
3435
35- columns_list = ["code" , "func_name" ]
36+ columns_list = ["code" , "func_name" , "code_tokens" ]
3637
3738 start = time ()
3839 self .pd = self ._jsonl_list_to_dataframe (files , columns_list )
@@ -63,10 +64,21 @@ def __getitem__(self, idx: int) -> Tuple[str, str]:
6364
6465 # drop fn signature
6566 code = row ["code" ]
66- fn_body = code [code .find ("{" ) + 1 : code .rfind ("}" )].lstrip ().rstrip ()
67- fn_body = fn_body .replace ("\n " , "\\ n" )
67+ fn_body = (
68+ code [
69+ code .find ("{" , code .find (fn_name ) + len (fn_name )) + 1 : code .rfind ("}" )
70+ ]
71+ .lstrip ()
72+ .rstrip ()
73+ )
74+ fn_body = fn_body .replace ("\n " , self .newline_repl )
6875 # fn_body_enc = self.enc.encode(fn_body)
69- return (fn_name , fn_body )
76+
77+ tokens = row ["code_tokens" ]
78+ body_tokens = tokens [tokens .index (fn_name ) + 2 :]
79+ fn_body_tokens = body_tokens [body_tokens .index ("{" ) + 1 : len (body_tokens ) - 1 ]
80+
81+ return (fn_name , fn_body , fn_body_tokens )
7082
7183 def __len__ (self ) -> int :
7284 return len (self .pd )
@@ -78,14 +90,15 @@ def main(args: Namespace) -> None:
7890 with open (args .src_file % split_name , mode = "w" , encoding = "utf8" ) as s , open (
7991 args .tgt_file % split_name , mode = "w" , encoding = "utf8"
8092 ) as t :
81- for fn_name , fn_body in dataset :
93+ for fn_name , fn_body , fn_body_tokens in dataset :
8294 if not fn_name or not fn_body :
8395 continue
96+ src = " " .join (fn_body_tokens ) if args .token_level_sources else fn_body
8497 tgt = fn_name if args .word_level_targets else " " .join (fn_name )
8598 if args .print :
86- print (f"'{ fn_name [:40 ]:40} ' - '{ tgt [:40 ]:40} '" )
99+ print (f"'{ tgt [:40 ]:40} ' - '{ src [:40 ]:40} '" )
87100 else :
88- print (fn_body , file = s )
101+ print (src , file = s )
89102 print (tgt , file = t )
90103
91104
@@ -102,18 +115,27 @@ def main(args: Namespace) -> None:
102115 "--newline" , type = str , default = "\\ n" , help = "Replace newline with this"
103116 )
104117
118+ parser .add_argument (
119+ "--token-level-sources" ,
120+ action = "store_true" ,
121+ help = "Use language-specific token sources instead of word level ones" ,
122+ )
123+
105124 parser .add_argument (
106125 "--word-level-targets" ,
107126 action = "store_true" ,
108127 help = "Use word level targets instead of char level ones" ,
109128 )
110129
111130 parser .add_argument (
112- "--src_file" , type = str , default = "src-%s.txt" , help = "File with function bodies" ,
131+ "--src_file" ,
132+ type = str ,
133+ default = "src-%s.token" ,
134+ help = "File with function bodies" ,
113135 )
114136
115137 parser .add_argument (
116- "--tgt_file" , type = str , default = "tgt-%s.txt " , help = "File with function texts"
138+ "--tgt_file" , type = str , default = "tgt-%s.token " , help = "File with function texts"
117139 )
118140
119141 parser .add_argument (
0 commit comments