11from argparse import ArgumentParser
2- import os
3- from pathlib import Path
4- import time
5- from typing import Dict , List , Tuple
62import logging
3+ from pathlib import Path
4+ from time import time
5+ from typing import List
76
87import pandas as pd
98
9+
1010logging .basicConfig (level = logging .INFO )
1111
12+
1213class CodeSearchNetRAM (object ):
1314 """Stores one split of CodeSearchNet data in memory
1415
1516 Usage example:
1617 wget 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip'
1718 unzip java.zip
18- python notebooks/codesearchnet-opennmt.py --data_dir='java/final/jsonl/valid' --newline='\\ n'
19+ python notebooks/codesearchnet-opennmt.py \
20+ --data_dir='java/final/jsonl/valid' \
21+ --newline='\\ n'
1922 """
2023
2124 def __init__ (self , split_path : Path , newline_repl : str ):
2225 super ().__init__ ()
2326 self .pd = pd
2427
25- files = sorted (split_path .glob (' **/*.gz' ))
26- logging .info (f' Total number of files: { len (files ):,} ' )
27- assert len ( files ) != 0 , "could not find files under %s" % split_path
28+ files = sorted (split_path .glob (" **/*.gz" ))
29+ logging .info (f" Total number of files: { len (files ):,} " )
30+ assert files , "could not find files under %s" % split_path
2831
29- columns_list = [' code' , ' func_name' ]
32+ columns_list = [" code" , " func_name" ]
3033
31- start = time . time ()
34+ start = time ()
3235 self .pd = self ._jsonl_list_to_dataframe (files , columns_list )
33- logging .info (f"Loading took { time . time () - start :.2f} s for { len (self )} rows" )
36+ logging .info (f"Loading took { time () - start :.2f} s for { len (self )} rows" )
3437
3538 @staticmethod
36- def _jsonl_list_to_dataframe (file_list : List [Path ],
37- columns : List [str ]) -> pd .DataFrame :
39+ def _jsonl_list_to_dataframe (
40+ file_list : List [Path ], columns : List [str ]
41+ ) -> pd .DataFrame :
3842 """Load a list of jsonl.gz files into a pandas DataFrame."""
39- return pd .concat ([pd .read_json (f ,
40- orient = 'records' ,
41- compression = 'gzip' ,
42- lines = True )[columns ]
43- for f in file_list ], sort = False )
44-
43+ return pd .concat (
44+ [
45+ pd .read_json (f , orient = "records" , compression = "gzip" , lines = True )[
46+ columns
47+ ]
48+ for f in file_list
49+ ],
50+ sort = False ,
51+ )
4552
4653 def __getitem__ (self , idx : int ):
4754 row = self .pd .iloc [idx ]
4855
4956 # drop class name
5057 fn_name = row ["func_name" ]
51- fn_name = fn_name .split ('.' )[- 1 ] # drop the class name
58+ fn_name = fn_name .split ("." )[- 1 ] # drop the class name
5259 # fn_name_enc = self.enc.encode(fn_name)
5360
5461 # drop fn signature
5562 code = row ["code" ]
56- fn_body = code [code .find ("{" ) + 1 : code .find ("}" )].lstrip ().rstrip ()
63+ fn_body = code [code .find ("{" ) + 1 : code .rfind ("}" )].lstrip ().rstrip ()
5764 fn_body = fn_body .replace ("\n " , "\\ n" )
5865 # fn_body_enc = self.enc.encode(fn_body)
5966 return (fn_name , fn_body )
@@ -63,36 +70,44 @@ def __len__(self):
6370
6471
6572def main (args ):
66- test_set = CodeSearchNetRAM (Path (args .data_dir ), args .newline )
67- with open (args .src_file , mode = "w" , encoding = "utf8" ) as s , open (args .tgt_file , mode = "w" , encoding = "utf8" ) as t :
68- for fn_name , fn_body in test_set :
69- print (f"'{ fn_name [:40 ]:40} ' - '{ fn_body [:40 ]:40} '" )
70- print (fn_name , file = s )
71- print (fn_body , file = t )
72-
73+ dataset = CodeSearchNetRAM (Path (args .data_dir ), args .newline )
74+ split_name = Path (args .data_dir ).name
75+ with open (args .src_file % split_name , mode = "w" , encoding = "utf8" ) as s , open (
76+ args .tgt_file % split_name , mode = "w" , encoding = "utf8"
77+ ) as t :
78+ for fn_name , fn_body in dataset :
79+ if not fn_name or not fn_body :
80+ continue
81+ print (fn_body , file = s )
82+ print (fn_name if args .word_level_targets else " " .join (fn_name ), file = t )
7383
7484
7585if __name__ == "__main__" :
7686 parser = ArgumentParser (add_help = False )
77- parser .add_argument ('--data_dir' ,
78- type = str ,
79- default = "java/final/jsonl/test" ,
80- help = "Path to the unziped input data (CodeSearchNet)" )
81-
82- parser .add_argument ('--newline' ,
83- type = str ,
84- default = "\\ n" ,
85- help = "Replace newline with this" )
86-
87- parser .add_argument ('--src_file' ,
88- type = str ,
89- default = "src-trian.txt" ,
90- help = "File with function bodies" )
91-
92- parser .add_argument ('--tgt_file' ,
93- type = str ,
94- default = "tgt-trian.txt" ,
95- help = "File with function texts" )
87+ parser .add_argument (
88+ "--data_dir" ,
89+ type = str ,
90+ default = "java/final/jsonl/test" ,
91+ help = "Path to the unziped input data (CodeSearchNet)" ,
92+ )
93+
94+ parser .add_argument (
95+ "--newline" , type = str , default = "\\ n" , help = "Replace newline with this"
96+ )
97+
98+ parser .add_argument (
99+ "--word-level-targets" ,
100+ action = "store_true" ,
101+ help = "Use word level targets instead of char level ones" ,
102+ )
103+
104+ parser .add_argument (
105+ "--src_file" , type = str , default = "src-%s.txt" , help = "File with function bodies" ,
106+ )
107+
108+ parser .add_argument (
109+ "--tgt_file" , type = str , default = "tgt-%s.txt" , help = "File with function texts"
110+ )
96111
97112 args = parser .parse_args ()
98113 main (args )
0 commit comments