Skip to content

Commit 1ead6d6

Browse files
committed
[NVBUG: 5619158] Enfore high precision model dtype for diffusion trt
1 parent 5adb9ba commit 1ead6d6

File tree

1 file changed

+4
-12
lines changed

1 file changed

+4
-12
lines changed

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,6 @@ def main():
126126
default=None,
127127
help="Path to the model if not using default paths in MODEL_ID mapping.",
128128
)
129-
parser.add_argument(
130-
"--model-dtype",
131-
type=str,
132-
default="Half",
133-
choices=["Half", "BFloat16", "Float"],
134-
help="Precision used to load the model.",
135-
)
136129
parser.add_argument(
137130
"--restore-from", type=str, default=None, help="Path to the modelopt quantized checkpoint"
138131
)
@@ -170,28 +163,27 @@ def main():
170163

171164
pipe = PipelineManager.create_pipeline_from(
172165
MODEL_ID[args.model],
173-
dtype_map[args.model_dtype],
174166
override_model_path=args.override_model_path,
175167
)
176168

177169
# Save the backbone of the pipeline and move it to the GPU
178170
add_embedding = None
179171
backbone = None
172+
model_dtype = None
180173
if hasattr(pipe, "transformer"):
181174
backbone = pipe.transformer
175+
model_dtype = "Bfloat16"
182176
elif hasattr(pipe, "unet"):
183177
backbone = pipe.unet
184178
add_embedding = backbone.add_embedding
179+
model_dtype = "Half"
185180
else:
186181
raise ValueError("Pipeline does not have a transformer or unet backbone")
187182

188183
if args.restore_from:
189184
mto.restore(backbone, args.restore_from)
190185

191186
if args.torch_compile:
192-
assert args.model_dtype in ["BFloat16", "Float", "Half"], (
193-
"torch.compile() only supports BFloat16 and Float"
194-
)
195187
print("Compiling backbone with torch.compile()...")
196188
backbone = torch.compile(backbone, mode="max-autotune")
197189

@@ -203,7 +195,7 @@ def main():
203195
pipe.to("cuda")
204196

205197
if args.benchmark:
206-
benchmark_model(pipe, args.prompt, model_dtype=args.model_dtype)
198+
benchmark_model(pipe, args.prompt, model_dtype=model_dtype)
207199

208200
if not args.skip_image:
209201
generate_image(pipe, args.prompt, image_name)

0 commit comments

Comments
 (0)