|
| 1 | +# 🚀 Flux TensorRT Pipelines |
| 2 | + |
| 3 | +This project provides **TensorRT-accelerated pipelines** for Flux models, enabling **faster inference** with static and dynamic shapes. |
| 4 | + |
| 5 | +## ✅ Supported Pipelines |
| 6 | +- ✅ `FluxPipeline` (Supported) |
| 7 | +- ⏳ `FluxImg2ImgPipeline` (Coming soon) |
| 8 | +- ⏳ `FluxInpaintPipeline` (Coming soon) |
| 9 | +- ⏳ `FluxFillPipeline` (Coming soon) |
| 10 | +- ⏳ `FluxKontextPipeline` (Coming soon) |
| 11 | +- ⏳ `FluxKontextInpaintPipeline` (Coming soon) |
| 12 | + |
| 13 | +--- |
| 14 | + |
| 15 | +## ⚙️ Building Flux with TensorRT |
| 16 | + |
| 17 | +We follow the official [NVIDIA/TensorRT](https://github.com/NVIDIA/TensorRT) repository to build TensorRT. |
| 18 | + |
| 19 | +> **Note:** |
| 20 | +> TensorRT was originally built with `diffusers==0.31.1`. |
| 21 | +> Currently, we recommend using: |
| 22 | +> - one **venv** for building, and |
| 23 | +> - another **venv** for inference. |
| 24 | +
|
| 25 | +(🔜 TODO: Build scripts for the latest `diffusers` will be added later.) |
| 26 | + |
| 27 | +### Installation |
| 28 | +```bash |
| 29 | +git clone https://github.com/NVIDIA/TensorRT |
| 30 | +cd TensorRT/demo/Diffusion |
| 31 | + |
| 32 | +pip install tensorrt-cu12==10.13.2.6 |
| 33 | +pip install -r requirements.txt |
| 34 | +``` |
| 35 | + |
| 36 | +### ⚡ Fast Building with Static Shapes |
| 37 | +```bash |
| 38 | +# BF16 |
| 39 | +python3 demo_txt2img_flux.py "a beautiful photograph of Mt. Fuji during cherry blossom" --hf-token=$HF_TOKEN --bf16 --download-onnx-models |
| 40 | + |
| 41 | +# FP8 |
| 42 | +python3 demo_txt2img_flux.py "a beautiful photograph of Mt. Fuji during cherry blossom" --hf-token=$HF_TOKEN --quantization-level 4 --fp8 --download-onnx-models |
| 43 | + |
| 44 | +# FP4 |
| 45 | +python3 demo_txt2img_flux.py "a beautiful photograph of Mt. Fuji during cherry blossom" --hf-token=$HF_TOKEN --fp4 --download-onnx-models |
| 46 | +``` |
| 47 | + |
| 48 | +- To build with dynamic shape, add: `--build-dynamic-shape`. |
| 49 | +- To build with static batch, add `--build-static-batch`. |
| 50 | + |
| 51 | +ℹ️ For more details, run: |
| 52 | +`python demo_txt2img_flux.py --help` |
| 53 | + |
| 54 | +## 🖼️ Inference with Flux TensorRT |
| 55 | +Create a new venv (or update diffusers, peft in your existing one), then run fast inference using TensorRT engines. |
| 56 | + |
| 57 | +Example: Full Pipeline with All Engines |
| 58 | + |
| 59 | +```python |
| 60 | +from pipeline_flux_trt import FluxPipelineTRT |
| 61 | +from cuda import cudart |
| 62 | +import torch |
| 63 | + |
| 64 | +from module.transformers import FluxTransformerModel |
| 65 | +from module.vae import VAEModel |
| 66 | +from module.t5xxl import T5XXLModel |
| 67 | +from module.clip import CLIPModel |
| 68 | +import time |
| 69 | + |
| 70 | +# Local path for each engine |
| 71 | +engine_transformer_path = "path/to/transformer/engine_trt10.13.2.6.plan" |
| 72 | +engine_vae_path = "path/to/vae/engine_trt10.13.2.6.plan" |
| 73 | +engine_t5xxl_path = "path/to/t5/engine_trt10.13.2.6.plan" |
| 74 | +engine_clip_path = "path/to/clip/engine_trt10.13.2.6.plan" |
| 75 | + |
| 76 | +# Create stream for each engine |
| 77 | +stream = cudart.cudaStreamCreate()[1] |
| 78 | + |
| 79 | +# Create engine for each model |
| 80 | +engine_transformer = FluxTransformerModel(engine_transformer_path, stream) |
| 81 | +engine_vae = VAEModel(engine_vae_path, stream) |
| 82 | +engine_t5xxl = T5XXLModel(engine_t5xxl_path, stream) |
| 83 | +engine_clip = CLIPModel(engine_clip_path, stream) |
| 84 | + |
| 85 | +# Create pipeline |
| 86 | +pipeline = FluxPipelineTRT.from_pretrained( |
| 87 | + "black-forest-labs/FLUX.1-dev", |
| 88 | + torch_dtype=torch.bfloat16, |
| 89 | + engine_transformer=engine_transformer, |
| 90 | + engine_vae=engine_vae, |
| 91 | + engine_text_encoder=engine_clip, |
| 92 | + engine_text_encoder_2= engine_t5xxl, |
| 93 | + ) |
| 94 | +pipeline.to("cuda") |
| 95 | + |
| 96 | + |
| 97 | +prompt = "A cat holding a sign that says hello world" |
| 98 | +generator = torch.Generator(device="cuda").manual_seed(42) |
| 99 | +image = pipeline(prompt, num_inference_steps=28, guidance_scale=3.0, generator=generator).images[0] |
| 100 | + |
| 101 | + |
| 102 | +image.save("test_pipeline.png") |
| 103 | +``` |
| 104 | + |
| 105 | +Example: Transformer Only (Other Modules on Torch) |
| 106 | +```python |
| 107 | +pipeline = FluxPipelineTRT.from_pretrained( |
| 108 | + "black-forest-labs/FLUX.1-dev", |
| 109 | + torch_dtype=torch.bfloat16, |
| 110 | + engine_transformer=engine_transformer, |
| 111 | + ) |
| 112 | +pipeline.to("cuda") |
| 113 | +``` |
| 114 | + |
| 115 | +## 📌 Notes |
| 116 | + |
| 117 | +- Ensure correct CUDA / TensorRT versions are installed. |
| 118 | + |
| 119 | +- Always match the `.plan` engine files with the TensorRT version used for building. |
| 120 | + |
| 121 | +- For best performance, prefer static shapes unless dynamic batching is required. |
0 commit comments