1-
21import argparse
32import array
43import json
1110import torch
1211import yaml
1312from diffusers import AutoencoderKLWan , WanPipeline
14- from diffusers .utils import export_to_video
1513
1614SCENARIO_MAP = {
1715 "SingleStream" : lg .TestScenario .SingleStream ,
@@ -28,52 +26,47 @@ def setup_logging(rank):
2826 """Setup logging configuration for data parallel (all ranks log)."""
2927 logging .basicConfig (
3028 level = logging .INFO ,
31- format = f' [Rank { rank } ] %(asctime)s - %(levelname)s - %(message)s' ,
32- datefmt = ' %Y-%m-%d %H:%M:%S'
29+ format = f" [Rank { rank } ] %(asctime)s - %(levelname)s - %(message)s" ,
30+ datefmt = " %Y-%m-%d %H:%M:%S" ,
3331 )
3432
3533
3634def load_config (config_path ):
3735 """Load configuration from YAML file."""
38- with open (config_path , 'r' ) as f :
36+ with open (config_path , "r" ) as f :
3937 config = yaml .safe_load (f )
4038 return config
4139
4240
4341def load_prompts (dataset_path ):
4442 """Load prompts from dataset file."""
45- with open (dataset_path , 'r' ) as f :
43+ with open (dataset_path , "r" ) as f :
4644 prompts = [line .strip () for line in f if line .strip ()]
4745 return prompts
4846
4947
5048class Model :
51- def __init__ (
52- self , model_path , video_output_path , device , config , prompts , fixed_latent = None , rank = 0
53- ):
54- self .video_output_path = video_output_path
49+ def __init__ (self , model_path , device , config , prompts , fixed_latent = None , rank = 0 ):
5550 self .device = device
5651 self .rank = rank
57- self .height = config [' height' ]
58- self .width = config [' width' ]
59- self .num_frames = config [' num_frames' ]
60- self .fps = config [' fps' ]
61- self .guidance_scale = config [' guidance_scale' ]
62- self .guidance_scale_2 = config [' guidance_scale_2' ]
63- self .boundary_ratio = config [' boundary_ratio' ]
64- self .negative_prompt = config [' negative_prompt' ].strip ()
65- self .sample_steps = config [' sample_steps' ]
66- self .base_seed = config [' seed' ]
52+ self .height = config [" height" ]
53+ self .width = config [" width" ]
54+ self .num_frames = config [" num_frames" ]
55+ self .fps = config [" fps" ]
56+ self .guidance_scale = config [" guidance_scale" ]
57+ self .guidance_scale_2 = config [" guidance_scale_2" ]
58+ self .boundary_ratio = config [" boundary_ratio" ]
59+ self .negative_prompt = config [" negative_prompt" ].strip ()
60+ self .sample_steps = config [" sample_steps" ]
61+ self .base_seed = config [" seed" ]
6762 self .vae = AutoencoderKLWan .from_pretrained (
68- model_path ,
69- subfolder = "vae" ,
70- torch_dtype = torch .float32
63+ model_path , subfolder = "vae" , torch_dtype = torch .float32
7164 )
7265 self .pipe = WanPipeline .from_pretrained (
7366 model_path ,
7467 boundary_ratio = self .boundary_ratio ,
7568 vae = self .vae ,
76- torch_dtype = torch .bfloat16
69+ torch_dtype = torch .bfloat16 ,
7770 )
7871 self .pipe .to (self .device )
7972 self .prompts = prompts
@@ -94,24 +87,15 @@ def issue_queries(self, query_samples):
9487 "guidance_scale" : self .guidance_scale ,
9588 "guidance_scale_2" : self .guidance_scale_2 ,
9689 "num_inference_steps" : self .sample_steps ,
97- "generator" : torch .Generator (device = self .device ).manual_seed (self .base_seed ),
90+ "generator" : torch .Generator (device = self .device ).manual_seed (
91+ self .base_seed
92+ ),
9893 }
9994 if self .fixed_latent is not None :
10095 pipeline_kwargs ["latents" ] = self .fixed_latent
10196 output = self .pipe (** pipeline_kwargs ).frames [0 ]
102-
103- # Save to video to reduce mlperf_log_accuracy.json size
104- output_path = Path (
105- self .video_output_path ,
106- f"{ self .prompts [i ]} -0.mp4" )
107- logging .info (f"Saving { q } to { output_path } " )
108- export_to_video (output [0 ], str (output_path ), fps = self .fps )
109-
110- with open (output_path , "rb" ) as f :
111- resp = f .read ()
112-
11397 response_array = array .array (
114- "B" , resp
98+ "B" , output . cpu (). detach (). numpy (). tobytes ()
11599 )
116100 bi = response_array .buffer_info ()
117101 response .append (lg .QuerySampleResponse (q , bi [0 ], bi [1 ]))
@@ -122,23 +106,21 @@ def flush_queries(self):
122106
123107
124108class DebugModel :
125- def __init__ (
126- self , model_path , device , config , prompts , fixed_latent = None , rank = 0
127- ):
109+ def __init__ (self , model_path , device , config , prompts , fixed_latent = None , rank = 0 ):
128110 self .prompts = prompts
129111
130112 def issue_queries (self , query_samples ):
131113 idx = [q .index for q in query_samples ]
132114 query_ids = [q .id for q in query_samples ]
133115 response = []
116+ response_array_refs = []
134117 for i , q in zip (idx , query_ids ):
135118 print (i , self .prompts [i ])
136119 output = self .prompts [i ]
137- response_array = array .array (
138- "B" , output .encode ("utf-8" )
139- )
120+ response_array = array .array ("B" , output .encode ("utf-8" ))
140121 bi = response_array .buffer_info ()
141122 response .append (lg .QuerySampleResponse (q , bi [0 ], bi [1 ]))
123+ response_array_refs .append (response_array )
142124 lg .QuerySamplesComplete (response )
143125
144126 def flush_queries (self ):
@@ -155,56 +137,56 @@ def unload_query_samples(sample_list):
155137
156138def get_args ():
157139 parser = argparse .ArgumentParser (
158- description = "Batch T2V inference with Wan2.2-Diffusers" )
140+ description = "Batch T2V inference with Wan2.2-Diffusers"
141+ )
159142 # Model Arguments
160143 parser .add_argument (
161144 "--model-path" ,
162145 type = str ,
163146 default = "./models/Wan2.2-T2V-A14B-Diffusers" ,
164- help = "Path to model checkpoint directory (default: ./models/Wan2.2-T2V-A14B-Diffusers)"
147+ help = "Path to model checkpoint directory (default: ./models/Wan2.2-T2V-A14B-Diffusers)" ,
165148 )
166149 parser .add_argument (
167150 "--dataset" ,
168151 type = str ,
169152 default = "./data/vbench_prompts.txt" ,
170- help = "Path to dataset file (text prompts, one per line) (default: ./data/prompts.txt)"
153+ help = "Path to dataset file (text prompts, one per line) (default: ./data/prompts.txt)" ,
171154 )
172155 parser .add_argument (
173156 "--output-dir" ,
174157 type = str ,
175158 default = "./output" ,
176- help = "Directory to save generated videos (default: ./data/outputs)"
159+ help = "Directory to save generated videos (default: ./data/outputs)" ,
177160 )
178161 parser .add_argument (
179162 "--config" ,
180163 type = str ,
181164 default = "./inference_config.yaml" ,
182- help = "Path to inference configuration file (default: ./inference_config.yaml)"
165+ help = "Path to inference configuration file (default: ./inference_config.yaml)" ,
183166 )
184167 parser .add_argument (
185168 "--num-iterations" ,
186169 type = int ,
187170 default = 1 ,
188- help = "Number of generation iterations per prompt (default: 1)"
171+ help = "Number of generation iterations per prompt (default: 1)" ,
189172 )
190173 parser .add_argument (
191174 "--num-prompts" ,
192175 type = int ,
193176 default = - 1 ,
194- help = "Process only first N prompts (for testing, default: all)"
177+ help = "Process only first N prompts (for testing, default: all)" ,
195178 )
196179 parser .add_argument (
197180 "--fixed-latent" ,
198181 type = str ,
199182 default = "./data/fixed_latent.pt" ,
200- help = "Path to fixed latent .pt file for deterministic generation (default: data/fixed_latent.pt)"
183+ help = "Path to fixed latent .pt file for deterministic generation (default: data/fixed_latent.pt)" ,
201184 )
202185 # MLPerf loadgen arguments
203186 parser .add_argument (
204187 "--scenario" ,
205188 default = "SingleStream" ,
206- help = "mlperf benchmark scenario, one of " +
207- str (list (SCENARIO_MAP .keys ())),
189+ help = "mlperf benchmark scenario, one of " + str (list (SCENARIO_MAP .keys ())),
208190 )
209191 parser .add_argument (
210192 "--user_conf" ,
@@ -218,19 +200,9 @@ def get_args():
218200 "--performance-sample-count" ,
219201 type = int ,
220202 help = "performance sample count" ,
221- default = 248 ,
222- )
223- parser .add_argument (
224- "--accuracy" ,
225- action = "store_true" ,
226- help = "enable accuracy pass"
227- )
228- parser .add_argument (
229- "--video_output_path" ,
230- type = str ,
231- default = "./videos" ,
232- help = "path to store output videos"
203+ default = 5000 ,
233204 )
205+ parser .add_argument ("--accuracy" , action = "store_true" , help = "enable accuracy pass" )
234206 # Dont overwrite these for official submission
235207 parser .add_argument ("--count" , type = int , help = "dataset items to use" )
236208 parser .add_argument ("--time" , type = int , help = "time to scan in seconds" )
@@ -272,20 +244,14 @@ def run_mlperf(args, config):
272244 if args .fixed_latent :
273245 fixed_latent = torch .load (args .fixed_latent )
274246 logging .info (
275- f"Loaded fixed latent from { args .fixed_latent } with shape: { fixed_latent .shape } " )
247+ f"Loaded fixed latent from { args .fixed_latent } with shape: { fixed_latent .shape } "
248+ )
276249 logging .info ("This latent will be reused for all generations" )
277250 else :
278251 logging .info ("No fixed latent provided - using random initial latents" )
279252
280253 # Loading model
281- model = Model (
282- args .model_path ,
283- args .video_output_path ,
284- device ,
285- config ,
286- dataset ,
287- fixed_latent ,
288- rank )
254+ model = Model (args .model_path , device , config , dataset , fixed_latent , rank )
289255 # model = DebugModel(args.model_path, device, config, dataset, fixed_latent, rank)
290256 logging .info ("Model loaded successfully!" )
291257
@@ -305,10 +271,7 @@ def run_mlperf(args, config):
305271
306272 audit_config = os .path .abspath (args .audit_conf )
307273 if os .path .exists (audit_config ):
308- settings .FromConfig (
309- audit_config ,
310- "wan-2.2-t2v-a14b" ,
311- args .scenario )
274+ settings .FromConfig (audit_config , "wan-2.2-t2v-a14b" , args .scenario )
312275 settings .scenario = SCENARIO_MAP [args .scenario ]
313276
314277 settings .mode = lg .TestMode .PerformanceOnly
@@ -324,24 +287,18 @@ def run_mlperf(args, config):
324287 settings .server_target_qps = qps
325288 settings .offline_expected_qps = qps
326289
327- count_override = False
328290 count = args .count
329- if count :
330- count_override = True
331291
332292 if args .count :
333293 settings .min_query_count = count
334294 settings .max_query_count = count
335- if not count_override :
336- count = len (dataset )
295+ count = len (dataset )
337296
338297 if args .samples_per_query :
339298 settings .multi_stream_samples_per_query = args .samples_per_query
340299 if args .max_latency :
341- settings .server_target_latency_ns = int (
342- args .max_latency * NANO_SEC )
343- settings .multi_stream_expected_latency_ns = int (
344- args .max_latency * NANO_SEC )
300+ settings .server_target_latency_ns = int (args .max_latency * NANO_SEC )
301+ settings .multi_stream_expected_latency_ns = int (args .max_latency * NANO_SEC )
345302
346303 performance_sample_count = (
347304 args .performance_sample_count
@@ -354,13 +311,7 @@ def run_mlperf(args, config):
354311 count , performance_sample_count , load_query_samples , unload_query_samples
355312 )
356313
357- lg .StartTestWithLogSettings (
358- sut , qsl , settings , log_settings , audit_config )
359- if args .accuracy :
360- # TODO: output accuracy
361- final_results = {}
362- with open ("results.json" , "w" ) as f :
363- json .dump (final_results , f , sort_keys = True , indent = 4 )
314+ lg .StartTestWithLogSettings (sut , qsl , settings , log_settings , audit_config )
364315
365316 lg .DestroyQSL (qsl )
366317 lg .DestroySUT (sut )
0 commit comments