1616
1717def main ():
1818 parser = argparse .ArgumentParser (
19- description = ' Generate low-precision samples from existing float32 samples'
19+ description = " Generate low-precision samples from existing float32 samples"
2020 )
2121 parser .add_argument (
22- ' --source-sample-path' ,
22+ " --source-sample-path" ,
2323 type = str ,
2424 required = True ,
25- help = ' Path to the source sample directory'
25+ help = " Path to the source sample directory" ,
2626 )
2727 parser .add_argument (
28- ' --output-base-path' ,
28+ " --output-base-path" ,
2929 type = str ,
3030 required = True ,
31- help = ' Base path for output samples'
31+ help = " Base path for output samples" ,
3232 )
3333 parser .add_argument (
34- ' --dtype-list' ,
34+ " --dtype-list" ,
3535 type = str ,
36- nargs = '+' ,
37- default = [' float16' , ' bfloat16' ],
38- choices = [' float16' , ' bfloat16' , ' float8' ],
39- help = ' List of target dtypes to generate (default: float16 bfloat16)'
36+ nargs = "+" ,
37+ default = [" float16" , " bfloat16" ],
38+ choices = [" float16" , " bfloat16" , " float8" ],
39+ help = " List of target dtypes to generate (default: float16 bfloat16)" ,
4040 )
4141 parser .add_argument (
42- ' --filter-config' ,
42+ " --filter-config" ,
4343 type = str ,
4444 default = None ,
45- help = ' Path to filter configuration file (JSON format)'
45+ help = " Path to filter configuration file (JSON format)" ,
4646 )
47-
47+
4848 args = parser .parse_args ()
49-
49+
5050 # Load filter config if provided
5151 filter_config = {}
5252 if args .filter_config :
5353 import json
54- with open (args .filter_config , 'r' ) as f :
54+
55+ with open (args .filter_config , "r" ) as f :
5556 filter_config = json .load (f )
56-
57+
5758 # Create generator
5859 generator = MultiDtypeGenerator (
5960 source_sample_path = args .source_sample_path ,
6061 output_base_path = args .output_base_path ,
6162 dtype_list = args .dtype_list ,
6263 filter_config = filter_config ,
6364 )
64-
65+
6566 # Generate samples
6667 print (f"Generating samples from: { args .source_sample_path } " )
6768 print (f"Output base path: { args .output_base_path } " )
6869 print (f"Target dtypes: { args .dtype_list } " )
69-
70+
7071 generated_paths = generator .generate ()
71-
72+
7273 if generated_paths :
7374 print (f"\n Successfully generated { len (generated_paths )} sample(s):" )
7475 for path in generated_paths :
@@ -79,6 +80,5 @@ def main():
7980 return 1
8081
8182
82- if __name__ == ' __main__' :
83+ if __name__ == " __main__" :
8384 sys .exit (main ())
84-
0 commit comments