Skip to content

Commit 83aa657

Browse files
Add SingleStream text2video loadgen integration
1 parent 571fa92 commit 83aa657

File tree

2 files changed

+341
-0
lines changed

2 files changed

+341
-0
lines changed
Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
2+
from diffusers import WanPipeline, AutoencoderKLWan
3+
import argparse
4+
import yaml
5+
import json
6+
import logging
7+
import os
8+
import torch
9+
import array
10+
import numpy as np
11+
import mlperf_loadgen as lg
12+
from pathlib import Path
13+
14+
SCENARIO_MAP = {
15+
"SingleStream": lg.TestScenario.SingleStream,
16+
"MultiStream": lg.TestScenario.MultiStream,
17+
"Server": lg.TestScenario.Server,
18+
"Offline": lg.TestScenario.Offline,
19+
}
20+
21+
NANO_SEC = 1e9
22+
MILLI_SEC = 1000
23+
24+
def setup_logging(rank):
25+
"""Setup logging configuration for data parallel (all ranks log)."""
26+
logging.basicConfig(
27+
level=logging.INFO,
28+
format=f'[Rank {rank}] %(asctime)s - %(levelname)s - %(message)s',
29+
datefmt='%Y-%m-%d %H:%M:%S'
30+
)
31+
32+
33+
def load_config(config_path):
34+
"""Load configuration from YAML file."""
35+
with open(config_path, 'r') as f:
36+
config = yaml.safe_load(f)
37+
return config
38+
39+
40+
def load_prompts(dataset_path):
41+
"""Load prompts from dataset file."""
42+
with open(dataset_path, 'r') as f:
43+
prompts = [line.strip() for line in f if line.strip()]
44+
return prompts
45+
46+
47+
class Model:
48+
def __init__(self, model_path, device, config, prompts, fixed_latent = None):
49+
self.device = device
50+
self.height = config['height']
51+
self.width = config['width']
52+
self.num_frames = config['num_frames']
53+
self.fps = config['fps']
54+
self.guidance_scale = config['guidance_scale']
55+
self.guidance_scale_2 = config['guidance_scale_2']
56+
self.boundary_ratio = config['boundary_ratio']
57+
self.negative_prompt = config['negative_prompt'].strip()
58+
self.sample_steps = config['sample_steps']
59+
self.base_seed = config['seed']
60+
self.vae = AutoencoderKLWan.from_pretrained(
61+
model_path,
62+
subfolder="vae",
63+
torch_dtype=torch.float32
64+
)
65+
self.pipe = WanPipeline.from_pretrained(
66+
model_path,
67+
boundary_ratio=self.boundary_ratio,
68+
vae=self.vae,
69+
torch_dtype=torch.bfloat16
70+
)
71+
self.pipe.to(self.device)
72+
self.prompts = prompts
73+
self.fixed_latent = fixed_latent
74+
75+
def issue_queries(self, query_samples):
76+
if self.rank == 0:
77+
idx = [q.index for q in query_samples]
78+
query_ids = [q.id for q in query_samples]
79+
response = []
80+
for i, q in zip(idx, query_ids):
81+
pipeline_kwargs = {
82+
"prompt": self.prompts[i],
83+
"negative_prompt": self.negative_prompt,
84+
"height": self.height,
85+
"width": self.width,
86+
"num_frames": self.num_frames,
87+
"guidance_scale": self.guidance_scale,
88+
"guidance_scale_2": self.guidance_scale_2,
89+
"num_inference_steps": self.sample_steps,
90+
"generator": torch.Generator(device=self.device).manual_seed(self.base_seed),
91+
}
92+
if self.fixed_latent is not None:
93+
pipeline_kwargs["latents"] = self.fixed_latent
94+
output = self.pipe(**pipeline_kwargs).frames[0]
95+
response_array = array.array(
96+
"B", output.cpu().detach().numpy().tobytes()
97+
)
98+
bi = response_array.buffer_info()
99+
response.append(lg.QuerySampleResponse(q, bi[0], bi[1]))
100+
lg.QuerySamplesComplete(response)
101+
102+
def flush_queries(self):
103+
pass
104+
105+
106+
class DebugModel:
107+
def __init__(self, model_path, device, config, prompts, fixed_latent = None):
108+
self.prompts = prompts
109+
110+
def issue_queries(self, query_samples):
111+
idx = [q.index for q in query_samples]
112+
query_ids = [q.id for q in query_samples]
113+
response = []
114+
for i, q in zip(idx, query_ids):
115+
print(i, self.prompts[i])
116+
output = self.prompts[i]
117+
response_array = array.array(
118+
"B", output.encode("utf-8")
119+
)
120+
bi = response_array.buffer_info()
121+
response.append(lg.QuerySampleResponse(q, bi[0], bi[1]))
122+
lg.QuerySamplesComplete(response)
123+
124+
def flush_queries(self):
125+
pass
126+
127+
128+
def load_query_samples(sample_list):
129+
pass
130+
131+
def unload_query_samples(sample_list):
132+
pass
133+
134+
def get_args():
135+
parser = argparse.ArgumentParser(
136+
description="Batch T2V inference with Wan2.2-Diffusers")
137+
## Model Arguments
138+
parser.add_argument(
139+
"--model-path",
140+
type=str,
141+
default="./models/Wan2.2-T2V-A14B-Diffusers",
142+
help="Path to model checkpoint directory (default: ./models/Wan2.2-T2V-A14B-Diffusers)"
143+
)
144+
parser.add_argument(
145+
"--dataset",
146+
type=str,
147+
default="./data/vbench_prompts.txt",
148+
help="Path to dataset file (text prompts, one per line) (default: ./data/prompts.txt)"
149+
)
150+
parser.add_argument(
151+
"--output-dir",
152+
type=str,
153+
default="./output",
154+
help="Directory to save generated videos (default: ./data/outputs)"
155+
)
156+
parser.add_argument(
157+
"--config",
158+
type=str,
159+
default="./inference_config.yaml",
160+
help="Path to inference configuration file (default: ./inference_config.yaml)"
161+
)
162+
parser.add_argument(
163+
"--num-iterations",
164+
type=int,
165+
default=1,
166+
help="Number of generation iterations per prompt (default: 1)"
167+
)
168+
parser.add_argument(
169+
"--num-prompts",
170+
type=int,
171+
default=-1,
172+
help="Process only first N prompts (for testing, default: all)"
173+
)
174+
parser.add_argument(
175+
"--fixed-latent",
176+
type=str,
177+
default="./data/fixed_latent.pt",
178+
help="Path to fixed latent .pt file for deterministic generation (default: data/fixed_latent.pt)"
179+
)
180+
## MLPerf loadgen arguments
181+
parser.add_argument(
182+
"--scenario",
183+
default="SingleStream",
184+
help="mlperf benchmark scenario, one of " +
185+
str(list(SCENARIO_MAP.keys())),
186+
)
187+
parser.add_argument(
188+
"--user_conf",
189+
default="user.conf",
190+
help="user config for user LoadGen settings such as target QPS",
191+
)
192+
parser.add_argument(
193+
"--audit_conf", default="audit.config", help="config for LoadGen audit settings"
194+
)
195+
parser.add_argument(
196+
"--performance-sample-count",
197+
type=int,
198+
help="performance sample count",
199+
default=5000,
200+
)
201+
parser.add_argument(
202+
"--accuracy",
203+
action="store_true",
204+
help="enable accuracy pass"
205+
)
206+
# Dont overwrite these for official submission
207+
parser.add_argument("--count", type=int, help="dataset items to use")
208+
parser.add_argument("--time", type=int, help="time to scan in seconds")
209+
parser.add_argument("--qps", type=int, help="target qps")
210+
parser.add_argument("--debug", action="store_true", help="debug")
211+
parser.add_argument(
212+
"--samples-per-query",
213+
default=8,
214+
type=int,
215+
help="mlperf multi-stream samples per query",
216+
)
217+
parser.add_argument(
218+
"--max-latency", type=float, help="mlperf max latency in pct tile"
219+
)
220+
221+
return parser.parse_args()
222+
223+
def run_mlperf(args, config):
224+
# Load dataset
225+
dataset = load_prompts(args.dataset)
226+
227+
# Load model parameters
228+
# Parallelism parameters
229+
world_size = int(os.environ.get("WORLD_SIZE", 1))
230+
rank = int(os.environ.get("RANK", 0))
231+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
232+
233+
torch.cuda.set_device(local_rank)
234+
device = torch.device(f"cuda:{local_rank}")
235+
setup_logging(rank)
236+
237+
# Generation parameters from config
238+
239+
240+
output_dir = Path(args.output_dir)
241+
output_dir.mkdir(parents=True, exist_ok=True)
242+
output_dir_lg = str(args.output_dir)
243+
244+
fixed_latent = None
245+
if args.fixed_latent:
246+
fixed_latent = torch.load(args.fixed_latent)
247+
logging.info(
248+
f"Loaded fixed latent from {args.fixed_latent} with shape: {fixed_latent.shape}")
249+
logging.info(f"This latent will be reused for all generations")
250+
else:
251+
logging.info("No fixed latent provided - using random initial latents")
252+
253+
# Loading model
254+
model = Model(args.model_path, device, config, dataset, fixed_latent)
255+
#model = DebugModel(args.model_path, device, config, dataset, fixed_latent)
256+
logging.info("Model loaded successfully!")
257+
258+
# Prepare loadgen for run
259+
if rank == 0:
260+
log_output_settings = lg.LogOutputSettings()
261+
log_output_settings.outdir = output_dir_lg
262+
log_output_settings.copy_summary_to_stdout = False
263+
264+
log_settings = lg.LogSettings()
265+
log_settings.enable_trace = args.debug
266+
log_settings.log_output = log_output_settings
267+
268+
user_conf = os.path.abspath(args.user_conf)
269+
settings = lg.TestSettings()
270+
settings.FromConfig(user_conf, "qwen3-vl-235b-a22b", args.scenario)
271+
272+
audit_config = os.path.abspath(args.audit_conf)
273+
if os.path.exists(audit_config):
274+
settings.FromConfig(audit_config, "qwen3-vl-235b-a22b", args.scenario)
275+
settings.scenario = SCENARIO_MAP[args.scenario]
276+
277+
settings.mode = lg.TestMode.PerformanceOnly
278+
if args.accuracy:
279+
settings.mode = lg.TestMode.AccuracyOnly
280+
281+
if args.time:
282+
# override the time we want to run
283+
settings.min_duration_ms = args.time * MILLI_SEC
284+
settings.max_duration_ms = args.time * MILLI_SEC
285+
if args.qps:
286+
qps = float(args.qps)
287+
settings.server_target_qps = qps
288+
settings.offline_expected_qps = qps
289+
290+
291+
count_override = False
292+
count = args.count
293+
if count:
294+
count_override = True
295+
296+
if args.count:
297+
settings.min_query_count = count
298+
settings.max_query_count = count
299+
count = len(dataset)
300+
301+
if args.samples_per_query:
302+
settings.multi_stream_samples_per_query = args.samples_per_query
303+
if args.max_latency:
304+
settings.server_target_latency_ns = int(args.max_latency * NANO_SEC)
305+
settings.multi_stream_expected_latency_ns = int(
306+
args.max_latency * NANO_SEC)
307+
308+
performance_sample_count = (
309+
args.performance_sample_count
310+
if args.performance_sample_count
311+
else min(count, 500)
312+
)
313+
314+
sut = lg.ConstructSUT(model.issue_queries, model.flush_queries)
315+
qsl = lg.ConstructQSL(
316+
count, performance_sample_count, load_query_samples, unload_query_samples
317+
)
318+
319+
lg.StartTestWithLogSettings(sut, qsl, settings, log_settings, audit_config)
320+
if args.accuracy:
321+
## TODO: output accuracy
322+
final_results = {}
323+
with open("results.json", "w") as f:
324+
json.dump(final_results, f, sort_keys=True, indent=4)
325+
326+
lg.DestroyQSL(qsl)
327+
lg.DestroySUT(sut)
328+
329+
def main():
330+
args = get_args()
331+
config = load_config(args.config)
332+
run_mlperf(args, config)
333+
334+
335+
336+
if __name__ == "__main__":
337+
main()
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# The format of this config file is 'key = value'.
2+
# The key has the format 'model.scenario.key'. Value is mostly int64_t.
3+
# Model maybe '*' as wildcard. In that case the value applies to all models.
4+
# All times are in milli seconds

0 commit comments

Comments
 (0)