Skip to content

Commit 9e52779

Browse files
committed
Automatically infer backbone dtype from model parameters
Signed-off-by: ajrasane <[email protected]>
1 parent 17e07d0 commit 9e52779

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,6 @@
4040
"flux-schnell": ModelType.FLUX_SCHNELL,
4141
}
4242

43-
dtype_map = {
44-
"Half": torch.float16,
45-
"BFloat16": torch.bfloat16,
46-
"Float": torch.float32,
47-
}
48-
4943

5044
def generate_image(pipe, prompt, image_name):
5145
seed = 42
@@ -60,7 +54,7 @@ def generate_image(pipe, prompt, image_name):
6054

6155

6256
def benchmark_model(
63-
pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20, model_dtype="Half"
57+
pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20, model_dtype=torch.float16
6458
):
6559
"""Benchmark the backbone model inference time."""
6660
backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
@@ -83,7 +77,7 @@ def forward_hook(_module, _input, _output):
8377
try:
8478
print(f"Starting warmup: {num_warmup} runs")
8579
for _ in tqdm(range(num_warmup), desc="Warmup"):
86-
with torch.amp.autocast("cuda", dtype=dtype_map[model_dtype]):
80+
with torch.amp.autocast("cuda", dtype=model_dtype):
8781
_ = pipe(
8882
prompt,
8983
output_type="pil",
@@ -95,7 +89,7 @@ def forward_hook(_module, _input, _output):
9589

9690
print(f"Starting benchmark: {num_runs} runs")
9791
for _ in tqdm(range(num_runs), desc="Benchmark"):
98-
with torch.amp.autocast("cuda", dtype=dtype_map[model_dtype]):
92+
with torch.amp.autocast("cuda", dtype=model_dtype):
9993
_ = pipe(
10094
prompt,
10195
output_type="pil",
@@ -169,17 +163,17 @@ def main():
169163
# Save the backbone of the pipeline and move it to the GPU
170164
add_embedding = None
171165
backbone = None
172-
model_dtype = None
173166
if hasattr(pipe, "transformer"):
174167
backbone = pipe.transformer
175-
model_dtype = "Bfloat16"
176168
elif hasattr(pipe, "unet"):
177169
backbone = pipe.unet
178170
add_embedding = backbone.add_embedding
179-
model_dtype = "Half"
180171
else:
181172
raise ValueError("Pipeline does not have a transformer or unet backbone")
182173

174+
# Get dtype directly from the backbone's parameters
175+
model_dtype = next(backbone.parameters()).dtype
176+
183177
if args.restore_from:
184178
mto.restore(backbone, args.restore_from)
185179

0 commit comments

Comments
 (0)