1+
2+ from diffusers import WanPipeline , AutoencoderKLWan
3+ import argparse
4+ import yaml
5+ import json
6+ import logging
7+ import os
8+ import torch
9+ import array
10+ import numpy as np
11+ import mlperf_loadgen as lg
12+ from pathlib import Path
13+
14+ SCENARIO_MAP = {
15+ "SingleStream" : lg .TestScenario .SingleStream ,
16+ "MultiStream" : lg .TestScenario .MultiStream ,
17+ "Server" : lg .TestScenario .Server ,
18+ "Offline" : lg .TestScenario .Offline ,
19+ }
20+
21+ NANO_SEC = 1e9
22+ MILLI_SEC = 1000
23+
24+ def setup_logging (rank ):
25+ """Setup logging configuration for data parallel (all ranks log)."""
26+ logging .basicConfig (
27+ level = logging .INFO ,
28+ format = f'[Rank { rank } ] %(asctime)s - %(levelname)s - %(message)s' ,
29+ datefmt = '%Y-%m-%d %H:%M:%S'
30+ )
31+
32+
33+ def load_config (config_path ):
34+ """Load configuration from YAML file."""
35+ with open (config_path , 'r' ) as f :
36+ config = yaml .safe_load (f )
37+ return config
38+
39+
40+ def load_prompts (dataset_path ):
41+ """Load prompts from dataset file."""
42+ with open (dataset_path , 'r' ) as f :
43+ prompts = [line .strip () for line in f if line .strip ()]
44+ return prompts
45+
46+
47+ class Model :
48+ def __init__ (self , model_path , device , config , prompts , fixed_latent = None ):
49+ self .device = device
50+ self .height = config ['height' ]
51+ self .width = config ['width' ]
52+ self .num_frames = config ['num_frames' ]
53+ self .fps = config ['fps' ]
54+ self .guidance_scale = config ['guidance_scale' ]
55+ self .guidance_scale_2 = config ['guidance_scale_2' ]
56+ self .boundary_ratio = config ['boundary_ratio' ]
57+ self .negative_prompt = config ['negative_prompt' ].strip ()
58+ self .sample_steps = config ['sample_steps' ]
59+ self .base_seed = config ['seed' ]
60+ self .vae = AutoencoderKLWan .from_pretrained (
61+ model_path ,
62+ subfolder = "vae" ,
63+ torch_dtype = torch .float32
64+ )
65+ self .pipe = WanPipeline .from_pretrained (
66+ model_path ,
67+ boundary_ratio = self .boundary_ratio ,
68+ vae = self .vae ,
69+ torch_dtype = torch .bfloat16
70+ )
71+ self .pipe .to (self .device )
72+ self .prompts = prompts
73+ self .fixed_latent = fixed_latent
74+
75+ def issue_queries (self , query_samples ):
76+ if self .rank == 0 :
77+ idx = [q .index for q in query_samples ]
78+ query_ids = [q .id for q in query_samples ]
79+ response = []
80+ for i , q in zip (idx , query_ids ):
81+ pipeline_kwargs = {
82+ "prompt" : self .prompts [i ],
83+ "negative_prompt" : self .negative_prompt ,
84+ "height" : self .height ,
85+ "width" : self .width ,
86+ "num_frames" : self .num_frames ,
87+ "guidance_scale" : self .guidance_scale ,
88+ "guidance_scale_2" : self .guidance_scale_2 ,
89+ "num_inference_steps" : self .sample_steps ,
90+ "generator" : torch .Generator (device = self .device ).manual_seed (self .base_seed ),
91+ }
92+ if self .fixed_latent is not None :
93+ pipeline_kwargs ["latents" ] = self .fixed_latent
94+ output = self .pipe (** pipeline_kwargs ).frames [0 ]
95+ response_array = array .array (
96+ "B" , output .cpu ().detach ().numpy ().tobytes ()
97+ )
98+ bi = response_array .buffer_info ()
99+ response .append (lg .QuerySampleResponse (q , bi [0 ], bi [1 ]))
100+ lg .QuerySamplesComplete (response )
101+
102+ def flush_queries (self ):
103+ pass
104+
105+
106+ class DebugModel :
107+ def __init__ (self , model_path , device , config , prompts , fixed_latent = None ):
108+ self .prompts = prompts
109+
110+ def issue_queries (self , query_samples ):
111+ idx = [q .index for q in query_samples ]
112+ query_ids = [q .id for q in query_samples ]
113+ response = []
114+ for i , q in zip (idx , query_ids ):
115+ print (i , self .prompts [i ])
116+ output = self .prompts [i ]
117+ response_array = array .array (
118+ "B" , output .encode ("utf-8" )
119+ )
120+ bi = response_array .buffer_info ()
121+ response .append (lg .QuerySampleResponse (q , bi [0 ], bi [1 ]))
122+ lg .QuerySamplesComplete (response )
123+
124+ def flush_queries (self ):
125+ pass
126+
127+
128+ def load_query_samples (sample_list ):
129+ pass
130+
131+ def unload_query_samples (sample_list ):
132+ pass
133+
134+ def get_args ():
135+ parser = argparse .ArgumentParser (
136+ description = "Batch T2V inference with Wan2.2-Diffusers" )
137+ ## Model Arguments
138+ parser .add_argument (
139+ "--model-path" ,
140+ type = str ,
141+ default = "./models/Wan2.2-T2V-A14B-Diffusers" ,
142+ help = "Path to model checkpoint directory (default: ./models/Wan2.2-T2V-A14B-Diffusers)"
143+ )
144+ parser .add_argument (
145+ "--dataset" ,
146+ type = str ,
147+ default = "./data/vbench_prompts.txt" ,
148+ help = "Path to dataset file (text prompts, one per line) (default: ./data/prompts.txt)"
149+ )
150+ parser .add_argument (
151+ "--output-dir" ,
152+ type = str ,
153+ default = "./output" ,
154+ help = "Directory to save generated videos (default: ./data/outputs)"
155+ )
156+ parser .add_argument (
157+ "--config" ,
158+ type = str ,
159+ default = "./inference_config.yaml" ,
160+ help = "Path to inference configuration file (default: ./inference_config.yaml)"
161+ )
162+ parser .add_argument (
163+ "--num-iterations" ,
164+ type = int ,
165+ default = 1 ,
166+ help = "Number of generation iterations per prompt (default: 1)"
167+ )
168+ parser .add_argument (
169+ "--num-prompts" ,
170+ type = int ,
171+ default = - 1 ,
172+ help = "Process only first N prompts (for testing, default: all)"
173+ )
174+ parser .add_argument (
175+ "--fixed-latent" ,
176+ type = str ,
177+ default = "./data/fixed_latent.pt" ,
178+ help = "Path to fixed latent .pt file for deterministic generation (default: data/fixed_latent.pt)"
179+ )
180+ ## MLPerf loadgen arguments
181+ parser .add_argument (
182+ "--scenario" ,
183+ default = "SingleStream" ,
184+ help = "mlperf benchmark scenario, one of " +
185+ str (list (SCENARIO_MAP .keys ())),
186+ )
187+ parser .add_argument (
188+ "--user_conf" ,
189+ default = "user.conf" ,
190+ help = "user config for user LoadGen settings such as target QPS" ,
191+ )
192+ parser .add_argument (
193+ "--audit_conf" , default = "audit.config" , help = "config for LoadGen audit settings"
194+ )
195+ parser .add_argument (
196+ "--performance-sample-count" ,
197+ type = int ,
198+ help = "performance sample count" ,
199+ default = 5000 ,
200+ )
201+ parser .add_argument (
202+ "--accuracy" ,
203+ action = "store_true" ,
204+ help = "enable accuracy pass"
205+ )
206+ # Dont overwrite these for official submission
207+ parser .add_argument ("--count" , type = int , help = "dataset items to use" )
208+ parser .add_argument ("--time" , type = int , help = "time to scan in seconds" )
209+ parser .add_argument ("--qps" , type = int , help = "target qps" )
210+ parser .add_argument ("--debug" , action = "store_true" , help = "debug" )
211+ parser .add_argument (
212+ "--samples-per-query" ,
213+ default = 8 ,
214+ type = int ,
215+ help = "mlperf multi-stream samples per query" ,
216+ )
217+ parser .add_argument (
218+ "--max-latency" , type = float , help = "mlperf max latency in pct tile"
219+ )
220+
221+ return parser .parse_args ()
222+
223+ def run_mlperf (args , config ):
224+ # Load dataset
225+ dataset = load_prompts (args .dataset )
226+
227+ # Load model parameters
228+ # Parallelism parameters
229+ world_size = int (os .environ .get ("WORLD_SIZE" , 1 ))
230+ rank = int (os .environ .get ("RANK" , 0 ))
231+ local_rank = int (os .environ .get ("LOCAL_RANK" , 0 ))
232+
233+ torch .cuda .set_device (local_rank )
234+ device = torch .device (f"cuda:{ local_rank } " )
235+ setup_logging (rank )
236+
237+ # Generation parameters from config
238+
239+
240+ output_dir = Path (args .output_dir )
241+ output_dir .mkdir (parents = True , exist_ok = True )
242+ output_dir_lg = str (args .output_dir )
243+
244+ fixed_latent = None
245+ if args .fixed_latent :
246+ fixed_latent = torch .load (args .fixed_latent )
247+ logging .info (
248+ f"Loaded fixed latent from { args .fixed_latent } with shape: { fixed_latent .shape } " )
249+ logging .info (f"This latent will be reused for all generations" )
250+ else :
251+ logging .info ("No fixed latent provided - using random initial latents" )
252+
253+ # Loading model
254+ model = Model (args .model_path , device , config , dataset , fixed_latent )
255+ #model = DebugModel(args.model_path, device, config, dataset, fixed_latent)
256+ logging .info ("Model loaded successfully!" )
257+
258+ # Prepare loadgen for run
259+ if rank == 0 :
260+ log_output_settings = lg .LogOutputSettings ()
261+ log_output_settings .outdir = output_dir_lg
262+ log_output_settings .copy_summary_to_stdout = False
263+
264+ log_settings = lg .LogSettings ()
265+ log_settings .enable_trace = args .debug
266+ log_settings .log_output = log_output_settings
267+
268+ user_conf = os .path .abspath (args .user_conf )
269+ settings = lg .TestSettings ()
270+ settings .FromConfig (user_conf , "qwen3-vl-235b-a22b" , args .scenario )
271+
272+ audit_config = os .path .abspath (args .audit_conf )
273+ if os .path .exists (audit_config ):
274+ settings .FromConfig (audit_config , "qwen3-vl-235b-a22b" , args .scenario )
275+ settings .scenario = SCENARIO_MAP [args .scenario ]
276+
277+ settings .mode = lg .TestMode .PerformanceOnly
278+ if args .accuracy :
279+ settings .mode = lg .TestMode .AccuracyOnly
280+
281+ if args .time :
282+ # override the time we want to run
283+ settings .min_duration_ms = args .time * MILLI_SEC
284+ settings .max_duration_ms = args .time * MILLI_SEC
285+ if args .qps :
286+ qps = float (args .qps )
287+ settings .server_target_qps = qps
288+ settings .offline_expected_qps = qps
289+
290+
291+ count_override = False
292+ count = args .count
293+ if count :
294+ count_override = True
295+
296+ if args .count :
297+ settings .min_query_count = count
298+ settings .max_query_count = count
299+ count = len (dataset )
300+
301+ if args .samples_per_query :
302+ settings .multi_stream_samples_per_query = args .samples_per_query
303+ if args .max_latency :
304+ settings .server_target_latency_ns = int (args .max_latency * NANO_SEC )
305+ settings .multi_stream_expected_latency_ns = int (
306+ args .max_latency * NANO_SEC )
307+
308+ performance_sample_count = (
309+ args .performance_sample_count
310+ if args .performance_sample_count
311+ else min (count , 500 )
312+ )
313+
314+ sut = lg .ConstructSUT (model .issue_queries , model .flush_queries )
315+ qsl = lg .ConstructQSL (
316+ count , performance_sample_count , load_query_samples , unload_query_samples
317+ )
318+
319+ lg .StartTestWithLogSettings (sut , qsl , settings , log_settings , audit_config )
320+ if args .accuracy :
321+ ## TODO: output accuracy
322+ final_results = {}
323+ with open ("results.json" , "w" ) as f :
324+ json .dump (final_results , f , sort_keys = True , indent = 4 )
325+
326+ lg .DestroyQSL (qsl )
327+ lg .DestroySUT (sut )
328+
329+ def main ():
330+ args = get_args ()
331+ config = load_config (args .config )
332+ run_mlperf (args , config )
333+
334+
335+
336+ if __name__ == "__main__" :
337+ main ()
0 commit comments