Skip to content

Commit 17e07d0

Browse files
committed
[NVBUG: 5619158] Enfore high precision model dtype for diffusion trt
Signed-off-by: ajrasane <[email protected]>
1 parent 5adb9ba commit 17e07d0

File tree

3 files changed

+12
-14
lines changed

3 files changed

+12
-14
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Model Optimizer Changelog (Linux)
3030
- Add support for multi-node PTQ and export with FSDP2 in ``examples/llm_ptq/multinode_ptq.py``. See `examples/llm_ptq/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/llm_ptq#multi-node-post-training-quantization-with-fsdp2>`_ for more details.
3131
- Add support for Nemotron Nano VL v1 & v2 models in FP8/NVFP4 PTQ workflow.
3232
- Add flags ``nodes_to_include`` and ``op_types_to_include`` in AutoCast to force-include nodes in low precision, even if they would otherwise be excluded by other rules.
33+
- Add support for ``torch.compile`` and benchmarking in ``examples/diffusers/quantization/diffusion_trt.py``.
3334

3435
**Documentation**
3536

examples/diffusers/README.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,11 +307,16 @@ Generate images for the quantized checkpoint with the following [Script](./quant
307307
python diffusion_trt.py \
308308
--model {sdxl-1.0|sdxl-turbo|sd3-medium|flux-dev} \
309309
--prompt "A cat holding a sign that says hello world" \
310+
[--override-model-path /path/to/model] \
310311
[--restore-from ./{MODEL}_fp8.pt] \
311312
[--onnx-load-path {ONNX_DIR}] \
312313
[--trt-engine-load-path {ENGINE_DIR}] \
313-
[--dq_only] \
314-
[--torch]
314+
[--dq-only] \
315+
[--torch] \
316+
[--save-image-as /path/to/image] \
317+
[--benchmark] \
318+
[--torch-compile] \
319+
[--skip-image]
315320
```
316321

317322
This script will save the output image as `./{MODEL}.png` and report the latency of the TensorRT backbone.

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,6 @@ def main():
126126
default=None,
127127
help="Path to the model if not using default paths in MODEL_ID mapping.",
128128
)
129-
parser.add_argument(
130-
"--model-dtype",
131-
type=str,
132-
default="Half",
133-
choices=["Half", "BFloat16", "Float"],
134-
help="Precision used to load the model.",
135-
)
136129
parser.add_argument(
137130
"--restore-from", type=str, default=None, help="Path to the modelopt quantized checkpoint"
138131
)
@@ -170,28 +163,27 @@ def main():
170163

171164
pipe = PipelineManager.create_pipeline_from(
172165
MODEL_ID[args.model],
173-
dtype_map[args.model_dtype],
174166
override_model_path=args.override_model_path,
175167
)
176168

177169
# Save the backbone of the pipeline and move it to the GPU
178170
add_embedding = None
179171
backbone = None
172+
model_dtype = None
180173
if hasattr(pipe, "transformer"):
181174
backbone = pipe.transformer
175+
model_dtype = "Bfloat16"
182176
elif hasattr(pipe, "unet"):
183177
backbone = pipe.unet
184178
add_embedding = backbone.add_embedding
179+
model_dtype = "Half"
185180
else:
186181
raise ValueError("Pipeline does not have a transformer or unet backbone")
187182

188183
if args.restore_from:
189184
mto.restore(backbone, args.restore_from)
190185

191186
if args.torch_compile:
192-
assert args.model_dtype in ["BFloat16", "Float", "Half"], (
193-
"torch.compile() only supports BFloat16 and Float"
194-
)
195187
print("Compiling backbone with torch.compile()...")
196188
backbone = torch.compile(backbone, mode="max-autotune")
197189

@@ -203,7 +195,7 @@ def main():
203195
pipe.to("cuda")
204196

205197
if args.benchmark:
206-
benchmark_model(pipe, args.prompt, model_dtype=args.model_dtype)
198+
benchmark_model(pipe, args.prompt, model_dtype=model_dtype)
207199

208200
if not args.skip_image:
209201
generate_image(pipe, args.prompt, image_name)

0 commit comments

Comments
 (0)