Skip to content

Conversation

lanluo-nvidia
Copy link
Collaborator

@lanluo-nvidia lanluo-nvidia commented Apr 29, 2025

Description

Add fp4 support

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Apr 29, 2025
@github-actions github-actions bot requested a review from peri044 April 29, 2025 16:37
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_enums.py	2025-04-29 16:37:46.596096+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_enums.py	2025-04-29 16:38:07.872645+00:00
@@ -77,11 +77,11 @@
    f8 = auto()
    """8 bit floating-point number, equivalent to ``dtype.fp8`` and ``dtype.float8``
    
    :meta hide-value:
    """
-    
+
    f4 = auto()
    """4 bit floating-point number, equivalent to ``dtype.fp4`` and ``dtype.float4``

    :meta hide-value:
    """
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/quantize.py	2025-04-29 16:37:46.600096+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/quantize.py	2025-04-29 16:38:08.284218+00:00
@@ -66,10 +66,11 @@
            dequantize_layer.precision = trt.DataType.FP8
        dq_output = dequantize_layer.get_output(0)

        return dq_output

+
def dynamic_block_quantize(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
    name: str,
@@ -97,19 +98,23 @@
            )
        if len(input_tensor.shape) not in (2, 3):
            raise ValueError(
                f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D"
            )
-        print(f"input_tensor.shape: {input_tensor.shape} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}")
+        print(
+            f"input_tensor.shape: {input_tensor.shape} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}"
+        )
        max_bound = 6
        amax = to_torch(amax, None)
        scale = torch.divide(amax, max_bound)
        scale = get_trt_tensor(ctx, scale, name + "_scale")

-        output_type=trt.DataType.FP4
+        output_type = trt.DataType.FP4
        # Add Q node
-        dynamic_quantize_layer = ctx.net.add_dynamic_quantize(input_tensor, axis=-1, block_size=16, output_type=output_type)
+        dynamic_quantize_layer = ctx.net.add_dynamic_quantize(
+            input_tensor, axis=-1, block_size=16, output_type=output_type
+        )
        quantize_layer.set_output_type(0, output_type)

        set_layer_name(quantize_layer, target, name + "_quantize", source_ir)
        q_output = quantize_layer.get_output(0)
        # Add DQ node
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-04-29 16:37:46.627096+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-04-29 16:38:13.681784+00:00
@@ -195,11 +195,10 @@
        msg=f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
    )

    # Clean up model env
    torch._dynamo.reset()
-


@unittest.skipIf(
    torch.cuda.get_device_capability() < (8, 9),
    "FP4 quantization requires compute capability 8.9 or later",

@narendasan
Copy link
Collaborator

Removed the image so we arent leaking internal info

@github-actions github-actions bot added the component: build system Issues re: Build system label May 1, 2025
@github-actions github-actions bot added the component: lowering Issues re: The lowering / preprocessing passes label May 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants