Skip to content

Commit a93116c

Browse files
committed
[feature] Support Flux TensorRT Pipeline
1 parent 800aa13 commit a93116c

File tree

7 files changed

+1386
-0
lines changed

7 files changed

+1386
-0
lines changed

examples/flux-tensorrt/README.md

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# 🚀 Flux TensorRT Pipelines
2+
3+
This project provides **TensorRT-accelerated pipelines** for Flux models, enabling **faster inference** with static and dynamic shapes.
4+
5+
## ✅ Supported Pipelines
6+
-`FluxPipeline` (Supported)
7+
-`FluxImg2ImgPipeline` (Coming soon)
8+
-`FluxInpaintPipeline` (Coming soon)
9+
-`FluxFillPipeline` (Coming soon)
10+
-`FluxKontextPipeline` (Coming soon)
11+
-`FluxKontextInpaintPipeline` (Coming soon)
12+
13+
---
14+
15+
## ⚙️ Building Flux with TensorRT
16+
17+
We follow the official [NVIDIA/TensorRT](https://github.com/NVIDIA/TensorRT) repository to build TensorRT.
18+
19+
> **Note:**
20+
> TensorRT was originally built with `diffusers==0.31.1`.
21+
> Currently, we recommend using:
22+
> - one **venv** for building, and
23+
> - another **venv** for inference.
24+
25+
(🔜 TODO: Build scripts for the latest `diffusers` will be added later.)
26+
27+
### Installation
28+
```bash
29+
git clone https://github.com/NVIDIA/TensorRT
30+
cd TensorRT/demo/Diffusion
31+
32+
pip install tensorrt-cu12==10.13.2.6
33+
pip install -r requirements.txt
34+
```
35+
36+
### ⚡ Fast Building with Static Shapes
37+
```bash
38+
# BF16
39+
python3 demo_txt2img_flux.py "a beautiful photograph of Mt. Fuji during cherry blossom" --hf-token=$HF_TOKEN --bf16 --download-onnx-models
40+
41+
# FP8
42+
python3 demo_txt2img_flux.py "a beautiful photograph of Mt. Fuji during cherry blossom" --hf-token=$HF_TOKEN --quantization-level 4 --fp8 --download-onnx-models
43+
44+
# FP4
45+
python3 demo_txt2img_flux.py "a beautiful photograph of Mt. Fuji during cherry blossom" --hf-token=$HF_TOKEN --fp4 --download-onnx-models
46+
```
47+
48+
- To build with dynamic shape, add: `--build-dynamic-shape`.
49+
- To build with static batch, add `--build-static-batch`.
50+
51+
ℹ️ For more details, run:
52+
`python demo_txt2img_flux.py --help`
53+
54+
## 🖼️ Inference with Flux TensorRT
55+
Create a new venv (or update diffusers, peft in your existing one), then run fast inference using TensorRT engines.
56+
57+
Example: Full Pipeline with All Engines
58+
59+
```python
60+
from pipeline_flux_trt import FluxPipelineTRT
61+
from cuda import cudart
62+
import torch
63+
64+
from module.transformers import FluxTransformerModel
65+
from module.vae import VAEModel
66+
from module.t5xxl import T5XXLModel
67+
from module.clip import CLIPModel
68+
import time
69+
70+
# Local path for each engine
71+
engine_transformer_path = "path/to/transformer/engine_trt10.13.2.6.plan"
72+
engine_vae_path = "path/to/vae/engine_trt10.13.2.6.plan"
73+
engine_t5xxl_path = "path/to/t5/engine_trt10.13.2.6.plan"
74+
engine_clip_path = "path/to/clip/engine_trt10.13.2.6.plan"
75+
76+
# Create stream for each engine
77+
stream = cudart.cudaStreamCreate()[1]
78+
79+
# Create engine for each model
80+
engine_transformer = FluxTransformerModel(engine_transformer_path, stream)
81+
engine_vae = VAEModel(engine_vae_path, stream)
82+
engine_t5xxl = T5XXLModel(engine_t5xxl_path, stream)
83+
engine_clip = CLIPModel(engine_clip_path, stream)
84+
85+
# Create pipeline
86+
pipeline = FluxPipelineTRT.from_pretrained(
87+
"black-forest-labs/FLUX.1-dev",
88+
torch_dtype=torch.bfloat16,
89+
engine_transformer=engine_transformer,
90+
engine_vae=engine_vae,
91+
engine_text_encoder=engine_clip,
92+
engine_text_encoder_2= engine_t5xxl,
93+
)
94+
pipeline.to("cuda")
95+
96+
97+
prompt = "A cat holding a sign that says hello world"
98+
generator = torch.Generator(device="cuda").manual_seed(42)
99+
image = pipeline(prompt, num_inference_steps=28, guidance_scale=3.0, generator=generator).images[0]
100+
101+
102+
image.save("test_pipeline.png")
103+
```
104+
105+
Example: Transformer Only (Other Modules on Torch)
106+
```python
107+
pipeline = FluxPipelineTRT.from_pretrained(
108+
"black-forest-labs/FLUX.1-dev",
109+
torch_dtype=torch.bfloat16,
110+
engine_transformer=engine_transformer,
111+
)
112+
pipeline.to("cuda")
113+
```
114+
115+
## 📌 Notes
116+
117+
- Ensure correct CUDA / TensorRT versions are installed.
118+
119+
- Always match the `.plan` engine files with the TensorRT version used for building.
120+
121+
- For best performance, prefer static shapes unless dynamic batching is required.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from .engine import Engine
2+
3+
class CLIPModel(Engine):
4+
def __init__(self, engine_path: str, stream = None):
5+
super().__init__(engine_path, stream)
6+
self.text_maxlen = 77
7+
self.embedding_dim = 768
8+
self.keep_pooled_output = True
9+
10+
# Load engine before
11+
self.load_engine()
12+
13+
def get_shape_dict(self, batch_size, image_height, image_width):
14+
self.check_dims(batch_size, image_height, image_width)
15+
output = {
16+
"input_ids": (batch_size, self.text_maxlen),
17+
"text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
18+
}
19+
if self.keep_pooled_output:
20+
output["pooled_embeddings"] = (batch_size, self.embedding_dim)
21+
return output
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import gc
2+
from cuda import cudart
3+
import torch
4+
import tensorrt as trt
5+
6+
from collections import OrderedDict
7+
from polygraphy.backend.common import bytes_from_path
8+
from polygraphy.backend.trt import engine_from_bytes
9+
10+
trt_to_torch_dtype_dict = {
11+
trt.DataType.BOOL: torch.bool,
12+
trt.DataType.UINT8: torch.uint8,
13+
trt.DataType.INT8: torch.int8,
14+
trt.DataType.INT32: torch.int32,
15+
trt.DataType.INT64: torch.int64,
16+
trt.DataType.HALF: torch.float16,
17+
trt.DataType.FLOAT: torch.float32,
18+
trt.DataType.BF16: torch.bfloat16,
19+
}
20+
21+
class Engine:
22+
def __init__(self, engine_path: str, stream = None):
23+
self.engine_path = engine_path
24+
self._binding_indices = {}
25+
self.stream = stream
26+
self.tensors = OrderedDict()
27+
28+
def load_engine(self,stream = None):
29+
self.engine_bytes_cpu = bytes_from_path(self.engine_path)
30+
self.engine = engine_from_bytes(self.engine_bytes_cpu)
31+
self.context = self.engine.create_execution_context()
32+
33+
if stream is None:
34+
self.stream = cudart.cudaStreamCreate()[1]
35+
36+
def allocate_buffers(self, shape_dict=None, device="cuda"):
37+
for binding in range(self.engine.num_io_tensors):
38+
name = self.engine.get_tensor_name(binding)
39+
if shape_dict and name in shape_dict:
40+
shape = shape_dict[name]
41+
else:
42+
shape = self.engine.get_tensor_shape(name)
43+
print(
44+
f"[W]: {self.engine_path}: Could not find '{name}' in shape dict {shape_dict}. Using shape {shape} inferred from the engine."
45+
)
46+
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
47+
self.context.set_input_shape(name, shape)
48+
dtype = trt_to_torch_dtype_dict[self.engine.get_tensor_dtype(name)]
49+
tensor = torch.empty(tuple(shape), dtype=dtype).to(device=device)
50+
self.tensors[name] = tensor
51+
52+
def infer(self, feed_dict, stream):
53+
for name, buf in feed_dict.items():
54+
self.tensors[name].copy_(buf)
55+
56+
for name, tensor in self.tensors.items():
57+
self.context.set_tensor_address(name, tensor.data_ptr())
58+
59+
noerror = self.context.execute_async_v3(stream)
60+
if not noerror:
61+
raise ValueError(f"ERROR: inference of {self.engine_path} failed.")
62+
63+
return self.tensors
64+
65+
def unload_engine(self):
66+
del self.engine
67+
self.engine = None
68+
gc.collect()
69+
70+
def get_shape_dict(self):
71+
pass
72+
73+
def check_dims(self, batch_size, image_height, image_width, compression_factor = 8, min_batch = 1, max_batch = 16, min_latent_shape = 16, max_latent_shape = 1024):
74+
assert batch_size >= min_batch and batch_size <= max_batch
75+
latent_height = image_height // compression_factor
76+
latent_width = image_width // compression_factor
77+
assert latent_height >= min_latent_shape and latent_height <= max_latent_shape
78+
assert latent_width >= min_latent_shape and latent_width <= max_latent_shape
79+
return (latent_height, latent_width)
80+
81+
82+
83+
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from .engine import Engine
2+
3+
class T5XXLModel(Engine):
4+
def __init__(self, engine_path: str, stream = None):
5+
super().__init__(engine_path, stream)
6+
self.text_maxlen = 512
7+
self.d_model = 4096
8+
9+
# Load engine before
10+
self.load_engine()
11+
12+
def get_shape_dict(self, batch_size, image_height, image_width):
13+
self.check_dims(batch_size, image_height, image_width)
14+
output = {
15+
"input_ids": (batch_size, self.text_maxlen),
16+
"text_embeddings": (batch_size, self.text_maxlen, self.d_model),
17+
}
18+
return output
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from .engine import Engine
2+
3+
class FluxTransformerModel(Engine):
4+
def __init__(self, engine_path: str, stream = None):
5+
super().__init__(engine_path, stream)
6+
self.in_channels = 64
7+
8+
# Load engine before
9+
self.load_engine()
10+
11+
def get_shape_dict(self,batch_size, image_height, image_width):
12+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
13+
shape_dict = {
14+
"hidden_states": (batch_size, (latent_height // 2) * (latent_width // 2), 64),
15+
"encoder_hidden_states": (batch_size, 512, 4096),
16+
"pooled_projections": (batch_size, 768),
17+
"timestep": (batch_size,),
18+
"img_ids": ((latent_height // 2) * (latent_width // 2), 3),
19+
"txt_ids": (512, 3),
20+
"latent": (batch_size, (latent_height // 2) * (latent_width // 2), 64),
21+
"guidance": (batch_size,),
22+
}
23+
24+
return shape_dict
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from .engine import Engine
2+
3+
class VAEModel(Engine):
4+
def __init__(self, engine_path: str, stream = None):
5+
super().__init__(engine_path, stream)
6+
self.latent_channels = 16
7+
self.scaling_factor = 0.3611
8+
self.shift_factor = 0.1159
9+
10+
# Load engine before
11+
self.load_engine()
12+
13+
def get_shape_dict(self, batch_size, image_height, image_width):
14+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
15+
return {
16+
"latent": (batch_size, self.latent_channels, latent_height, latent_width),
17+
"images": (batch_size, 3, image_height, image_width),
18+
}

0 commit comments

Comments
 (0)