Skip to content

Commit 1aafbbc

Browse files
committed
Update benchmarking for diffusers
Signed-off-by: ajrasane <[email protected]>
1 parent ca94c96 commit 1aafbbc

File tree

1 file changed

+44
-42
lines changed

1 file changed

+44
-42
lines changed

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
}
5050

5151

52+
@torch.inference_mode()
5253
def generate_image(pipe, prompt, image_name):
5354
seed = 42
5455
image = pipe(
@@ -61,56 +62,52 @@ def generate_image(pipe, prompt, image_name):
6162
print(f"Image generated saved as {image_name}")
6263

6364

64-
def benchmark_model(
65-
pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20, model_dtype=torch.float16
65+
@torch.inference_mode()
66+
def benchmark_backbone_standalone(
67+
pipe, num_warmup=10, num_benchmark=100, model_name="flux-dev",
6668
):
67-
"""Benchmark the backbone model inference time."""
69+
"""Benchmark the backbone model directly without running the full pipeline."""
6870
backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
6971

70-
backbone_times = []
72+
# Generate dummy inputs for the backbone
73+
dummy_inputs, _, _ = generate_dummy_inputs_and_dynamic_axes_and_shapes(model_name, backbone)
74+
75+
# Extract the dict from the tuple and move to cuda
76+
dummy_inputs_dict = {
77+
k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in dummy_inputs[0].items()
78+
}
79+
80+
# Warmup
81+
print(f"Warming up: {num_warmup} iterations")
82+
for _ in tqdm(range(num_warmup), desc="Warmup"):
83+
_ = backbone(**dummy_inputs_dict)
84+
85+
# Benchmark
86+
torch.cuda.synchronize()
7187
start_event = torch.cuda.Event(enable_timing=True)
7288
end_event = torch.cuda.Event(enable_timing=True)
7389

74-
def forward_pre_hook(_module, _input):
90+
print(f"Benchmarking: {num_benchmark} iterations")
91+
times = []
92+
for _ in tqdm(range(num_benchmark), desc="Benchmark"):
7593
start_event.record()
76-
77-
def forward_hook(_module, _input, _output):
94+
_ = backbone(**dummy_inputs_dict)
7895
end_event.record()
7996
torch.cuda.synchronize()
80-
backbone_times.append(start_event.elapsed_time(end_event))
81-
82-
pre_handle = backbone.register_forward_pre_hook(forward_pre_hook)
83-
post_handle = backbone.register_forward_hook(forward_hook)
84-
85-
try:
86-
print(f"Starting warmup: {num_warmup} runs")
87-
for _ in tqdm(range(num_warmup), desc="Warmup"):
88-
with torch.amp.autocast("cuda", dtype=model_dtype):
89-
_ = pipe(
90-
prompt,
91-
output_type="pil",
92-
num_inference_steps=num_inference_steps,
93-
generator=torch.Generator("cuda").manual_seed(42),
94-
)
95-
96-
backbone_times.clear()
97-
98-
print(f"Starting benchmark: {num_runs} runs")
99-
for _ in tqdm(range(num_runs), desc="Benchmark"):
100-
with torch.amp.autocast("cuda", dtype=model_dtype):
101-
_ = pipe(
102-
prompt,
103-
output_type="pil",
104-
num_inference_steps=num_inference_steps,
105-
generator=torch.Generator("cuda").manual_seed(42),
106-
)
107-
finally:
108-
pre_handle.remove()
109-
post_handle.remove()
110-
111-
total_backbone_time = sum(backbone_times)
112-
avg_latency = total_backbone_time / (num_runs * num_inference_steps)
113-
print(f"Inference latency of the torch backbone: {avg_latency:.2f} ms")
97+
times.append(start_event.elapsed_time(end_event))
98+
99+
avg_latency = sum(times) / len(times)
100+
times = sorted(times)
101+
p50 = times[len(times) // 2]
102+
p95 = times[int(len(times) * 0.95)]
103+
p99 = times[int(len(times) * 0.99)]
104+
105+
print(f"\nBackbone-only inference latency:")
106+
print(f" Average: {avg_latency:.2f} ms")
107+
print(f" P50: {p50:.2f} ms")
108+
print(f" P95: {p95:.2f} ms")
109+
print(f" P99: {p99:.2f} ms")
110+
114111
return avg_latency
115112

116113

@@ -196,7 +193,12 @@ def main():
196193
pipe.to("cuda")
197194

198195
if args.benchmark:
199-
benchmark_model(pipe, args.prompt, model_dtype=model_dtype)
196+
benchmark_backbone_standalone(
197+
pipe,
198+
num_warmup=10,
199+
num_benchmark=100,
200+
model_name=args.model,
201+
)
200202

201203
if not args.skip_image:
202204
generate_image(pipe, args.prompt, image_name)

0 commit comments

Comments
 (0)