Skip to content

Commit e0a6efb

Browse files
fix trt engine building of the diffusers pipelines (NVIDIA#637)
## What does this PR do? **Type of change:** Bug fix **Overview:** 1. The diffusion_trt.py needs the dynamic_shapes when running trtexec for engine building. A previous change altered the format of dynamic_shapes, fix it here. 2. the dynamic_shapes logic gets cleaned up. The existing logic is very confusing 3. recover min-batch_size config for some pipelines. Previously some pipelines set the min batch_size to be > 1, which was odd, so a previous change sets them to be 1, but it turns out the oddity has a reason, the trt engine building fails with the altered batch_size min/opt, thus recover them. ## Testing pytest tests/examples/diffusers --------- Signed-off-by: Shengliang Xu <[email protected]> Signed-off-by: Keval Morabia <[email protected]> Co-authored-by: Keval Morabia <[email protected]>
1 parent 422c58b commit e0a6efb

File tree

5 files changed

+43
-47
lines changed

5 files changed

+43
-47
lines changed

examples/diffusers/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@ Install Model Optimizer with `onnx` and `hf` dependencies using `pip` from [PyPI
3939

4040
```bash
4141
pip install nvidia-modelopt[onnx,hf]
42+
pip install -r requirements.txt
4243
```
4344

44-
Each subsection (cache_diffusion, quantization, etc.) have their own `requirements.txt` file that needs to be installed separately.
45+
Each subsection (eval, etc.) may have their own `requirements.txt` file that needs to be installed separately.
4546

4647
You can find the latest TensorRT [here](https://developer.nvidia.com/tensorrt/download).
4748

examples/diffusers/cache_diffusion/requirements.txt

Lines changed: 0 additions & 6 deletions
This file was deleted.

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import torch
2121
from onnx_utils.export import (
22+
_create_trt_dynamic_shapes,
2223
generate_dummy_inputs_and_dynamic_axes_and_shapes,
2324
get_io_shapes,
2425
remove_nesting,
@@ -186,11 +187,13 @@ def main():
186187

187188
if args.torch_compile:
188189
assert args.torch, "Torch mode must be enabled when torch_compile is used"
189-
# Save the backbone of the pipeline and move it to the GPU
190+
# Save the backbone (and other attributes) of the pipeline and move it to the GPU
190191
add_embedding = None
191-
backbone = None
192+
cache_context = None
192193
if hasattr(pipe, "transformer"):
193194
backbone = pipe.transformer
195+
if hasattr(backbone, "cache_context"):
196+
cache_context = backbone.cache_context
194197
elif hasattr(pipe, "unet"):
195198
backbone = pipe.unet
196199
add_embedding = backbone.add_embedding
@@ -234,13 +237,13 @@ def main():
234237
if args.onnx_load_path == "":
235238
update_dynamic_axes(args.model, dynamic_axes)
236239

237-
compilation_args = dynamic_shapes
240+
trt_dynamic_shapes = _create_trt_dynamic_shapes(dynamic_shapes)
238241

239242
# We only need to remove the nesting for SDXL models as they contain the nested input added_cond_kwargs
240243
# which are renamed by the DeviceModel
241244
ignore_nesting = False
242245
if args.onnx_load_path != "" and args.model in ["sdxl-1.0", "sdxl-turbo"]:
243-
remove_nesting(compilation_args)
246+
remove_nesting(trt_dynamic_shapes)
244247
ignore_nesting = True
245248

246249
# Define deployment configuration
@@ -268,6 +271,7 @@ def main():
268271
del backbone
269272
torch.cuda.empty_cache()
270273

274+
compilation_args = {"dynamic_shapes": trt_dynamic_shapes}
271275
if not args.trt_engine_load_path:
272276
# Compile the TRT engine from the exported ONNX model
273277
compiled_model = client.ir_to_compiled(onnx_bytes, compilation_args)
@@ -289,18 +293,18 @@ def main():
289293
compiled_model,
290294
metadata,
291295
compilation_args,
292-
get_io_shapes(args.model, args.onnx_load_path, dynamic_shapes),
296+
get_io_shapes(args.model, args.onnx_load_path, trt_dynamic_shapes),
293297
ignore_nesting,
294298
)
295299

296-
if hasattr(pipe, "unet") and add_embedding:
297-
setattr(device_model, "add_embedding", add_embedding)
298-
299-
# Set the backbone to the device model
300-
if hasattr(pipe, "unet"):
301-
pipe.unet = device_model
302-
elif hasattr(pipe, "transformer"):
300+
# Set the backbone and other attributes to the device model
301+
if hasattr(pipe, "transformer"):
303302
pipe.transformer = device_model
303+
if cache_context:
304+
device_model.cache_context = cache_context
305+
elif hasattr(pipe, "unet"):
306+
pipe.unet = device_model
307+
device_model.add_embedding = add_embedding
304308
else:
305309
raise ValueError("Pipeline does not have a transformer or unet backbone")
306310
pipe.to("cuda")

examples/diffusers/quantization/onnx_utils/export.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -368,38 +368,36 @@ def update_dynamic_axes(model_id, dynamic_axes):
368368
dynamic_axes["out.0"] = dynamic_axes.pop("out_hidden_states")
369369

370370

371-
def _create_dynamic_shapes(dynamic_shapes):
371+
def _create_trt_dynamic_shapes(dynamic_shapes):
372372
min_shapes = {}
373373
opt_shapes = {}
374374
for key, value in dynamic_shapes.items():
375375
min_shapes[key] = value["min"]
376376
opt_shapes[key] = value["opt"]
377377
return {
378-
"dynamic_shapes": {
379-
"minShapes": min_shapes,
380-
"optShapes": opt_shapes,
381-
"maxShapes": opt_shapes,
382-
}
378+
"minShapes": min_shapes,
379+
"optShapes": opt_shapes,
380+
"maxShapes": opt_shapes,
383381
}
384382

385383

386384
def generate_dummy_inputs_and_dynamic_axes_and_shapes(model_id, backbone):
387385
"""Generate dummy inputs, dynamic axes, and dynamic shapes for the given model."""
388386
if model_id in ["sdxl-1.0", "sdxl-turbo"]:
389387
dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_sdxl(
390-
backbone, min_bs=1, opt_bs=16
388+
backbone, min_bs=2, opt_bs=16
391389
)
392390
elif model_id in ["sd3-medium", "sd3.5-medium"]:
393391
dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_sd3(
394-
backbone, min_bs=1, opt_bs=16
392+
backbone, min_bs=2, opt_bs=16
395393
)
396394
elif model_id in ["flux-dev", "flux-schnell"]:
397395
dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_flux(
398-
backbone, min_bs=1, opt_bs=2
396+
backbone, min_bs=1, opt_bs=1
399397
)
400398
elif model_id == "ltx-video-dev":
401399
dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_ltx(
402-
backbone, min_bs=1, opt_bs=2
400+
backbone, min_bs=2, opt_bs=2
403401
)
404402
elif model_id == "wan2.2-t2v-14b":
405403
dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_wan(
@@ -414,7 +412,7 @@ def generate_dummy_inputs_and_dynamic_axes_and_shapes(model_id, backbone):
414412
return dummy_kwargs, dynamic_axes, dynamic_shapes
415413

416414

417-
def get_io_shapes(model_id, onnx_load_path, dynamic_shapes):
415+
def get_io_shapes(model_id, onnx_load_path, trt_dynamic_shapes):
418416
output_name = "out.0"
419417
if onnx_load_path != "":
420418
if model_id in ["sdxl-1.0", "sdxl-turbo"]:
@@ -429,28 +427,28 @@ def get_io_shapes(model_id, onnx_load_path, dynamic_shapes):
429427
raise NotImplementedError(f"Unsupported model_id: {model_id}")
430428

431429
if model_id in ["sdxl-1.0", "sdxl-turbo"]:
432-
io_shapes = {output_name: dynamic_shapes["dynamic_shapes"]["minShapes"]["sample"]}
430+
io_shapes = {output_name: trt_dynamic_shapes["minShapes"]["sample"]}
433431
elif model_id in ["sd3-medium", "sd3.5-medium"]:
434-
io_shapes = {output_name: dynamic_shapes["dynamic_shapes"]["minShapes"]["hidden_states"]}
432+
io_shapes = {output_name: trt_dynamic_shapes["minShapes"]["hidden_states"]}
435433
elif model_id in ["flux-dev", "flux-schnell"]:
436434
io_shapes = {}
437435

438436
return io_shapes
439437

440438

441-
def remove_nesting(dynamic_shapes):
442-
dynamic_shapes["dynamic_shapes"]["minShapes"]["text_embeds"] = dynamic_shapes["dynamic_shapes"][
443-
"minShapes"
444-
].pop("added_cond_kwargs.text_embeds")
445-
dynamic_shapes["dynamic_shapes"]["minShapes"]["time_ids"] = dynamic_shapes["dynamic_shapes"][
446-
"minShapes"
447-
].pop("added_cond_kwargs.time_ids")
448-
dynamic_shapes["dynamic_shapes"]["optShapes"]["text_embeds"] = dynamic_shapes["dynamic_shapes"][
449-
"optShapes"
450-
].pop("added_cond_kwargs.text_embeds")
451-
dynamic_shapes["dynamic_shapes"]["optShapes"]["time_ids"] = dynamic_shapes["dynamic_shapes"][
452-
"optShapes"
453-
].pop("added_cond_kwargs.time_ids")
439+
def remove_nesting(trt_dynamic_shapes):
440+
trt_dynamic_shapes["minShapes"]["text_embeds"] = trt_dynamic_shapes["minShapes"].pop(
441+
"added_cond_kwargs.text_embeds"
442+
)
443+
trt_dynamic_shapes["minShapes"]["time_ids"] = trt_dynamic_shapes["minShapes"].pop(
444+
"added_cond_kwargs.time_ids"
445+
)
446+
trt_dynamic_shapes["optShapes"]["text_embeds"] = trt_dynamic_shapes["optShapes"].pop(
447+
"added_cond_kwargs.text_embeds"
448+
)
449+
trt_dynamic_shapes["optShapes"]["time_ids"] = trt_dynamic_shapes["optShapes"].pop(
450+
"added_cond_kwargs.time_ids"
451+
)
454452

455453

456454
def save_onnx(onnx_model, output):
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
cuda-python
2-
diffusers
1+
cuda-python<13
32
nvtx
43
opencv-python>=4.8.1.78,<4.12.0.88
54
sentencepiece

0 commit comments

Comments
 (0)