Skip to content

Commit 2290533

Browse files
authored
[NVBUG: 5619158] Enforce high precision model dtype for diffusion trt (#526)
## What does this PR do? **Type of change:** Minor code change **Overview:** - Select the high precision dtype directly based on model type - FP16 for Stable Diffusion models, BF16 for Flux ## Testing ```python python diffusion_trt.py --model flux-dev --benchmark ``` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: No (No option to specify dtype while loading pipeline) - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes --------- Signed-off-by: ajrasane <[email protected]>
1 parent 5adb9ba commit 2290533

File tree

3 files changed

+20
-21
lines changed

3 files changed

+20
-21
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Model Optimizer Changelog (Linux)
3030
- Add support for multi-node PTQ and export with FSDP2 in ``examples/llm_ptq/multinode_ptq.py``. See `examples/llm_ptq/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/llm_ptq#multi-node-post-training-quantization-with-fsdp2>`_ for more details.
3131
- Add support for Nemotron Nano VL v1 & v2 models in FP8/NVFP4 PTQ workflow.
3232
- Add flags ``nodes_to_include`` and ``op_types_to_include`` in AutoCast to force-include nodes in low precision, even if they would otherwise be excluded by other rules.
33+
- Add support for ``torch.compile`` and benchmarking in ``examples/diffusers/quantization/diffusion_trt.py``.
3334

3435
**Documentation**
3536

examples/diffusers/README.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,11 +307,16 @@ Generate images for the quantized checkpoint with the following [Script](./quant
307307
python diffusion_trt.py \
308308
--model {sdxl-1.0|sdxl-turbo|sd3-medium|flux-dev} \
309309
--prompt "A cat holding a sign that says hello world" \
310+
[--override-model-path /path/to/model] \
310311
[--restore-from ./{MODEL}_fp8.pt] \
311312
[--onnx-load-path {ONNX_DIR}] \
312313
[--trt-engine-load-path {ENGINE_DIR}] \
313-
[--dq_only] \
314-
[--torch]
314+
[--dq-only] \
315+
[--torch] \
316+
[--save-image-as /path/to/image] \
317+
[--benchmark] \
318+
[--torch-compile] \
319+
[--skip-image]
315320
```
316321

317322
This script will save the output image as `./{MODEL}.png` and report the latency of the TensorRT backbone.

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@
4040
"flux-schnell": ModelType.FLUX_SCHNELL,
4141
}
4242

43-
dtype_map = {
44-
"Half": torch.float16,
45-
"BFloat16": torch.bfloat16,
46-
"Float": torch.float32,
43+
DTYPE_MAP = {
44+
"sdxl-1.0": torch.float16,
45+
"sdxl-turbo": torch.float16,
46+
"sd3-medium": torch.float16,
47+
"flux-dev": torch.bfloat16,
48+
"flux-schnell": torch.bfloat16,
4749
}
4850

4951

@@ -60,7 +62,7 @@ def generate_image(pipe, prompt, image_name):
6062

6163

6264
def benchmark_model(
63-
pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20, model_dtype="Half"
65+
pipe, prompt, num_warmup=10, num_runs=50, num_inference_steps=20, model_dtype=torch.float16
6466
):
6567
"""Benchmark the backbone model inference time."""
6668
backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
@@ -83,7 +85,7 @@ def forward_hook(_module, _input, _output):
8385
try:
8486
print(f"Starting warmup: {num_warmup} runs")
8587
for _ in tqdm(range(num_warmup), desc="Warmup"):
86-
with torch.amp.autocast("cuda", dtype=dtype_map[model_dtype]):
88+
with torch.amp.autocast("cuda", dtype=model_dtype):
8789
_ = pipe(
8890
prompt,
8991
output_type="pil",
@@ -95,7 +97,7 @@ def forward_hook(_module, _input, _output):
9597

9698
print(f"Starting benchmark: {num_runs} runs")
9799
for _ in tqdm(range(num_runs), desc="Benchmark"):
98-
with torch.amp.autocast("cuda", dtype=dtype_map[model_dtype]):
100+
with torch.amp.autocast("cuda", dtype=model_dtype):
99101
_ = pipe(
100102
prompt,
101103
output_type="pil",
@@ -126,13 +128,6 @@ def main():
126128
default=None,
127129
help="Path to the model if not using default paths in MODEL_ID mapping.",
128130
)
129-
parser.add_argument(
130-
"--model-dtype",
131-
type=str,
132-
default="Half",
133-
choices=["Half", "BFloat16", "Float"],
134-
help="Precision used to load the model.",
135-
)
136131
parser.add_argument(
137132
"--restore-from", type=str, default=None, help="Path to the modelopt quantized checkpoint"
138133
)
@@ -167,10 +162,11 @@ def main():
167162
args = parser.parse_args()
168163

169164
image_name = args.save_image_as if args.save_image_as else f"{args.model}.png"
165+
model_dtype = DTYPE_MAP[args.model]
170166

171167
pipe = PipelineManager.create_pipeline_from(
172168
MODEL_ID[args.model],
173-
dtype_map[args.model_dtype],
169+
torch_dtype=model_dtype,
174170
override_model_path=args.override_model_path,
175171
)
176172

@@ -189,9 +185,6 @@ def main():
189185
mto.restore(backbone, args.restore_from)
190186

191187
if args.torch_compile:
192-
assert args.model_dtype in ["BFloat16", "Float", "Half"], (
193-
"torch.compile() only supports BFloat16 and Float"
194-
)
195188
print("Compiling backbone with torch.compile()...")
196189
backbone = torch.compile(backbone, mode="max-autotune")
197190

@@ -203,7 +196,7 @@ def main():
203196
pipe.to("cuda")
204197

205198
if args.benchmark:
206-
benchmark_model(pipe, args.prompt, model_dtype=args.model_dtype)
199+
benchmark_model(pipe, args.prompt, model_dtype=model_dtype)
207200

208201
if not args.skip_image:
209202
generate_image(pipe, args.prompt, image_name)

0 commit comments

Comments
 (0)