5656except ImportError as e :
5757 has_functorch = False
5858
59- try :
60- import torch ._dynamo
61- has_dynamo = True
62- except ImportError :
63- has_dynamo = False
64- pass
65-
59+ has_compile = hasattr (torch , 'compile' )
6660
6761if torch .cuda .is_available ():
6862 torch .backends .cuda .matmul .allow_tf32 = True
8175 help = 'Provide train fwd/bwd/opt breakdown detail if True. Defaults to False' )
8276parser .add_argument ('--no-retry' , action = 'store_true' , default = False ,
8377 help = 'Do not decay batch size and retry on error.' )
84- parser .add_argument ('--results-file' , default = '' , type = str , metavar = 'FILENAME' ,
78+ parser .add_argument ('--results-file' , default = '' , type = str ,
8579 help = 'Output csv file for validation results (summary)' )
80+ parser .add_argument ('--results-format' , default = 'csv' , type = str ,
81+ help = 'Format for results file one of (csv, json) (default: csv).' )
8682parser .add_argument ('--num-warm-iter' , default = 10 , type = int ,
8783 metavar = 'N' , help = 'Number of warmup iterations (default: 10)' )
8884parser .add_argument ('--num-bench-iter' , default = 40 , type = int ,
113109 help = 'Numeric precision. One of (amp, float32, float16, bfloat16, tf32)' )
114110parser .add_argument ('--fuser' , default = '' , type = str ,
115111 help = "Select jit fuser. One of ('', 'te', 'old', 'nvfuser')" )
116- parser .add_argument ('--dynamo-backend' , default = None , type = str ,
117- help = "Select dynamo backend. Default: None" )
118112parser .add_argument ('--fast-norm' , default = False , action = 'store_true' ,
119113 help = 'enable experimental fast-norm' )
120114
121115# codegen (model compilation) options
122116scripting_group = parser .add_mutually_exclusive_group ()
123117scripting_group .add_argument ('--torchscript' , dest = 'torchscript' , action = 'store_true' ,
124118 help = 'convert model torchscript for inference' )
119+ scripting_group .add_argument ('--torchcompile' , nargs = '?' , type = str , default = None , const = 'inductor' ,
120+ help = "Enable compilation w/ specified backend (default: inductor)." )
125121scripting_group .add_argument ('--aot-autograd' , default = False , action = 'store_true' ,
126122 help = "Enable AOT Autograd optimization." )
127- scripting_group .add_argument ('--dynamo' , default = False , action = 'store_true' ,
128- help = "Enable Dynamo optimization." )
123+
129124
130125# train optimizer parameters
131126parser .add_argument ('--opt' , default = 'sgd' , type = str , metavar = 'OPTIMIZER' ,
@@ -218,9 +213,8 @@ def __init__(
218213 detail = False ,
219214 device = 'cuda' ,
220215 torchscript = False ,
216+ torchcompile = None ,
221217 aot_autograd = False ,
222- dynamo = False ,
223- dynamo_backend = None ,
224218 precision = 'float32' ,
225219 fuser = '' ,
226220 num_warm_iter = 10 ,
@@ -259,20 +253,19 @@ def __init__(
259253 self .input_size = data_config ['input_size' ]
260254 self .batch_size = kwargs .pop ('batch_size' , 256 )
261255
262- self .scripted = False
256+ self .compiled = False
263257 if torchscript :
264258 self .model = torch .jit .script (self .model )
265- self .scripted = True
266- elif dynamo :
267- assert has_dynamo , " torch._dynamo is needed for --dynamo"
259+ self .compiled = True
260+ elif torchcompile :
261+ assert has_compile , 'A version of torch w/ torch.compile() is required, possibly a nightly.'
268262 torch ._dynamo .reset ()
269- if dynamo_backend is not None :
270- self .model = torch ._dynamo .optimize (dynamo_backend )(self .model )
271- else :
272- self .model = torch ._dynamo .optimize ()(self .model )
263+ self .model = torch .compile (self .model , backend = torchcompile )
264+ self .compiled = True
273265 elif aot_autograd :
274266 assert has_functorch , "functorch is needed for --aot-autograd"
275267 self .model = memory_efficient_fusion (self .model )
268+ self .compiled = True
276269
277270 self .example_inputs = None
278271 self .num_warm_iter = num_warm_iter
@@ -344,7 +337,7 @@ def _step():
344337 param_count = round (self .param_count / 1e6 , 2 ),
345338 )
346339
347- retries = 0 if self .scripted else 2 # skip profiling if model is scripted
340+ retries = 0 if self .compiled else 2 # skip profiling if model is scripted
348341 while retries :
349342 retries -= 1
350343 try :
@@ -642,7 +635,6 @@ def main():
642635 model_cfgs = [(n , None ) for n in model_names ]
643636
644637 if len (model_cfgs ):
645- results_file = args .results_file or './benchmark.csv'
646638 _logger .info ('Running bulk validation on these pretrained models: {}' .format (', ' .join (model_names )))
647639 results = []
648640 try :
@@ -663,22 +655,30 @@ def main():
663655 sort_key = 'infer_gmacs'
664656 results = filter (lambda x : sort_key in x , results )
665657 results = sorted (results , key = lambda x : x [sort_key ], reverse = True )
666- if len (results ):
667- write_results (results_file , results )
668658 else :
669659 results = benchmark (args )
670660
661+ if args .results_file :
662+ write_results (args .results_file , results , format = args .results_format )
663+
671664 # output results in JSON to stdout w/ delimiter for runner script
672665 print (f'--result\n { json .dumps (results , indent = 4 )} ' )
673666
674667
675- def write_results (results_file , results ):
668+ def write_results (results_file , results , format = 'csv' ):
676669 with open (results_file , mode = 'w' ) as cf :
677- dw = csv .DictWriter (cf , fieldnames = results [0 ].keys ())
678- dw .writeheader ()
679- for r in results :
680- dw .writerow (r )
681- cf .flush ()
670+ if format == 'json' :
671+ json .dump (results , cf , indent = 4 )
672+ else :
673+ if not isinstance (results , (list , tuple )):
674+ results = [results ]
675+ if not results :
676+ return
677+ dw = csv .DictWriter (cf , fieldnames = results [0 ].keys ())
678+ dw .writeheader ()
679+ for r in results :
680+ dw .writerow (r )
681+ cf .flush ()
682682
683683
684684if __name__ == '__main__' :
0 commit comments