Skip to content

Commit ac0beba

Browse files
committed
[update] build TRT code
1 parent 4ff5ec2 commit ac0beba

File tree

9 files changed

+339
-50
lines changed

9 files changed

+339
-50
lines changed

examples/flux-tensorrt/README.md

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,48 +12,49 @@ This project provides **TensorRT-accelerated pipelines** for Flux models, enabli
1212

1313
---
1414

15-
## ⚙️ Building Flux with TensorRT
1615

17-
We follow the official [NVIDIA/TensorRT](https://github.com/NVIDIA/TensorRT) repository to build TensorRT.
18-
19-
> **Note:**
20-
> TensorRT was originally built with `diffusers==0.31.1`.
21-
> Currently, we recommend using:
22-
> - one **venv** for building, and
23-
> - another **venv** for inference.
24-
25-
(🔜 TODO: Build scripts for the latest `diffusers` will be added later.)
26-
27-
### Installation
16+
## Installation
2817
```bash
29-
git clone https://github.com/NVIDIA/TensorRT
30-
cd TensorRT/demo/Diffusion
31-
32-
pip install tensorrt-cu12==10.13.2.6
18+
cd diffusers/examples/flux-tensorrt
3319
pip install -r requirements.txt
3420
```
3521

36-
### ⚡ Fast Building with Static Shapes
37-
```bash
38-
# BF16
39-
python3 demo_txt2img_flux.py "a beautiful photograph of Mt. Fuji during cherry blossom" --hf-token=$HF_TOKEN --bf16 --download-onnx-models
22+
## ⚙️ Build Flux with TensorRT
4023

41-
# FP8
42-
python3 demo_txt2img_flux.py "a beautiful photograph of Mt. Fuji during cherry blossom" --hf-token=$HF_TOKEN --quantization-level 4 --fp8 --download-onnx-models
24+
Before building, make sure you have the ONNX checkpoints ready.
25+
You can either download the official [Flux ONNX](https://huggingface.co/black-forest-labs/FLUX.1-dev-onnx) checkpoints from Hugging Face, or export your own.
4326

44-
# FP4
45-
python3 demo_txt2img_flux.py "a beautiful photograph of Mt. Fuji during cherry blossom" --hf-token=$HF_TOKEN --fp4 --download-onnx-models
27+
```bash
28+
huggingface-cli download black-forest-labs/FLUX.1-dev-onnx --local-dir onnx
4629
```
4730

48-
- To build with dynamic shape, add: `--build-dynamic-shape`.
49-
- To build with static batch, add `--build-static-batch`.
31+
Build each component individually. For example, to build the **Transformer engine**:
32+
```python
33+
from module.transformers import FluxTransformerModel
34+
35+
engine_path = "checkpoints_trt/transformer/engine.plan"
36+
engine_transformer = FluxTransformerModel(engine_path=engine_path,build=True)
37+
38+
# Build tranformer engine
39+
transformer_input_profile = engine_transformer.get_input_profile(
40+
opt_batch_size=1,
41+
opt_image_height=1024,
42+
opt_image_width=1024,
43+
static_batch = True,
44+
dynamic_shape= True
45+
)
46+
engine_transformer.build(
47+
onnx_path="onnx/transformer.opt/bf16/model.onnx", #Replace your onnx path
48+
input_profile=transformer_input_profile,
49+
)
50+
```
5051

51-
ℹ️ For more details, run:
52-
`python demo_txt2img_flux.py --help`
52+
You can convert all ONNX checkpoints to TensorRT engines with a single command:
53+
```bash
54+
python convert_trt.py
55+
```
5356

5457
## 🖼️ Inference with Flux TensorRT
55-
Create a new venv (or update diffusers, peft in your existing one), then run fast inference using TensorRT engines.
56-
5758
Example: Full Pipeline with All Engines
5859

5960
```python
@@ -68,10 +69,10 @@ from module.clip import CLIPModel
6869
import time
6970

7071
# Local path for each engine
71-
engine_transformer_path = "path/to/transformer/engine_trt10.13.2.6.plan"
72-
engine_vae_path = "path/to/vae/engine_trt10.13.2.6.plan"
73-
engine_t5xxl_path = "path/to/t5/engine_trt10.13.2.6.plan"
74-
engine_clip_path = "path/to/clip/engine_trt10.13.2.6.plan"
72+
engine_transformer_path = "checkpoints_trt/transformer/engine.plan"
73+
engine_vae_path = "checkpoints_trt/vae/engine.plan"
74+
engine_t5xxl_path = "checkpoints_trt/t5/engine.plan"
75+
engine_clip_path = "checkpoints_trt/clip/engine.plan"
7576

7677
# Create stream for each engine
7778
stream = cudart.cudaStreamCreate()[1]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from module.transformers import FluxTransformerModel
2+
from module.clip import CLIPModel
3+
from module.t5xxl import T5XXLModel
4+
from module.vae import VAEModel
5+
6+
models_config = {
7+
"transformer": (FluxTransformerModel, "onnx/transformer.opt/bf16/model.onnx"),
8+
"clip": (CLIPModel, "onnx/clip.opt/model.onnx"),
9+
"t5": (T5XXLModel, "onnx/t5.opt/model.onnx"),
10+
"vae": (VAEModel, "onnx/vae.opt/model.onnx"),
11+
}
12+
13+
engines = {}
14+
15+
for name, (ModelClass, onnx_path) in models_config.items():
16+
engine_path = f"checkpoints_trt/{name}/engine.plan"
17+
engine = ModelClass(engine_path=engine_path, build=True)
18+
19+
input_profile = engine.get_input_profile(
20+
opt_batch_size=1,
21+
opt_image_height=1024,
22+
opt_image_width=1024,
23+
static_batch=True,
24+
dynamic_shape=True,
25+
)
26+
27+
engine.build(onnx_path=onnx_path, input_profile=input_profile)
28+
engines[name] = engine
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from pipeline_flux_trt import FluxPipelineTRT
2+
from cuda import cudart
3+
import torch
4+
5+
from module.transformers import FluxTransformerModel
6+
from module.vae import VAEModel
7+
from module.t5xxl import T5XXLModel
8+
from module.clip import CLIPModel
9+
import time
10+
11+
# Local path for each engine
12+
engine_transformer_path = "checkpoints_trt/transformer/engine.plan"
13+
engine_vae_path = "checkpoints_trt/vae/engine.plan"
14+
engine_t5xxl_path = "checkpoints_trt/t5/engine.plan"
15+
engine_clip_path = "checkpoints_trt/clip/engine.plan"
16+
17+
# Create stream for each engine
18+
stream = cudart.cudaStreamCreate()[1]
19+
20+
# Create engine for each model
21+
engine_transformer = FluxTransformerModel(engine_transformer_path, stream)
22+
engine_vae = VAEModel(engine_vae_path, stream)
23+
engine_t5xxl = T5XXLModel(engine_t5xxl_path, stream)
24+
engine_clip = CLIPModel(engine_clip_path, stream)
25+
26+
# Create pipeline
27+
pipeline = FluxPipelineTRT.from_pretrained(
28+
"black-forest-labs/FLUX.1-dev",
29+
torch_dtype=torch.bfloat16,
30+
engine_transformer=engine_transformer,
31+
engine_vae=engine_vae,
32+
engine_text_encoder=engine_clip,
33+
engine_text_encoder_2= engine_t5xxl,
34+
)
35+
pipeline.to("cuda")
36+
37+
38+
prompt = "A cat holding a sign that says hello world"
39+
generator = torch.Generator(device="cuda").manual_seed(42)
40+
image = pipeline(prompt, num_inference_steps=28, guidance_scale=3.0, generator=generator).images[0]
41+
42+
43+
image.save("test_pipeline.png")
Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from .engine import Engine
22

33
class CLIPModel(Engine):
4-
def __init__(self, engine_path: str, stream = None):
4+
def __init__(self, engine_path: str, stream = None, build = False):
55
super().__init__(engine_path, stream)
66
self.text_maxlen = 77
77
self.embedding_dim = 768
88
self.keep_pooled_output = True
99

10-
# Load engine before
11-
self.load_engine()
10+
if not build:
11+
# Load engine before
12+
self.load_engine()
1213

1314
def get_shape_dict(self, batch_size, image_height, image_width):
1415
self.check_dims(batch_size, image_height, image_width)
@@ -18,4 +19,14 @@ def get_shape_dict(self, batch_size, image_height, image_width):
1819
}
1920
if self.keep_pooled_output:
2021
output["pooled_embeddings"] = (batch_size, self.embedding_dim)
21-
return output
22+
return output
23+
24+
def get_input_profile(self, opt_batch_size=1, opt_image_height=1024, opt_image_width=1024, min_batch=1, max_batch=8, min_height=512, max_height=1280, min_width=512, max_width=1280, static_batch=True, dynamic_shape=True):
25+
min_batch = opt_batch_size if static_batch else min_batch
26+
max_batch = opt_batch_size if static_batch else max_batch
27+
28+
return {
29+
"input_ids": [(min_batch, self.text_maxlen), (opt_batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
30+
}
31+
32+

examples/flux-tensorrt/module/engine.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from cuda import cudart
33
import torch
44
import tensorrt as trt
5-
5+
import subprocess
6+
from collections import defaultdict
67
from collections import OrderedDict
78
from polygraphy.backend.common import bytes_from_path
89
from polygraphy.backend.trt import engine_from_bytes
@@ -78,6 +79,124 @@ def check_dims(self, batch_size, image_height, image_width, compression_factor =
7879
assert latent_width >= min_latent_shape and latent_width <= max_latent_shape
7980
return (latent_height, latent_width)
8081

82+
def build(
83+
self,
84+
onnx_path,
85+
strongly_typed=False,
86+
fp16=False,
87+
bf16=True,
88+
tf32=True,
89+
int8=False,
90+
fp8=False,
91+
input_profile=None,
92+
enable_refit=False,
93+
enable_all_tactics=False,
94+
timing_cache=None,
95+
update_output_names=None,
96+
native_instancenorm=True,
97+
verbose=False,
98+
weight_streaming=False,
99+
builder_optimization_level=3,
100+
precision_constraints='none',
101+
):
102+
print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
103+
104+
# Handle weight streaming case: https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#streaming-weights.
105+
if weight_streaming:
106+
strongly_typed, fp16, bf16, int8, fp8 = True, False, False, False, False
107+
108+
# Base command
109+
build_command = [f"polygraphy convert {onnx_path} --convert-to trt --output {self.engine_path}"]
110+
111+
# Precision flags
112+
build_args = [
113+
"--fp16" if fp16 else "",
114+
"--bf16" if bf16 else "",
115+
"--tf32" if tf32 else "",
116+
"--fp8" if fp8 else "",
117+
"--int8" if int8 else "",
118+
"--strongly-typed" if strongly_typed else "",
119+
]
120+
121+
# Additional arguments
122+
build_args.extend([
123+
"--weight-streaming" if weight_streaming else "",
124+
"--refittable" if enable_refit else "",
125+
"--tactic-sources" if not enable_all_tactics else "",
126+
"--onnx-flags native_instancenorm" if native_instancenorm else "",
127+
f"--builder-optimization-level {builder_optimization_level}",
128+
f"--precision-constraints {precision_constraints}",
129+
])
130+
131+
# Timing cache
132+
if timing_cache:
133+
build_args.extend([
134+
f"--load-timing-cache {timing_cache}",
135+
f"--save-timing-cache {timing_cache}"
136+
])
137+
138+
# Verbosity setting
139+
verbosity = "extra_verbose" if verbose else "error"
140+
build_args.append(f"--verbosity {verbosity}")
141+
142+
# Output names
143+
if update_output_names:
144+
print(f"Updating network outputs to {update_output_names}")
145+
# build_args.append(f"--trt-outputs {' '.join(update_output_names)}")
146+
build_args.append(f"--trt-outputs {update_output_names}")
147+
148+
# Input profiles
149+
if input_profile:
150+
profile_args = defaultdict(str)
151+
for name, dims in input_profile.items():
152+
assert len(dims) == 3
153+
profile_args["--trt-min-shapes"] += f"{name}:{str(list(dims[0])).replace(' ', '')} "
154+
profile_args["--trt-opt-shapes"] += f"{name}:{str(list(dims[1])).replace(' ', '')} "
155+
profile_args["--trt-max-shapes"] += f"{name}:{str(list(dims[2])).replace(' ', '')} "
156+
157+
build_args.extend(f"{k} {v}" for k, v in profile_args.items())
158+
159+
# Filter out empty strings and join command
160+
build_args = [arg for arg in build_args if arg]
161+
final_command = ' '.join(build_command + build_args)
162+
163+
# Execute command with improved error handling
164+
try:
165+
print(f"Engine build command: {final_command}")
166+
subprocess.run(final_command, check=True, shell=True)
167+
except subprocess.CalledProcessError as exc:
168+
error_msg = (
169+
f"Failed to build TensorRT engine. Error details:\n"
170+
f"Command: {exc.cmd}\n"
171+
)
172+
raise RuntimeError(error_msg) from exc
173+
174+
def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape, compression_factor=8, min_batch=1, max_batch=8, min_image_shape=256, max_image_shape=1344, min_latent_shape=16, max_latent_shape=1024):
175+
min_batch = batch_size if static_batch else self.min_batch
176+
max_batch = batch_size if static_batch else self.max_batch
177+
latent_height = image_height // compression_factor
178+
latent_width = image_width // compression_factor
179+
min_image_height = image_height if static_shape else min_image_shape
180+
max_image_height = image_height if static_shape else max_image_shape
181+
min_image_width = image_width if static_shape else min_image_shape
182+
max_image_width = image_width if static_shape else max_image_shape
183+
min_latent_height = latent_height if static_shape else min_latent_shape
184+
max_latent_height = latent_height if static_shape else max_latent_shape
185+
min_latent_width = latent_width if static_shape else min_latent_shape
186+
max_latent_width = latent_width if static_shape else max_latent_shape
187+
return (
188+
min_batch,
189+
max_batch,
190+
min_image_height,
191+
max_image_height,
192+
min_image_width,
193+
max_image_width,
194+
min_latent_height,
195+
max_latent_height,
196+
min_latent_width,
197+
max_latent_width,
198+
)
199+
81200

82201

83202

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
11
from .engine import Engine
22

33
class T5XXLModel(Engine):
4-
def __init__(self, engine_path: str, stream = None):
4+
def __init__(self, engine_path: str, stream = None, build = False):
55
super().__init__(engine_path, stream)
66
self.text_maxlen = 512
77
self.d_model = 4096
88

9-
# Load engine before
10-
self.load_engine()
9+
if not build:
10+
# Load engine before
11+
self.load_engine()
1112

1213
def get_shape_dict(self, batch_size, image_height, image_width):
1314
self.check_dims(batch_size, image_height, image_width)
1415
output = {
1516
"input_ids": (batch_size, self.text_maxlen),
1617
"text_embeddings": (batch_size, self.text_maxlen, self.d_model),
1718
}
18-
return output
19+
return output
20+
21+
def get_input_profile(self, opt_batch_size=1, opt_image_height=1024, opt_image_width=1024, min_batch=1, max_batch=8, min_height=512, max_height=1280, min_width=512, max_width=1280, static_batch=True, dynamic_shape=True):
22+
min_batch = opt_batch_size if static_batch else min_batch
23+
max_batch = opt_batch_size if static_batch else max_batch
24+
return {
25+
"input_ids": [(min_batch, self.text_maxlen), (opt_batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
26+
}

0 commit comments

Comments
 (0)