Skip to content

Commit 89f6c25

Browse files
committed
Update benchmarking for diffusers
Signed-off-by: ajrasane <[email protected]>
1 parent 4b522e0 commit 89f6c25

File tree

1 file changed

+45
-42
lines changed

1 file changed

+45
-42
lines changed

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
}
4848

4949

50+
@torch.inference_mode()
5051
def generate_image(pipe, prompt, image_name):
5152
seed = 42
5253
image = pipe(
@@ -59,56 +60,52 @@ def generate_image(pipe, prompt, image_name):
5960
print(f"Image generated saved as {image_name}")
6061

6162

62-
def benchmark_model(
63-
pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20, model_dtype="Half"
63+
@torch.inference_mode()
64+
def benchmark_backbone_standalone(
65+
pipe, num_warmup=10, num_benchmark=100, model_name="flux-dev", model_dtype="Half"
6466
):
65-
"""Benchmark the backbone model inference time."""
67+
"""Benchmark the backbone model directly without running the full pipeline."""
6668
backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
6769

68-
backbone_times = []
70+
# Generate dummy inputs for the backbone
71+
dummy_inputs, _, _ = generate_dummy_inputs_and_dynamic_axes_and_shapes(model_name, backbone)
72+
73+
# Extract the dict from the tuple and move to cuda
74+
dummy_inputs_dict = {
75+
k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in dummy_inputs[0].items()
76+
}
77+
78+
# Warmup
79+
print(f"Warming up: {num_warmup} iterations")
80+
for _ in tqdm(range(num_warmup), desc="Warmup"):
81+
_ = backbone(**dummy_inputs_dict)
82+
83+
# Benchmark
84+
torch.cuda.synchronize()
6985
start_event = torch.cuda.Event(enable_timing=True)
7086
end_event = torch.cuda.Event(enable_timing=True)
7187

72-
def forward_pre_hook(_module, _input):
88+
print(f"Benchmarking: {num_benchmark} iterations")
89+
times = []
90+
for _ in tqdm(range(num_benchmark), desc="Benchmark"):
7391
start_event.record()
74-
75-
def forward_hook(_module, _input, _output):
92+
_ = backbone(**dummy_inputs_dict)
7693
end_event.record()
7794
torch.cuda.synchronize()
78-
backbone_times.append(start_event.elapsed_time(end_event))
79-
80-
pre_handle = backbone.register_forward_pre_hook(forward_pre_hook)
81-
post_handle = backbone.register_forward_hook(forward_hook)
82-
83-
try:
84-
print(f"Starting warmup: {num_warmup} runs")
85-
for _ in tqdm(range(num_warmup), desc="Warmup"):
86-
with torch.amp.autocast("cuda", dtype=dtype_map[model_dtype]):
87-
_ = pipe(
88-
prompt,
89-
output_type="pil",
90-
num_inference_steps=num_inference_steps,
91-
generator=torch.Generator("cuda").manual_seed(42),
92-
)
93-
94-
backbone_times.clear()
95-
96-
print(f"Starting benchmark: {num_runs} runs")
97-
for _ in tqdm(range(num_runs), desc="Benchmark"):
98-
with torch.amp.autocast("cuda", dtype=dtype_map[model_dtype]):
99-
_ = pipe(
100-
prompt,
101-
output_type="pil",
102-
num_inference_steps=num_inference_steps,
103-
generator=torch.Generator("cuda").manual_seed(42),
104-
)
105-
finally:
106-
pre_handle.remove()
107-
post_handle.remove()
108-
109-
total_backbone_time = sum(backbone_times)
110-
avg_latency = total_backbone_time / (num_runs * num_inference_steps)
111-
print(f"Inference latency of the torch backbone: {avg_latency:.2f} ms")
95+
times.append(start_event.elapsed_time(end_event))
96+
97+
avg_latency = sum(times) / len(times)
98+
times = sorted(times)
99+
p50 = times[len(times) // 2]
100+
p95 = times[int(len(times) * 0.95)]
101+
p99 = times[int(len(times) * 0.99)]
102+
103+
print(f"\nBackbone-only inference latency ({model_dtype}):")
104+
print(f" Average: {avg_latency:.2f} ms")
105+
print(f" P50: {p50:.2f} ms")
106+
print(f" P95: {p95:.2f} ms")
107+
print(f" P99: {p99:.2f} ms")
108+
112109
return avg_latency
113110

114111

@@ -203,7 +200,13 @@ def main():
203200
pipe.to("cuda")
204201

205202
if args.benchmark:
206-
benchmark_model(pipe, args.prompt, model_dtype=args.model_dtype)
203+
benchmark_backbone_standalone(
204+
pipe,
205+
num_warmup=10,
206+
num_benchmark=100,
207+
model_name=args.model,
208+
model_dtype=args.model_dtype,
209+
)
207210

208211
if not args.skip_image:
209212
generate_image(pipe, args.prompt, image_name)

0 commit comments

Comments
 (0)