@@ -166,9 +166,18 @@ def estimate_time(num_layers, iters, nsamples, batch_size):
166166 return total_seconds
167167
168168
169- def dry_run_estimate (model_name , scheme_bits , group_size , model_dtype = "float16" ,
170- batch_size = 8 , seqlen = 2048 , nsamples = 128 , iters = 200 ,
171- trust_remote_code = True , platform = "hf" ):
169+ def dry_run_estimate (
170+ model_name ,
171+ scheme_bits ,
172+ group_size ,
173+ model_dtype = "float16" ,
174+ batch_size = 8 ,
175+ seqlen = 2048 ,
176+ nsamples = 128 ,
177+ iters = 200 ,
178+ trust_remote_code = True ,
179+ platform = "hf" ,
180+ ):
172181 """Run a dry-run estimation and return a dict of estimates.
173182
174183 Args:
@@ -249,8 +258,10 @@ def print_dry_run_report(estimates):
249258 print (f" Estimated peak VRAM: { estimates ['peak_vram_str' ]} " )
250259 print (f" Estimated output size: { estimates ['output_size_str' ]} " )
251260 print (f" Estimated time: { estimates ['estimated_time_str' ]} " )
252- print (f" (batch_size={ estimates ['batch_size' ]} , seqlen={ estimates ['seqlen' ]} , "
253- f"nsamples={ estimates ['nsamples' ]} , iters={ estimates ['iters' ]} )" )
261+ print (
262+ f" (batch_size={ estimates ['batch_size' ]} , seqlen={ estimates ['seqlen' ]} , "
263+ f"nsamples={ estimates ['nsamples' ]} , iters={ estimates ['iters' ]} )"
264+ )
254265 print (border )
255266 print (" NOTE: These are rough estimates. Actual values depend on" )
256267 print (" hardware, model architecture, and runtime conditions." )
0 commit comments