Skip to content

Commit 6184567

Browse files
committed
Benchmark the backbone only
Signed-off-by: ajrasane <[email protected]>
1 parent ecaaf76 commit 6184567

File tree

1 file changed

+38
-21
lines changed

1 file changed

+38
-21
lines changed

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -60,30 +60,47 @@ def generate_image(pipe, prompt, image_name):
6060

6161

6262
def benchmark_model(pipe, prompt, num_warmup=3, num_runs=10):
63-
"""Benchmark the model inference time."""
64-
# Warmup runs
65-
for _ in range(num_warmup):
66-
_ = pipe(
67-
prompt,
68-
output_type="pil",
69-
num_inference_steps=30,
70-
generator=torch.Generator("cuda").manual_seed(42),
71-
)
63+
"""Benchmark the backbone model inference time."""
64+
backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
7265

73-
# Benchmark runs
74-
torch.cuda.synchronize()
75-
start = time.time()
76-
for _ in range(num_runs):
77-
_ = pipe(
78-
prompt,
79-
output_type="pil",
80-
num_inference_steps=30,
81-
generator=torch.Generator("cuda").manual_seed(42),
82-
)
66+
backbone_times = []
67+
68+
def forward_pre_hook(module, input):
8369
torch.cuda.synchronize()
84-
end = time.time()
70+
module._start_time = time.time()
8571

86-
avg_latency = (end - start) / num_runs * 1000 # Convert to ms
72+
def forward_hook(module, input, output):
73+
torch.cuda.synchronize()
74+
module._end_time = time.time()
75+
backbone_times.append((module._end_time - module._start_time) * 1000) # Convert to ms
76+
77+
pre_handle = backbone.register_forward_pre_hook(forward_pre_hook)
78+
post_handle = backbone.register_forward_hook(forward_hook)
79+
80+
try:
81+
for _ in range(num_warmup):
82+
_ = pipe(
83+
prompt,
84+
output_type="pil",
85+
num_inference_steps=10,
86+
generator=torch.Generator("cuda").manual_seed(42),
87+
)
88+
89+
backbone_times.clear()
90+
91+
for _ in range(num_runs):
92+
_ = pipe(
93+
prompt,
94+
output_type="pil",
95+
num_inference_steps=10,
96+
generator=torch.Generator("cuda").manual_seed(42),
97+
)
98+
finally:
99+
pre_handle.remove()
100+
post_handle.remove()
101+
102+
total_backbone_time = sum(backbone_times)
103+
avg_latency = total_backbone_time / num_runs
87104
return avg_latency
88105

89106

0 commit comments

Comments
 (0)