Skip to content

Commit 2a3f7cf

Browse files
nvidia-modelopt 0.15.0 examples release
1 parent 54cd4a3 commit 2a3f7cf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2185
-772
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
## Model Optimizer Overview
3939

40-
Minimizing inference costs presents a significant challenge as generative AI models continue to grow in complexity and size. The **NVIDIA TensorRT Model Optimizer** (referred to as **Model Optimizer**, or **ModelOpt**) is a library comprising state-of-the-art model optimization techniques including [quantization](#quantization) and [sparsity](#sparsity) to compress model. It accepts a torch or [ONNX](https://github.com/onnx/onnx) model as inputs and provides Python APIs for users to easily stack different model optimization techniques to produce quantized checkpoint. Seamlessly integrated within the NVIDIA AI software ecosystem, the quantized checkpoint generated from Model Optimizer is ready for deployment in downstream inference frameworks like [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/quantization) or [TensorRT](https://github.com/NVIDIA/TensorRT). Further integrations are planned for [NVIDIA NeMo](https://github.com/NVIDIA/NeMo) and [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) for training-in-the-loop optimization techniques. For enterprise users, the 8-bit quantization with Stable Diffusion is also available on [NVIDIA NIM](https://developer.nvidia.com/blog/nvidia-nim-offers-optimized-inference-microservices-for-deploying-ai-models-at-scale/).
40+
Minimizing inference costs presents a significant challenge as generative AI models continue to grow in complexity and size. The **NVIDIA TensorRT Model Optimizer** (referred to as **Model Optimizer**, or **ModelOpt**) is a library comprising state-of-the-art model optimization techniques including [quantization](#quantization) and [sparsity](#sparsity) to compress models. It accepts a torch or [ONNX](https://github.com/onnx/onnx) model as inputs and provides Python APIs for users to easily stack different model optimization techniques to produce an optimized quantized checkpoint. Seamlessly integrated within the NVIDIA AI software ecosystem, the quantized checkpoint generated from Model Optimizer is ready for deployment in downstream inference frameworks like [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/quantization) or [TensorRT](https://github.com/NVIDIA/TensorRT). Further integrations are planned for [NVIDIA NeMo](https://github.com/NVIDIA/NeMo) and [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) for training-in-the-loop optimization techniques. For enterprise users, the 8-bit quantization with Stable Diffusion is also available on [NVIDIA NIM](https://developer.nvidia.com/blog/nvidia-nim-offers-optimized-inference-microservices-for-deploying-ai-models-at-scale/).
4141

4242
Model Optimizer is available for free for all developers on [NVIDIA PyPI](https://pypi.org/project/nvidia-modelopt/). This repository is for sharing examples and GPU-optimized recipes as well as collecting feedback from the community.
4343

@@ -46,7 +46,7 @@ Model Optimizer is available for free for all developers on [NVIDIA PyPI](https:
4646
### [PIP](https://pypi.org/project/nvidia-modelopt/)
4747

4848
```bash
49-
pip install "nvidia-modelopt[all]~=0.13.0" --extra-index-url https://pypi.nvidia.com
49+
pip install "nvidia-modelopt[all]~=0.15.0" --extra-index-url https://pypi.nvidia.com
5050
```
5151

5252
See the [installation guide](https://nvidia.github.io/TensorRT-Model-Optimizer/getting_started/2_installation.html) for more fine-grained control over the installation.
@@ -68,7 +68,7 @@ docker run --gpus all -it --shm-size 20g --rm docker.io/library/modelopt_example
6868
python -c "import modelopt"
6969
```
7070

71-
Alternatively for PyTorch, you can also use [NVIDIA NGC PyTorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags) with Model Optimizer pre-installed starting from 24.06 PyTorch container. Make sure to update the Model Optimizer version to the latest one if not already.
71+
Alternatively for PyTorch, you can also use [NVIDIA NGC PyTorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags) with Model Optimizer pre-installed starting from 24.06 container. Make sure to update the Model Optimizer version to the latest one if not already.
7272

7373
## Techniques
7474

benchmark.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,18 @@ performance** that can be delivered by Model Optimizer. All performance numbers
88

99
#### 1.1 Performanace
1010

11-
Config: H100, nvidia-modelopt v0.11.0, TensorRT-LLM v0.9, latency measured with full batch inference (no inflight batching).
11+
Config: H100, nvidia-modelopt v0.15.0, TensorRT-LLM v0.11, latency measured with full batch inference (no inflight batching).
1212
Memory saving and inference speedup are compared to the FP16 baseline. Speedup is normalized to the GPU count.
1313

1414
| | | | FP8 | | | | INT4 AWQ | |
1515
|:----------:|:----------:|:----------:|:----------:|:-------:|:-:|:----------:|:----------:|:-------:|
1616
| Model | Batch Size | Mem Saving | Tokens/sec | Speedup | | Mem Saving | Tokens/sec | Speedup |
17-
| Llama3-8B | 2 | 1.66x | 337.67 | 1.39x | | 2.37x | 392.99 | 1.61x |
18-
| | 32 | 1.56x | 2368.69 | 1.66x | | 1.86x | 2037.54 | 1.43x |
19-
| | 64 | 1.54x | 2404.86 | 1.43x | | 1.76x | 2308.57 | 1.37x |
20-
| Llama3-70B | 2 | 1.98x | 64.35 | 2.11x | | 3.49x | 77.36 | 2.54x |
21-
| | 32 | 1.95x | 391.73 | 3.03x | | 2.94x | 479.11 | 3.71x |
22-
| | 64 | 1.91x | 383.42 | 2.41x | | 2.46x | 348.65 | 2.19x |
17+
| Llama3-8B | 1 | 1.63x | 175.42 | 1.26x | | 2.34x | 213.45 | 1.53x |
18+
| | 32 | 1.62x | 3399.84 | 1.49x | | 1.89x | 2546.12 | 1.11x |
19+
| | 64 | 1.58x | 3311.03 | 1.34x | | 1.97x | 3438.08 | 1.39x |
20+
| Llama3-70B | 1 | 1.96x | 32.85 | 1.87x | | 3.47x | 47.49 | 2.70x |
21+
| | 32 | 1.93x | 462.69 | 1.82x | | 2.62x | 365.06 | 1.44x |
22+
| | 64 | 1.99x | 449.09 | 1.91x | | 2.90x | 483.51 | 2.05x |
2323

2424
### 1.2 Accuracy
2525

diffusers/cache_diffusion/README.md

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
# Cache Diffusion
22

3-
## News
4-
5-
- [Utilizing DeepCache to Accelerate Stable Diffusion-XL Benchmarks in MLPerf Yields Leading Results](https://developer.nvidia.com/blog/nvidia-h200-tensor-core-gpus-and-nvidia-tensorrt-llm-set-mlperf-llm-inference-records/)
6-
73
## Introduction
84

95
| Supported Framework | Supported Models |
106
|----------|----------|
117
| **PyTorch** | [**PixArt-α**](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS), [**Stable Diffusion - XL**](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), [**SVD**](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt) |
12-
| **TensorRT** | **WIP** |
8+
| **TensorRT** | [**Stable Diffusion - XL**](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) |
139

1410
Cache Diffusion methods, such as [DeepCache](https://arxiv.org/abs/2312.00858), [Block Caching](https://arxiv.org/abs/2312.03209) and [T-Gate](https://arxiv.org/abs/2404.02747), optimize performance by reusing cached outputs from previous steps instead of recalculating them. This **training-free** caching approach is compatible with a variety of models, like **DiT** and **UNet**, enabling considerable acceleration without compromising quality.
1511

@@ -52,9 +48,37 @@ Two parameters are essential: `wildcard_or_filter_func` and `select_cache_step_f
5248

5349
Multiple configurations can be set up, but ensure that the `wildcard_or_filter_func` works correctly. If you input more than one pair of parameters with the same `wildcard_or_filter_func`, the later one in the list will overwrite the previous ones.
5450

55-
## Demo
51+
### TensorRT support
52+
53+
#### Quick Start
54+
55+
Install [TensorRT](https://developer.nvidia.com/tensorrt) then run:
56+
57+
```bash
58+
python run_cache_diffusion.py
59+
```
60+
61+
You can find the latest TensorRT at [here](https://developer.nvidia.com/tensorrt/download).
62+
63+
To execute cache diffusion in TensorRT, follow these steps:
5664

57-
The following demo images are generated using `PyTorch==2.3.0 with 1xAda 6000 GPU backend`. TensorRT support will be available in the next ModelOPT release.
65+
```python
66+
# Load the model
67+
68+
compile(
69+
pipe.unet,
70+
onnx_path=Path("./onnx"),
71+
engine_path=Path("./engine"),
72+
)
73+
74+
cachify.prepare(pipe, num_inference_steps, SDXL_DEFAULT_CONFIG)
75+
```
76+
77+
Afterward, use it as a standard cache diffusion pipeline to generate the image.
78+
79+
Please note that only the UNET component is running in TensorRT, while the other parts remain in PyTorch.
80+
81+
## Demo
5882

5983
Comparing with naively reducing the generation steps, cache diffusion can achieve the same speedup and also much better image quality, even close to the reference image. If the image quality does not meet your needs or product requirements, you can replace our default configuration with your customized settings.
6084

diffusers/cache_diffusion/cache_diffusion/cachify.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,21 @@
2121

2222
import fnmatch
2323

24-
from diffusers.models.attention import FeedForward
25-
from diffusers.models.attention_processor import Attention
26-
from diffusers.models.resnet import ResnetBlock2D, TemporalResnetBlock
24+
from diffusers.models.attention import BasicTransformerBlock
25+
from diffusers.models.unets.unet_2d_blocks import (
26+
CrossAttnDownBlock2D,
27+
CrossAttnUpBlock2D,
28+
DownBlock2D,
29+
UNetMidBlock2DCrossAttn,
30+
UpBlock2D,
31+
)
32+
from diffusers.models.unets.unet_3d_blocks import (
33+
CrossAttnDownBlockSpatioTemporal,
34+
CrossAttnUpBlockSpatioTemporal,
35+
DownBlockSpatioTemporal,
36+
UNetMidBlockSpatioTemporal,
37+
UpBlockSpatioTemporal,
38+
)
2739
from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import PixArtAlphaPipeline
2840
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
2941
StableDiffusionXLPipeline,
@@ -35,15 +47,37 @@
3547
from .module import CachedModule
3648
from .utils import replace_module
3749

38-
SUPPORTED_METHODS = {PixArtAlphaPipeline, StableDiffusionXLPipeline, StableVideoDiffusionPipeline}
39-
40-
41-
def cachify(model, num_inference_steps, config_list):
50+
CACHED_PIPE = {
51+
StableDiffusionXLPipeline: (
52+
DownBlock2D,
53+
CrossAttnDownBlock2D,
54+
UNetMidBlock2DCrossAttn,
55+
CrossAttnUpBlock2D,
56+
UpBlock2D,
57+
),
58+
PixArtAlphaPipeline: (BasicTransformerBlock),
59+
StableVideoDiffusionPipeline: (
60+
CrossAttnDownBlockSpatioTemporal,
61+
DownBlockSpatioTemporal,
62+
UpBlockSpatioTemporal,
63+
CrossAttnUpBlockSpatioTemporal,
64+
UNetMidBlockSpatioTemporal,
65+
),
66+
}
67+
68+
69+
def cachify(model, num_inference_steps, config_list, modules):
70+
if hasattr(model, "use_trt_infer") and model.use_trt_infer:
71+
for key, _ in model.engines.items():
72+
for config in config_list:
73+
if _pass(key, config["wildcard_or_filter_func"]):
74+
model.engines[key] = CachedModule(
75+
model.engines[key], num_inference_steps, config["select_cache_step_func"]
76+
)
77+
return
4278
for name, module in model.named_modules():
4379
for config in config_list:
44-
if _pass(name, config["wildcard_or_filter_func"]) and isinstance(
45-
module, (Attention, ResnetBlock2D, TemporalResnetBlock, FeedForward)
46-
):
80+
if _pass(name, config["wildcard_or_filter_func"]) and isinstance(module, modules):
4781
replace_module(
4882
model,
4983
name,
@@ -86,8 +120,8 @@ def get_model(pipe):
86120

87121

88122
def prepare(pipe, num_inference_steps, config_list):
89-
assert pipe.__class__ in SUPPORTED_METHODS, f"{pipe.__class__} is not supported!"
123+
assert pipe.__class__ in CACHED_PIPE.keys(), f"{pipe.__class__} is not supported!"
90124

91125
model = get_model(pipe)
92126

93-
cachify(model, num_inference_steps, config_list)
127+
cachify(model, num_inference_steps, config_list, CACHED_PIPE[pipe.__class__])
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: MIT
3+
#
4+
# Permission is hereby granted, free of charge, to any person obtaining a
5+
# copy of this software and associated documentation files (the "Software"),
6+
# to deal in the Software without restriction, including without limitation
7+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8+
# and/or sell copies of the Software, and to permit persons to whom the
9+
# Software is furnished to do so, subject to the following conditions:
10+
#
11+
# The above copyright notice and this permission notice shall be included in
12+
# all copies or substantial portions of the Software.
13+
#
14+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20+
# DEALINGS IN THE SOFTWARE.
21+
22+
SDXL_ONNX_CONFIG = {
23+
"down_blocks.0": {
24+
"dummy_input": {
25+
"hidden_states": (2, 320, 128, 128),
26+
"temb": (2, 1280),
27+
},
28+
"output_names": ["sample", "res_samples_0", "res_samples_1", "res_samples_2"],
29+
"dynamic_axes": {
30+
"hidden_states": {0: "batch_size"},
31+
"temb": {0: "steps"},
32+
},
33+
},
34+
"down_blocks.1": {
35+
"dummy_input": {
36+
"hidden_states": (2, 320, 64, 64),
37+
"temb": (2, 1280),
38+
"encoder_hidden_states": (2, 77, 2048),
39+
},
40+
"output_names": ["sample", "res_samples_0", "res_samples_1", "res_samples_2"],
41+
"dynamic_axes": {
42+
"hidden_states": {0: "batch_size"},
43+
"temb": {0: "steps"},
44+
"encoder_hidden_states": {0: "batch_size"},
45+
},
46+
},
47+
"down_blocks.2": {
48+
"dummy_input": {
49+
"hidden_states": (2, 640, 32, 32),
50+
"temb": (2, 1280),
51+
"encoder_hidden_states": (2, 77, 2048),
52+
},
53+
"output_names": ["sample", "res_samples_0", "res_samples_1"],
54+
"dynamic_axes": {
55+
"hidden_states": {0: "batch_size"},
56+
"temb": {0: "steps"},
57+
"encoder_hidden_states": {0: "batch_size"},
58+
},
59+
},
60+
"mid_block": {
61+
"dummy_input": {
62+
"hidden_states": (2, 1280, 32, 32),
63+
"temb": (2, 1280),
64+
"encoder_hidden_states": (2, 77, 2048),
65+
},
66+
"output_names": ["sample"],
67+
"dynamic_axes": {
68+
"hidden_states": {0: "batch_size"},
69+
"temb": {0: "steps"},
70+
"encoder_hidden_states": {0: "batch_size"},
71+
},
72+
},
73+
"up_blocks.0": {
74+
"dummy_input": {
75+
"hidden_states": (2, 1280, 32, 32),
76+
"res_hidden_states_0": (2, 640, 32, 32),
77+
"res_hidden_states_1": (2, 1280, 32, 32),
78+
"res_hidden_states_2": (2, 1280, 32, 32),
79+
"temb": (2, 1280),
80+
"encoder_hidden_states": (2, 77, 2048),
81+
},
82+
"output_names": ["sample"],
83+
"dynamic_axes": {
84+
"hidden_states": {0: "batch_size"},
85+
"temb": {0: "steps"},
86+
"encoder_hidden_states": {0: "batch_size"},
87+
"res_hidden_states_0": {0: "batch_size"},
88+
"res_hidden_states_1": {0: "batch_size"},
89+
"res_hidden_states_2": {0: "batch_size"},
90+
},
91+
},
92+
"up_blocks.1": {
93+
"dummy_input": {
94+
"hidden_states": (2, 1280, 64, 64),
95+
"res_hidden_states_0": (2, 320, 64, 64),
96+
"res_hidden_states_1": (2, 640, 64, 64),
97+
"res_hidden_states_2": (2, 640, 64, 64),
98+
"temb": (2, 1280),
99+
"encoder_hidden_states": (2, 77, 2048),
100+
},
101+
"output_names": ["sample"],
102+
"dynamic_axes": {
103+
"hidden_states": {0: "batch_size"},
104+
"temb": {0: "steps"},
105+
"encoder_hidden_states": {0: "batch_size"},
106+
"res_hidden_states_0": {0: "batch_size"},
107+
"res_hidden_states_1": {0: "batch_size"},
108+
"res_hidden_states_2": {0: "batch_size"},
109+
},
110+
},
111+
"up_blocks.2": {
112+
"dummy_input": {
113+
"hidden_states": (2, 640, 128, 128),
114+
"res_hidden_states_0": (2, 320, 128, 128),
115+
"res_hidden_states_1": (2, 320, 128, 128),
116+
"res_hidden_states_2": (2, 320, 128, 128),
117+
"temb": (2, 1280),
118+
},
119+
"output_names": ["sample"],
120+
"dynamic_axes": {
121+
"hidden_states": {0: "batch_size"},
122+
"temb": {0: "steps"},
123+
"res_hidden_states_0": {0: "batch_size"},
124+
"res_hidden_states_1": {0: "batch_size"},
125+
"res_hidden_states_2": {0: "batch_size"},
126+
},
127+
},
128+
}

0 commit comments

Comments
 (0)