Skip to content

Commit 1e04ea4

Browse files
Minor T2V bugfixes (#2495)
* Fixed Wan-2.2-t2v-a14b dataset size * Fixed invalid memory access in run_mlperf.py + minor fixes * Move changes * Added Wan to OFFLINE_MIN_SPQ_SINCE_V4 --------- Co-authored-by: hanyunfan <frank.han@dell.com>
1 parent 0b4aa6f commit 1e04ea4

File tree

3 files changed

+51
-99
lines changed

3 files changed

+51
-99
lines changed

text_to_video/wan-2.2-t2v-a14b/data/vbench_prompts.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,4 +245,4 @@ underwater coral reef
245245
valley
246246
volcano
247247
waterfall
248-
windmill
248+
windmill

text_to_video/wan-2.2-t2v-a14b/run_mlperf.py

Lines changed: 45 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import argparse
32
import array
43
import json
@@ -11,7 +10,6 @@
1110
import torch
1211
import yaml
1312
from diffusers import AutoencoderKLWan, WanPipeline
14-
from diffusers.utils import export_to_video
1513

1614
SCENARIO_MAP = {
1715
"SingleStream": lg.TestScenario.SingleStream,
@@ -28,52 +26,47 @@ def setup_logging(rank):
2826
"""Setup logging configuration for data parallel (all ranks log)."""
2927
logging.basicConfig(
3028
level=logging.INFO,
31-
format=f'[Rank {rank}] %(asctime)s - %(levelname)s - %(message)s',
32-
datefmt='%Y-%m-%d %H:%M:%S'
29+
format=f"[Rank {rank}] %(asctime)s - %(levelname)s - %(message)s",
30+
datefmt="%Y-%m-%d %H:%M:%S",
3331
)
3432

3533

3634
def load_config(config_path):
3735
"""Load configuration from YAML file."""
38-
with open(config_path, 'r') as f:
36+
with open(config_path, "r") as f:
3937
config = yaml.safe_load(f)
4038
return config
4139

4240

4341
def load_prompts(dataset_path):
4442
"""Load prompts from dataset file."""
45-
with open(dataset_path, 'r') as f:
43+
with open(dataset_path, "r") as f:
4644
prompts = [line.strip() for line in f if line.strip()]
4745
return prompts
4846

4947

5048
class Model:
51-
def __init__(
52-
self, model_path, video_output_path, device, config, prompts, fixed_latent=None, rank=0
53-
):
54-
self.video_output_path = video_output_path
49+
def __init__(self, model_path, device, config, prompts, fixed_latent=None, rank=0):
5550
self.device = device
5651
self.rank = rank
57-
self.height = config['height']
58-
self.width = config['width']
59-
self.num_frames = config['num_frames']
60-
self.fps = config['fps']
61-
self.guidance_scale = config['guidance_scale']
62-
self.guidance_scale_2 = config['guidance_scale_2']
63-
self.boundary_ratio = config['boundary_ratio']
64-
self.negative_prompt = config['negative_prompt'].strip()
65-
self.sample_steps = config['sample_steps']
66-
self.base_seed = config['seed']
52+
self.height = config["height"]
53+
self.width = config["width"]
54+
self.num_frames = config["num_frames"]
55+
self.fps = config["fps"]
56+
self.guidance_scale = config["guidance_scale"]
57+
self.guidance_scale_2 = config["guidance_scale_2"]
58+
self.boundary_ratio = config["boundary_ratio"]
59+
self.negative_prompt = config["negative_prompt"].strip()
60+
self.sample_steps = config["sample_steps"]
61+
self.base_seed = config["seed"]
6762
self.vae = AutoencoderKLWan.from_pretrained(
68-
model_path,
69-
subfolder="vae",
70-
torch_dtype=torch.float32
63+
model_path, subfolder="vae", torch_dtype=torch.float32
7164
)
7265
self.pipe = WanPipeline.from_pretrained(
7366
model_path,
7467
boundary_ratio=self.boundary_ratio,
7568
vae=self.vae,
76-
torch_dtype=torch.bfloat16
69+
torch_dtype=torch.bfloat16,
7770
)
7871
self.pipe.to(self.device)
7972
self.prompts = prompts
@@ -94,24 +87,15 @@ def issue_queries(self, query_samples):
9487
"guidance_scale": self.guidance_scale,
9588
"guidance_scale_2": self.guidance_scale_2,
9689
"num_inference_steps": self.sample_steps,
97-
"generator": torch.Generator(device=self.device).manual_seed(self.base_seed),
90+
"generator": torch.Generator(device=self.device).manual_seed(
91+
self.base_seed
92+
),
9893
}
9994
if self.fixed_latent is not None:
10095
pipeline_kwargs["latents"] = self.fixed_latent
10196
output = self.pipe(**pipeline_kwargs).frames[0]
102-
103-
# Save to video to reduce mlperf_log_accuracy.json size
104-
output_path = Path(
105-
self.video_output_path,
106-
f"{self.prompts[i]}-0.mp4")
107-
logging.info(f"Saving {q} to {output_path}")
108-
export_to_video(output[0], str(output_path), fps=self.fps)
109-
110-
with open(output_path, "rb") as f:
111-
resp = f.read()
112-
11397
response_array = array.array(
114-
"B", resp
98+
"B", output.cpu().detach().numpy().tobytes()
11599
)
116100
bi = response_array.buffer_info()
117101
response.append(lg.QuerySampleResponse(q, bi[0], bi[1]))
@@ -122,23 +106,21 @@ def flush_queries(self):
122106

123107

124108
class DebugModel:
125-
def __init__(
126-
self, model_path, device, config, prompts, fixed_latent=None, rank=0
127-
):
109+
def __init__(self, model_path, device, config, prompts, fixed_latent=None, rank=0):
128110
self.prompts = prompts
129111

130112
def issue_queries(self, query_samples):
131113
idx = [q.index for q in query_samples]
132114
query_ids = [q.id for q in query_samples]
133115
response = []
116+
response_array_refs = []
134117
for i, q in zip(idx, query_ids):
135118
print(i, self.prompts[i])
136119
output = self.prompts[i]
137-
response_array = array.array(
138-
"B", output.encode("utf-8")
139-
)
120+
response_array = array.array("B", output.encode("utf-8"))
140121
bi = response_array.buffer_info()
141122
response.append(lg.QuerySampleResponse(q, bi[0], bi[1]))
123+
response_array_refs.append(response_array)
142124
lg.QuerySamplesComplete(response)
143125

144126
def flush_queries(self):
@@ -155,56 +137,56 @@ def unload_query_samples(sample_list):
155137

156138
def get_args():
157139
parser = argparse.ArgumentParser(
158-
description="Batch T2V inference with Wan2.2-Diffusers")
140+
description="Batch T2V inference with Wan2.2-Diffusers"
141+
)
159142
# Model Arguments
160143
parser.add_argument(
161144
"--model-path",
162145
type=str,
163146
default="./models/Wan2.2-T2V-A14B-Diffusers",
164-
help="Path to model checkpoint directory (default: ./models/Wan2.2-T2V-A14B-Diffusers)"
147+
help="Path to model checkpoint directory (default: ./models/Wan2.2-T2V-A14B-Diffusers)",
165148
)
166149
parser.add_argument(
167150
"--dataset",
168151
type=str,
169152
default="./data/vbench_prompts.txt",
170-
help="Path to dataset file (text prompts, one per line) (default: ./data/prompts.txt)"
153+
help="Path to dataset file (text prompts, one per line) (default: ./data/prompts.txt)",
171154
)
172155
parser.add_argument(
173156
"--output-dir",
174157
type=str,
175158
default="./output",
176-
help="Directory to save generated videos (default: ./data/outputs)"
159+
help="Directory to save generated videos (default: ./data/outputs)",
177160
)
178161
parser.add_argument(
179162
"--config",
180163
type=str,
181164
default="./inference_config.yaml",
182-
help="Path to inference configuration file (default: ./inference_config.yaml)"
165+
help="Path to inference configuration file (default: ./inference_config.yaml)",
183166
)
184167
parser.add_argument(
185168
"--num-iterations",
186169
type=int,
187170
default=1,
188-
help="Number of generation iterations per prompt (default: 1)"
171+
help="Number of generation iterations per prompt (default: 1)",
189172
)
190173
parser.add_argument(
191174
"--num-prompts",
192175
type=int,
193176
default=-1,
194-
help="Process only first N prompts (for testing, default: all)"
177+
help="Process only first N prompts (for testing, default: all)",
195178
)
196179
parser.add_argument(
197180
"--fixed-latent",
198181
type=str,
199182
default="./data/fixed_latent.pt",
200-
help="Path to fixed latent .pt file for deterministic generation (default: data/fixed_latent.pt)"
183+
help="Path to fixed latent .pt file for deterministic generation (default: data/fixed_latent.pt)",
201184
)
202185
# MLPerf loadgen arguments
203186
parser.add_argument(
204187
"--scenario",
205188
default="SingleStream",
206-
help="mlperf benchmark scenario, one of " +
207-
str(list(SCENARIO_MAP.keys())),
189+
help="mlperf benchmark scenario, one of " + str(list(SCENARIO_MAP.keys())),
208190
)
209191
parser.add_argument(
210192
"--user_conf",
@@ -218,19 +200,9 @@ def get_args():
218200
"--performance-sample-count",
219201
type=int,
220202
help="performance sample count",
221-
default=248,
222-
)
223-
parser.add_argument(
224-
"--accuracy",
225-
action="store_true",
226-
help="enable accuracy pass"
227-
)
228-
parser.add_argument(
229-
"--video_output_path",
230-
type=str,
231-
default="./videos",
232-
help="path to store output videos"
203+
default=5000,
233204
)
205+
parser.add_argument("--accuracy", action="store_true", help="enable accuracy pass")
234206
# Dont overwrite these for official submission
235207
parser.add_argument("--count", type=int, help="dataset items to use")
236208
parser.add_argument("--time", type=int, help="time to scan in seconds")
@@ -272,20 +244,14 @@ def run_mlperf(args, config):
272244
if args.fixed_latent:
273245
fixed_latent = torch.load(args.fixed_latent)
274246
logging.info(
275-
f"Loaded fixed latent from {args.fixed_latent} with shape: {fixed_latent.shape}")
247+
f"Loaded fixed latent from {args.fixed_latent} with shape: {fixed_latent.shape}"
248+
)
276249
logging.info("This latent will be reused for all generations")
277250
else:
278251
logging.info("No fixed latent provided - using random initial latents")
279252

280253
# Loading model
281-
model = Model(
282-
args.model_path,
283-
args.video_output_path,
284-
device,
285-
config,
286-
dataset,
287-
fixed_latent,
288-
rank)
254+
model = Model(args.model_path, device, config, dataset, fixed_latent, rank)
289255
# model = DebugModel(args.model_path, device, config, dataset, fixed_latent, rank)
290256
logging.info("Model loaded successfully!")
291257

@@ -305,10 +271,7 @@ def run_mlperf(args, config):
305271

306272
audit_config = os.path.abspath(args.audit_conf)
307273
if os.path.exists(audit_config):
308-
settings.FromConfig(
309-
audit_config,
310-
"wan-2.2-t2v-a14b",
311-
args.scenario)
274+
settings.FromConfig(audit_config, "wan-2.2-t2v-a14b", args.scenario)
312275
settings.scenario = SCENARIO_MAP[args.scenario]
313276

314277
settings.mode = lg.TestMode.PerformanceOnly
@@ -324,24 +287,18 @@ def run_mlperf(args, config):
324287
settings.server_target_qps = qps
325288
settings.offline_expected_qps = qps
326289

327-
count_override = False
328290
count = args.count
329-
if count:
330-
count_override = True
331291

332292
if args.count:
333293
settings.min_query_count = count
334294
settings.max_query_count = count
335-
if not count_override:
336-
count = len(dataset)
295+
count = len(dataset)
337296

338297
if args.samples_per_query:
339298
settings.multi_stream_samples_per_query = args.samples_per_query
340299
if args.max_latency:
341-
settings.server_target_latency_ns = int(
342-
args.max_latency * NANO_SEC)
343-
settings.multi_stream_expected_latency_ns = int(
344-
args.max_latency * NANO_SEC)
300+
settings.server_target_latency_ns = int(args.max_latency * NANO_SEC)
301+
settings.multi_stream_expected_latency_ns = int(args.max_latency * NANO_SEC)
345302

346303
performance_sample_count = (
347304
args.performance_sample_count
@@ -354,13 +311,7 @@ def run_mlperf(args, config):
354311
count, performance_sample_count, load_query_samples, unload_query_samples
355312
)
356313

357-
lg.StartTestWithLogSettings(
358-
sut, qsl, settings, log_settings, audit_config)
359-
if args.accuracy:
360-
# TODO: output accuracy
361-
final_results = {}
362-
with open("results.json", "w") as f:
363-
json.dump(final_results, f, sort_keys=True, indent=4)
314+
lg.StartTestWithLogSettings(sut, qsl, settings, log_settings, audit_config)
364315

365316
lg.DestroyQSL(qsl)
366317
lg.DestroySUT(sut)

tools/submission/submission_checker/constants.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@
239239
"whisper": 1633,
240240
"gpt-oss-120b": 6396,
241241
"qwen3-vl-235b-a22b": 48289,
242-
"wan-2.2-t2v-a14b": 247,
242+
"wan-2.2-t2v-a14b": 248,
243243
"dlrm-v3": 349823,
244244
"yolo-95": 64,
245245
"yolo-99": 64,
@@ -270,7 +270,7 @@
270270
# TODO: Need to add accuracy sample count checkers as well (4395)
271271
"gpt-oss-120b": 6396,
272272
"qwen3-vl-235b-a22b": 48289,
273-
"wan-2.2-t2v-a14b": 247,
273+
"wan-2.2-t2v-a14b": 248,
274274
"dlrm-v3": 349823,
275275
"yolo-95": 1525,
276276
"yolo-99": 1525,
@@ -347,7 +347,7 @@
347347
"gpt-oss-120b": {"SingleStream": 1024, "Server": 270336, "Offline": 1},
348348
"qwen3-vl-235b-a22b": {"SingleStream": 1024, "Server": 270336, "Offline": 1},
349349
"dlrm-v3": {"Server": 270336, "Offline": 1},
350-
"wan-2.2-t2v-a14b": {"SingleStream": 247, "Offline": 1},
350+
"wan-2.2-t2v-a14b": {"SingleStream": 248, "Offline": 1},
351351
"yolo-95": {"SingleStream": 1024, "MultiStream": 270336, "Offline": 1},
352352
"yolo-99": {"SingleStream": 1024, "MultiStream": 270336, "Offline": 1},
353353
},
@@ -1169,7 +1169,8 @@
11691169
"yolo-99": 1525,
11701170
"yolo-95": 1525,
11711171
"dlrm-v3": 349823,
1172-
"qwen3-vl-235b-a22b": 48289
1172+
"qwen3-vl-235b-a22b": 48289,
1173+
"wan-2.2-t2v-a14b": 248,
11731174
}
11741175

11751176
SCENARIO_MAPPING = {

0 commit comments

Comments
 (0)