Skip to content

Commit 3bb45e6

Browse files
committed
Add a mapping between models and their default dtype
1 parent 9e52779 commit 3bb45e6

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@
4040
"flux-schnell": ModelType.FLUX_SCHNELL,
4141
}
4242

43+
DTYPE_MAP = {
44+
"sdxl-1.0": torch.float16,
45+
"sdxl-turbo": torch.float16,
46+
"sd3-medium": torch.float16,
47+
"flux-dev": torch.bfloat16,
48+
"flux-schnell": torch.bfloat16,
49+
}
50+
4351

4452
def generate_image(pipe, prompt, image_name):
4553
seed = 42
@@ -154,9 +162,11 @@ def main():
154162
args = parser.parse_args()
155163

156164
image_name = args.save_image_as if args.save_image_as else f"{args.model}.png"
165+
model_dtype = DTYPE_MAP[args.model]
157166

158167
pipe = PipelineManager.create_pipeline_from(
159168
MODEL_ID[args.model],
169+
torch_dtype=model_dtype,
160170
override_model_path=args.override_model_path,
161171
)
162172

@@ -171,9 +181,6 @@ def main():
171181
else:
172182
raise ValueError("Pipeline does not have a transformer or unet backbone")
173183

174-
# Get dtype directly from the backbone's parameters
175-
model_dtype = next(backbone.parameters()).dtype
176-
177184
if args.restore_from:
178185
mto.restore(backbone, args.restore_from)
179186

0 commit comments

Comments
 (0)