Skip to content

Commit d45ce00

Browse files
Add Flux fp4 support (#3689)
1 parent 75b7774 commit d45ce00

File tree

6 files changed

+94
-26
lines changed

6 files changed

+94
-26
lines changed

examples/apps/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ python flux_demo.py
2323

2424
### Using Different Precision Modes
2525

26+
- FP4 mode:
27+
```bash
28+
python flux_demo.py --dtype fp4
29+
```
30+
2631
- FP8 mode:
2732
```bash
2833
python flux_demo.py --dtype fp8

examples/apps/flux_demo.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@
1212
from diffusers import FluxPipeline
1313
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
1414

15-
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
16-
sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo"))
17-
from register_sdpa import *
18-
1915
DEVICE = "cuda:0"
2016

2117

@@ -24,8 +20,25 @@ def compile_model(
2420
) -> tuple[
2521
FluxPipeline, FluxTransformer2DModel, torch_tensorrt.MutableTorchTensorRTModule
2622
]:
23+
use_explicit_typing = False
24+
if args.use_sdpa:
25+
# currently use sdpa is not working correctly with flux model, so we don't use it
26+
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
27+
sys.path.append(
28+
os.path.join(os.path.dirname(__file__), "../../tools/llm/torchtrt_ext")
29+
)
30+
import register_sdpa
31+
32+
if args.dtype == "fp4":
33+
use_explicit_typing = True
34+
enabled_precisions = {torch.float4_e2m1fn_x2}
35+
ptq_config = mtq.NVFP4_DEFAULT_CFG
36+
if args.fp4_mha:
37+
from modelopt.core.torch.quantization.config import NVFP4_FP8_MHA_CONFIG
38+
39+
ptq_config = NVFP4_FP8_MHA_CONFIG
2740

28-
if args.dtype == "fp8":
41+
elif args.dtype == "fp8":
2942
enabled_precisions = {torch.float8_e4m3fn, torch.float16}
3043
ptq_config = mtq.FP8_DEFAULT_CFG
3144

@@ -107,26 +120,33 @@ def forward_loop(mod):
107120
"enabled_precisions": enabled_precisions,
108121
"truncate_double": True,
109122
"min_block_size": 1,
110-
"use_python_runtime": True,
123+
"use_python_runtime": False,
111124
"immutable_weights": False,
112-
"offload_module_to_cpu": True,
125+
"offload_module_to_cpu": args.low_vram_mode,
126+
"use_explicit_typing": use_explicit_typing,
113127
}
114128
if args.low_vram_mode:
115129
pipe.remove_all_hooks()
116130
pipe.enable_sequential_cpu_offload()
117131
remove_hook_from_module(pipe.transformer, recurse=True)
118132
pipe.transformer.to(DEVICE)
133+
119134
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
120135
if dynamic_shapes:
121136
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
122137
pipe.transformer = trt_gm
123-
138+
seed = 42
124139
image = pipe(
125-
"Test",
140+
[
141+
"enchanted winter forest, soft diffuse light on a snow-filled day, serene nature scene, the forest is illuminated by the snow"
142+
],
126143
output_type="pil",
127-
num_inference_steps=2,
144+
num_inference_steps=30,
128145
num_images_per_prompt=batch_size,
146+
generator=torch.Generator("cuda").manual_seed(seed),
129147
).images
148+
print(f"generated {len(image)} images")
149+
image[0].save("/tmpforest.png")
130150

131151
torch.cuda.empty_cache()
132152

@@ -242,12 +262,22 @@ def main(args):
242262
parser = argparse.ArgumentParser(
243263
description="Run Flux quantization with different dtypes"
244264
)
245-
265+
parser.add_argument(
266+
"--use_sdpa",
267+
action="store_true",
268+
help="Use sdpa",
269+
default=False,
270+
)
246271
parser.add_argument(
247272
"--dtype",
248-
choices=["fp8", "int8", "fp16"],
273+
choices=["fp4", "fp8", "int8", "fp16"],
249274
default="fp16",
250-
help="Select the data type to use (fp8 or int8 or fp16)",
275+
help="Select the data type to use (fp4 or fp8 or int8 or fp16)",
276+
)
277+
parser.add_argument(
278+
"--fp4_mha",
279+
action="store_true",
280+
help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_DEFAULT_CFG",
251281
)
252282
parser.add_argument(
253283
"--low_vram_mode",

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,11 @@ def cross_compile_for_windows(
258258

259259
if use_explicit_typing:
260260
if len(enabled_precisions) != 1 or not any(
261-
x in enabled_precisions for x in {torch.float32, dtype.f32}
261+
x in enabled_precisions
262+
for x in {torch.float32, dtype.f32, torch.float4_e2m1fn_x2, dtype.f4}
262263
):
263264
raise AssertionError(
264-
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: {_defaults.ENABLED_PRECISIONS}). enabled_precisions should not be used when use_explicit_typing=True"
265+
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True"
265266
)
266267

267268
if use_fp32_acc:
@@ -591,10 +592,11 @@ def compile(
591592

592593
if use_explicit_typing:
593594
if len(enabled_precisions) != 1 or not any(
594-
x in enabled_precisions for x in {torch.float32, dtype.f32}
595+
x in enabled_precisions
596+
for x in {torch.float32, dtype.f32, torch.float4_e2m1fn_x2, dtype.f4}
595597
):
596598
raise AssertionError(
597-
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: {_defaults.ENABLED_PRECISIONS}). enabled_precisions should not be used when use_explicit_typing=True"
599+
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True"
598600
)
599601

600602
if use_fp32_acc:

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
103103
self.quantization_ops: Set[torch._ops.OpOverload] = set()
104104
try:
105105
# modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
106-
import modelopt.torch.quantization as mtq
106+
import modelopt.torch.quantization as mtq # noqa: F401
107107

108108
assert torch.ops.tensorrt.quantize_op.default
109+
assert torch.ops.tensorrt.dynamic_block_quantize_op.default
109110
self.quantization_ops.add(torch.ops.tensorrt.quantize_op.default)
111+
self.quantization_ops.add(
112+
torch.ops.tensorrt.dynamic_block_quantize_op.default
113+
)
110114
except Exception as e:
111115
pass
112116

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def export_fn() -> torch.export.ExportedProgram:
334334
# Check if any quantization precision is enabled
335335
if self.enabled_precisions and any(
336336
precision in self.enabled_precisions
337-
for precision in (torch.float8_e4m3fn, torch.int8)
337+
for precision in (torch.float8_e4m3fn, torch.int8, torch.float4_e2m1fn_x2)
338338
):
339339
try:
340340
from modelopt.torch.quantization.utils import export_torch_mode

tools/perf/Flux/flux_perf.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,29 @@
33
import sys
44
from time import time
55

6+
import torch
7+
68
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../examples/apps"))
79
from flux_demo import compile_model
810

911

1012
def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1):
13+
print(f"Running warmup with {batch_size=} {inference_step=} iterations=10")
14+
# warmup
15+
for i in range(10):
16+
start = time()
17+
images = pipe(
18+
prompt,
19+
output_type="pil",
20+
num_inference_steps=inference_step,
21+
num_images_per_prompt=batch_size,
22+
).images
23+
print(
24+
f"Warmup {i} done in {time() - start} seconds, with {batch_size=} {inference_step=}, generated {len(images)} images"
25+
)
1126

27+
# actual benchmark
28+
print(f"Running benchmark with {batch_size=} {inference_step=} {iterations=}")
1229
start = time()
1330
for i in range(iterations):
1431
image = pipe(
@@ -18,32 +35,42 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1):
1835
num_images_per_prompt=batch_size,
1936
).images
2037
end = time()
21-
2238
print(f"Batch Size: {batch_size}")
2339
print("Time Elapse for", iterations, "iterations:", end - start)
2440
print(
2541
"Average Latency Per Step:",
2642
(end - start) / inference_step / iterations / batch_size,
2743
)
28-
return image
44+
return
2945

3046

3147
def main(args):
48+
print(f"Running flux_perfwith args: {args}")
3249
pipe, backbone, trt_gm = compile_model(args)
33-
for batch_size in range(1, args.max_batch_size + 1):
34-
benchmark(pipe, ["Test"], 20, batch_size=batch_size, iterations=3)
50+
51+
benchmark(pipe, ["Test"], 20, batch_size=args.max_batch_size, iterations=3)
3552

3653

3754
if __name__ == "__main__":
3855
parser = argparse.ArgumentParser(
3956
description="Run Flux quantization with different dtypes"
4057
)
41-
58+
parser.add_argument(
59+
"--use_sdpa",
60+
action="store_true",
61+
help="Use sdpa",
62+
default=False,
63+
)
4264
parser.add_argument(
4365
"--dtype",
44-
choices=["fp8", "int8", "fp16"],
66+
choices=["fp4", "fp8", "int8", "fp16"],
4567
default="fp16",
46-
help="Select the data type to use (fp8 or int8 or fp16)",
68+
help="Select the data type to use (fp4 or fp8 or int8 or fp16)",
69+
)
70+
parser.add_argument(
71+
"--fp4_mha",
72+
action="store_true",
73+
help="Use NVFP4_FP8_MHA_CONFIG config instead of NVFP4_DEFAULT_CFG",
4774
)
4875
parser.add_argument(
4976
"--low_vram_mode",

0 commit comments

Comments
 (0)