2121NANO_SEC = 1e9
2222MILLI_SEC = 1000
2323
24+
2425def setup_logging (rank ):
2526 """Setup logging configuration for data parallel (all ranks log)."""
2627 logging .basicConfig (
@@ -45,7 +46,8 @@ def load_prompts(dataset_path):
4546
4647
4748class Model :
48- def __init__ (self , model_path , device , config , prompts , fixed_latent = None , rank = 0 ):
49+ def __init__ (self , model_path , device , config ,
50+ prompts , fixed_latent = None , rank = 0 ):
4951 self .device = device
5052 self .rank = rank
5153 self .height = config ['height' ]
@@ -105,7 +107,8 @@ def flush_queries(self):
105107
106108
107109class DebugModel :
108- def __init__ (self , model_path , device , config , prompts , fixed_latent = None , rank = 0 ):
110+ def __init__ (self , model_path , device , config ,
111+ prompts , fixed_latent = None , rank = 0 ):
109112 self .prompts = prompts
110113
111114 def issue_queries (self , query_samples ):
@@ -129,13 +132,15 @@ def flush_queries(self):
129132def load_query_samples (sample_list ):
130133 pass
131134
135+
132136def unload_query_samples (sample_list ):
133137 pass
134138
139+
135140def get_args ():
136141 parser = argparse .ArgumentParser (
137142 description = "Batch T2V inference with Wan2.2-Diffusers" )
138- ## Model Arguments
143+ # Model Arguments
139144 parser .add_argument (
140145 "--model-path" ,
141146 type = str ,
@@ -178,7 +183,7 @@ def get_args():
178183 default = "./data/fixed_latent.pt" ,
179184 help = "Path to fixed latent .pt file for deterministic generation (default: data/fixed_latent.pt)"
180185 )
181- ## MLPerf loadgen arguments
186+ # MLPerf loadgen arguments
182187 parser .add_argument (
183188 "--scenario" ,
184189 default = "SingleStream" ,
@@ -221,6 +226,7 @@ def get_args():
221226
222227 return parser .parse_args ()
223228
229+
224230def run_mlperf (args , config ):
225231 # Load dataset
226232 dataset = load_prompts (args .dataset )
@@ -236,7 +242,6 @@ def run_mlperf(args, config):
236242 setup_logging (rank )
237243
238244 # Generation parameters from config
239-
240245
241246 output_dir = Path (args .output_dir )
242247 output_dir .mkdir (parents = True , exist_ok = True )
@@ -253,7 +258,7 @@ def run_mlperf(args, config):
253258
254259 # Loading model
255260 model = Model (args .model_path , device , config , dataset , fixed_latent , rank )
256- #model = DebugModel(args.model_path, device, config, dataset, fixed_latent, rank)
261+ # model = DebugModel(args.model_path, device, config, dataset, fixed_latent, rank)
257262 logging .info ("Model loaded successfully!" )
258263
259264 # Prepare loadgen for run
@@ -272,7 +277,10 @@ def run_mlperf(args, config):
272277
273278 audit_config = os .path .abspath (args .audit_conf )
274279 if os .path .exists (audit_config ):
275- settings .FromConfig (audit_config , "qwen3-vl-235b-a22b" , args .scenario )
280+ settings .FromConfig (
281+ audit_config ,
282+ "qwen3-vl-235b-a22b" ,
283+ args .scenario )
276284 settings .scenario = SCENARIO_MAP [args .scenario ]
277285
278286 settings .mode = lg .TestMode .PerformanceOnly
@@ -288,12 +296,11 @@ def run_mlperf(args, config):
288296 settings .server_target_qps = qps
289297 settings .offline_expected_qps = qps
290298
291-
292299 count_override = False
293300 count = args .count
294301 if count :
295302 count_override = True
296-
303+
297304 if args .count :
298305 settings .min_query_count = count
299306 settings .max_query_count = count
@@ -302,37 +309,39 @@ def run_mlperf(args, config):
302309 if args .samples_per_query :
303310 settings .multi_stream_samples_per_query = args .samples_per_query
304311 if args .max_latency :
305- settings .server_target_latency_ns = int (args .max_latency * NANO_SEC )
312+ settings .server_target_latency_ns = int (
313+ args .max_latency * NANO_SEC )
306314 settings .multi_stream_expected_latency_ns = int (
307315 args .max_latency * NANO_SEC )
308-
316+
309317 performance_sample_count = (
310318 args .performance_sample_count
311319 if args .performance_sample_count
312320 else min (count , 500 )
313321 )
314-
322+
315323 sut = lg .ConstructSUT (model .issue_queries , model .flush_queries )
316324 qsl = lg .ConstructQSL (
317325 count , performance_sample_count , load_query_samples , unload_query_samples
318326 )
319327
320- lg .StartTestWithLogSettings (sut , qsl , settings , log_settings , audit_config )
328+ lg .StartTestWithLogSettings (
329+ sut , qsl , settings , log_settings , audit_config )
321330 if args .accuracy :
322- ## TODO: output accuracy
331+ # TODO: output accuracy
323332 final_results = {}
324333 with open ("results.json" , "w" ) as f :
325334 json .dump (final_results , f , sort_keys = True , indent = 4 )
326335
327336 lg .DestroyQSL (qsl )
328337 lg .DestroySUT (sut )
329338
339+
330340def main ():
331341 args = get_args ()
332342 config = load_config (args .config )
333343 run_mlperf (args , config )
334344
335345
336-
337346if __name__ == "__main__" :
338- main ()
347+ main ()
0 commit comments