Skip to content

Commit 1d172ce

Browse files
committed
Add fp4 support
1 parent 100a7aa commit 1d172ce

File tree

5 files changed

+158
-0
lines changed

5 files changed

+158
-0
lines changed

examples/dynamo/vgg16_ptq.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ def calibrate_loop(model):
200200
quant_cfg = mtq.INT8_DEFAULT_CFG
201201
elif args.quantize_type == "fp8":
202202
quant_cfg = mtq.FP8_DEFAULT_CFG
203+
elif args.quantize_type == "fp4":
204+
quant_cfg = mtq.NVFP4_DEFAULT_CFG
203205
# PTQ with in-place replacement to quantized modules
204206
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
205207
# model has FP8 qdq nodes at this point
@@ -239,6 +241,8 @@ def calibrate_loop(model):
239241
enabled_precisions = {torch.int8}
240242
elif args.quantize_type == "fp8":
241243
enabled_precisions = {torch.float8_e4m3fn}
244+
elif args.quantize_type == "fp4":
245+
enabled_precisions = {torch.float4_e2m1fn_x2}
242246
trt_model = torchtrt.dynamo.compile(
243247
exp_program,
244248
inputs=[input_tensor],

py/torch_tensorrt/_enums.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ class dtype(Enum):
7676

7777
f8 = auto()
7878
"""8 bit floating-point number, equivalent to ``dtype.fp8`` and ``dtype.float8``
79+
80+
:meta hide-value:
81+
"""
82+
83+
f4 = auto()
84+
"""4 bit floating-point number, equivalent to ``dtype.fp4`` and ``dtype.float4``
7985
8086
:meta hide-value:
8187
"""
@@ -90,6 +96,7 @@ class dtype(Enum):
9096

9197
float8 = f8
9298
fp8 = f8
99+
fp4 = f4
93100

94101
half = f16
95102
fp16 = f16
@@ -162,6 +169,8 @@ def _from(
162169
return dtype.i32
163170
elif t == torch.float8_e4m3fn:
164171
return dtype.f8
172+
elif t == torch.float4_e2m1fn_x2:
173+
return dtype.f4
165174
elif t == torch.half:
166175
return dtype.f16
167176
elif t == torch.float:
@@ -188,6 +197,8 @@ def _from(
188197
return dtype.i8
189198
elif t == trt.DataType.FP8:
190199
return dtype.f8
200+
elif t == trt.DataType.FP4:
201+
return dtype.fp4
191202
elif t == trt.DataType.INT32:
192203
return dtype.i32
193204
elif t == trt.DataType.INT64:
@@ -357,6 +368,8 @@ def to(
357368
return torch.long
358369
elif self == dtype.f8:
359370
return torch.float8_e4m3fn
371+
elif self == dtype.f4:
372+
return torch.float4_e2m1fn_x2
360373
elif self == dtype.f16:
361374
return torch.half
362375
elif self == dtype.f32:
@@ -410,6 +423,8 @@ def to(
410423
return np.int64
411424
elif self == dtype.f16:
412425
return np.float16
426+
elif self == dtype.f4:
427+
return np.float4_e2m1fn_x2
413428
elif self == dtype.f32:
414429
return np.float32
415430
elif self == dtype.f64:

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,39 @@ def aten_ops_quantize_op(
617617
)
618618

619619

620+
try:
621+
import modelopt.torch.quantization as mtq # noqa: F401
622+
623+
assert torch.ops.tensorrt.dynamic_block_quantize_op.default
624+
except Exception as e:
625+
_LOGGER.warning(
626+
"Unable to import dynamic block quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling dynamic blockquantized models"
627+
)
628+
else:
629+
630+
@dynamo_tensorrt_converter(torch.ops.tensorrt.dynamic_block_quantize_op.default)
631+
def aten_ops_dynamic_block_quantize_op(
632+
ctx: ConversionContext,
633+
target: Target,
634+
args: Tuple[Argument, ...],
635+
kwargs: Dict[str, Argument],
636+
name: str,
637+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
638+
return impl.quantize.dynamic_block_quantize(
639+
ctx,
640+
target,
641+
SourceIR.ATEN,
642+
name,
643+
args[0],
644+
args[1],
645+
args[2],
646+
args[3],
647+
args[4],
648+
args[5],
649+
args[6],
650+
)
651+
652+
620653
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True)
621654
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True)
622655
def aten_ops_squeeze(

py/torch_tensorrt/dynamo/conversion/impl/quantize.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,55 @@ def quantize(
6767
dq_output = dequantize_layer.get_output(0)
6868

6969
return dq_output
70+
71+
def dynamic_block_quantize(
72+
ctx: ConversionContext,
73+
target: Target,
74+
source_ir: Optional[SourceIR],
75+
name: str,
76+
input_tensor: TRTTensor,
77+
block_size: int,
78+
amax: Union[np.ndarray, torch.Tensor],
79+
num_bits: int,
80+
exponent_bits: int,
81+
scale_num_bits: int,
82+
scale_exponent_bits: int,
83+
) -> TRTTensor:
84+
"""
85+
Adds quantize and dequantize ops (QDQ) which quantize to FP4 based
86+
on the output_type set and dequantizes them back.
87+
"""
88+
89+
with unset_fake_temporarily():
90+
if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in (
91+
trt.float32,
92+
trt.float16,
93+
trt.bfloat16,
94+
):
95+
raise ValueError(
96+
f"dynamic_block_quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16 | bfloat16"
97+
)
98+
if len(input_tensor.shape) not in (2, 3):
99+
raise ValueError(
100+
f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D"
101+
)
102+
print(f"input_tensor.shape: {input_tensor.shape} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}")
103+
max_bound = 6
104+
amax = to_torch(amax, None)
105+
scale = torch.divide(amax, max_bound)
106+
scale = get_trt_tensor(ctx, scale, name + "_scale")
107+
108+
output_type=trt.DataType.FP4
109+
# Add Q node
110+
dynamic_quantize_layer = ctx.net.add_dynamic_quantize(input_tensor, axis=-1, block_size=16, output_type=output_type)
111+
quantize_layer.set_output_type(0, output_type)
112+
113+
set_layer_name(quantize_layer, target, name + "_quantize", source_ir)
114+
q_output = quantize_layer.get_output(0)
115+
# Add DQ node
116+
dequantize_layer = ctx.net.add_dequantize(q_output, scale)
117+
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
118+
dequantize_layer.precision = output_type
119+
dq_output = dequantize_layer.get_output(0)
120+
121+
return dq_output

tests/py/dynamo/models/test_models_export.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,60 @@ def test_resnet18_half(ir):
199199
torch._dynamo.reset()
200200

201201

202+
203+
@unittest.skipIf(
204+
torch.cuda.get_device_capability() < (8, 9),
205+
"FP4 quantization requires compute capability 8.9 or later",
206+
)
207+
@unittest.skipIf(
208+
not importlib.util.find_spec("modelopt"),
209+
"ModelOpt is required to run this test",
210+
)
211+
@pytest.mark.unit
212+
def test_base_fp4(ir):
213+
import modelopt.torch.quantization as mtq
214+
from modelopt.torch.quantization.utils import export_torch_mode
215+
216+
class SimpleNetwork(torch.nn.Module):
217+
def __init__(self):
218+
super(SimpleNetwork, self).__init__()
219+
self.linear1 = torch.nn.Linear(in_features=10, out_features=5)
220+
self.linear2 = torch.nn.Linear(in_features=5, out_features=1)
221+
222+
def forward(self, x):
223+
x = self.linear1(x)
224+
x = torch.nn.ReLU()(x)
225+
x = self.linear2(x)
226+
return x
227+
228+
def calibrate_loop(model):
229+
"""Simple calibration function for testing."""
230+
model(input_tensor)
231+
232+
input_tensor = torch.randn(1, 10).cuda()
233+
model = SimpleNetwork().eval().cuda()
234+
235+
quant_cfg = mtq.NVFP4_DEFAULT_CFG
236+
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
237+
# model has FP8 qdq nodes at this point
238+
output_pyt = model(input_tensor)
239+
240+
with torch.no_grad():
241+
with export_torch_mode():
242+
exp_program = torch.export.export(model, (input_tensor,), strict=False)
243+
trt_model = torchtrt.dynamo.compile(
244+
exp_program,
245+
inputs=[input_tensor],
246+
enabled_precisions={torch.float4_e2m1fn_x2},
247+
min_block_size=1,
248+
debug=True,
249+
cache_built_engines=False,
250+
reuse_cached_engines=False,
251+
)
252+
outputs_trt = trt_model(input_tensor)
253+
assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=5e-1)
254+
255+
202256
@unittest.skipIf(
203257
torch.cuda.get_device_capability() < (8, 9),
204258
"FP8 quantization requires compute capability 8.9 or later",

0 commit comments

Comments
 (0)