Skip to content
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ repos:
- id: clang-format
types_or: [c++, c, cuda]
- repo: https://github.com/keith/pre-commit-buildifier
rev: 6.4.0
rev: 8.0.3
hooks:
- id: buildifier
args:
- --warnings=all
- id: buildifier-lint
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.23
rev: v0.24.1
hooks:
- id: validate-pyproject
- repo: https://github.com/pycqa/isort
Expand All @@ -37,17 +37,17 @@ repos:
- id: isort
name: isort (python)
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.9.0"
rev: "v1.15.0"
hooks:
- id: mypy
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py"
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.3
rev: v0.11.7
hooks:
- id: ruff
- repo: https://github.com/psf/black
rev: 24.3.0
rev: 25.1.0
hooks:
- id: black
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
Expand All @@ -57,7 +57,7 @@ repos:
- id: typos
- repo: https://github.com/astral-sh/uv-pre-commit
# uv version.
rev: 0.5.5
rev: 0.7.1
hooks:
# Update the uv lockfile
- id: uv-lock
Expand Down
4 changes: 4 additions & 0 deletions examples/dynamo/vgg16_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def calibrate_loop(model):
quant_cfg = mtq.INT8_DEFAULT_CFG
elif args.quantize_type == "fp8":
quant_cfg = mtq.FP8_DEFAULT_CFG
elif args.quantize_type == "fp4":
quant_cfg = mtq.NVFP4_DEFAULT_CFG
# PTQ with in-place replacement to quantized modules
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has FP8 qdq nodes at this point
Expand Down Expand Up @@ -239,6 +241,8 @@ def calibrate_loop(model):
enabled_precisions = {torch.int8}
elif args.quantize_type == "fp8":
enabled_precisions = {torch.float8_e4m3fn}
elif args.quantize_type == "fp4":
enabled_precisions = {torch.float4_e2m1fn_x2}
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
Expand Down
19 changes: 19 additions & 0 deletions py/torch_tensorrt/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ class dtype(Enum):
:meta hide-value:
"""

f4 = auto()
"""4 bit floating-point number, equivalent to ``dtype.fp4`` and ``dtype.float4``

:meta hide-value:
"""

uint8 = u8
int8 = i8

Expand All @@ -91,6 +97,9 @@ class dtype(Enum):
float8 = f8
fp8 = f8

float4 = f4
fp4 = f4

half = f16
fp16 = f16
float16 = f16
Expand Down Expand Up @@ -162,6 +171,8 @@ def _from(
return dtype.i32
elif t == torch.float8_e4m3fn:
return dtype.f8
elif t == torch.float4_e2m1fn_x2:
return dtype.f4
elif t == torch.half:
return dtype.f16
elif t == torch.float:
Expand All @@ -188,6 +199,8 @@ def _from(
return dtype.i8
elif t == trt.DataType.FP8:
return dtype.f8
elif t == trt.DataType.FP4:
return dtype.fp4
elif t == trt.DataType.INT32:
return dtype.i32
elif t == trt.DataType.INT64:
Expand Down Expand Up @@ -357,6 +370,8 @@ def to(
return torch.long
elif self == dtype.f8:
return torch.float8_e4m3fn
elif self == dtype.f4:
return torch.float4_e2m1fn_x2
elif self == dtype.f16:
return torch.half
elif self == dtype.f32:
Expand Down Expand Up @@ -394,6 +409,8 @@ def to(
return trt.DataType.BOOL
elif self == dtype.bf16:
return trt.DataType.BF16
elif self == dtype.f4:
return trt.DataType.FP4
elif use_default:
return trt.DataType.FLOAT
else:
Expand All @@ -410,6 +427,8 @@ def to(
return np.int64
elif self == dtype.f16:
return np.float16
elif self == dtype.f4:
return np.float4_e2m1fn_x2
elif self == dtype.f32:
return np.float32
elif self == dtype.f64:
Expand Down
9 changes: 8 additions & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
REQUIRE_FULL_COMPILATION = False
DRYRUN = False
HARDWARE_COMPATIBLE = False
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
SUPPORTED_KERNEL_PRECISIONS = {
dtype.f32,
dtype.f16,
dtype.bf16,
dtype.i8,
dtype.f8,
dtype.f4,
}
TIMING_CACHE_PATH = os.path.join(
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
)
Expand Down
33 changes: 33 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,39 @@ def aten_ops_quantize_op(
)


try:
import modelopt.torch.quantization as mtq # noqa: F401

assert torch.ops.tensorrt.dynamic_block_quantize_op.default
except Exception as e:
_LOGGER.warning(
"Unable to import dynamic block quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling dynamic blockquantized models"
)
else:

@dynamo_tensorrt_converter(torch.ops.tensorrt.dynamic_block_quantize_op.default)
def aten_ops_dynamic_block_quantize_op(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.nvfp4_quantize.nvfp4_quantize(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
)


@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True)
def aten_ops_squeeze(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
matmul,
nccl_ops,
normalization,
nvfp4_quantize,
pad,
permutation,
pool,
Expand Down
Loading
Loading