Skip to content

Commit 3a3c0ff

Browse files
[Automated Commit] Format Codebase
1 parent 3a61b54 commit 3a3c0ff

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

text_to_video/wan2.2-t2v-14b/run_mlperf.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
NANO_SEC = 1e9
2222
MILLI_SEC = 1000
2323

24+
2425
def setup_logging(rank):
2526
"""Setup logging configuration for data parallel (all ranks log)."""
2627
logging.basicConfig(
@@ -45,7 +46,8 @@ def load_prompts(dataset_path):
4546

4647

4748
class Model:
48-
def __init__(self, model_path, device, config, prompts, fixed_latent = None, rank = 0):
49+
def __init__(self, model_path, device, config,
50+
prompts, fixed_latent=None, rank=0):
4951
self.device = device
5052
self.rank = rank
5153
self.height = config['height']
@@ -105,7 +107,8 @@ def flush_queries(self):
105107

106108

107109
class DebugModel:
108-
def __init__(self, model_path, device, config, prompts, fixed_latent = None, rank = 0):
110+
def __init__(self, model_path, device, config,
111+
prompts, fixed_latent=None, rank=0):
109112
self.prompts = prompts
110113

111114
def issue_queries(self, query_samples):
@@ -129,13 +132,15 @@ def flush_queries(self):
129132
def load_query_samples(sample_list):
130133
pass
131134

135+
132136
def unload_query_samples(sample_list):
133137
pass
134138

139+
135140
def get_args():
136141
parser = argparse.ArgumentParser(
137142
description="Batch T2V inference with Wan2.2-Diffusers")
138-
## Model Arguments
143+
# Model Arguments
139144
parser.add_argument(
140145
"--model-path",
141146
type=str,
@@ -178,7 +183,7 @@ def get_args():
178183
default="./data/fixed_latent.pt",
179184
help="Path to fixed latent .pt file for deterministic generation (default: data/fixed_latent.pt)"
180185
)
181-
## MLPerf loadgen arguments
186+
# MLPerf loadgen arguments
182187
parser.add_argument(
183188
"--scenario",
184189
default="SingleStream",
@@ -221,6 +226,7 @@ def get_args():
221226

222227
return parser.parse_args()
223228

229+
224230
def run_mlperf(args, config):
225231
# Load dataset
226232
dataset = load_prompts(args.dataset)
@@ -236,7 +242,6 @@ def run_mlperf(args, config):
236242
setup_logging(rank)
237243

238244
# Generation parameters from config
239-
240245

241246
output_dir = Path(args.output_dir)
242247
output_dir.mkdir(parents=True, exist_ok=True)
@@ -253,7 +258,7 @@ def run_mlperf(args, config):
253258

254259
# Loading model
255260
model = Model(args.model_path, device, config, dataset, fixed_latent, rank)
256-
#model = DebugModel(args.model_path, device, config, dataset, fixed_latent, rank)
261+
# model = DebugModel(args.model_path, device, config, dataset, fixed_latent, rank)
257262
logging.info("Model loaded successfully!")
258263

259264
# Prepare loadgen for run
@@ -272,7 +277,10 @@ def run_mlperf(args, config):
272277

273278
audit_config = os.path.abspath(args.audit_conf)
274279
if os.path.exists(audit_config):
275-
settings.FromConfig(audit_config, "qwen3-vl-235b-a22b", args.scenario)
280+
settings.FromConfig(
281+
audit_config,
282+
"qwen3-vl-235b-a22b",
283+
args.scenario)
276284
settings.scenario = SCENARIO_MAP[args.scenario]
277285

278286
settings.mode = lg.TestMode.PerformanceOnly
@@ -288,12 +296,11 @@ def run_mlperf(args, config):
288296
settings.server_target_qps = qps
289297
settings.offline_expected_qps = qps
290298

291-
292299
count_override = False
293300
count = args.count
294301
if count:
295302
count_override = True
296-
303+
297304
if args.count:
298305
settings.min_query_count = count
299306
settings.max_query_count = count
@@ -302,37 +309,39 @@ def run_mlperf(args, config):
302309
if args.samples_per_query:
303310
settings.multi_stream_samples_per_query = args.samples_per_query
304311
if args.max_latency:
305-
settings.server_target_latency_ns = int(args.max_latency * NANO_SEC)
312+
settings.server_target_latency_ns = int(
313+
args.max_latency * NANO_SEC)
306314
settings.multi_stream_expected_latency_ns = int(
307315
args.max_latency * NANO_SEC)
308-
316+
309317
performance_sample_count = (
310318
args.performance_sample_count
311319
if args.performance_sample_count
312320
else min(count, 500)
313321
)
314-
322+
315323
sut = lg.ConstructSUT(model.issue_queries, model.flush_queries)
316324
qsl = lg.ConstructQSL(
317325
count, performance_sample_count, load_query_samples, unload_query_samples
318326
)
319327

320-
lg.StartTestWithLogSettings(sut, qsl, settings, log_settings, audit_config)
328+
lg.StartTestWithLogSettings(
329+
sut, qsl, settings, log_settings, audit_config)
321330
if args.accuracy:
322-
## TODO: output accuracy
331+
# TODO: output accuracy
323332
final_results = {}
324333
with open("results.json", "w") as f:
325334
json.dump(final_results, f, sort_keys=True, indent=4)
326335

327336
lg.DestroyQSL(qsl)
328337
lg.DestroySUT(sut)
329338

339+
330340
def main():
331341
args = get_args()
332342
config = load_config(args.config)
333343
run_mlperf(args, config)
334344

335345

336-
337346
if __name__ == "__main__":
338-
main()
347+
main()

0 commit comments

Comments
 (0)