Skip to content

Conversation

realAsma
Copy link
Contributor

@realAsma realAsma commented Sep 26, 2025

What does this PR do?

Type of change: ? new_feature & bug fix

Overview:
For FP8 fake quantization, the QDQ scaling was applied outside the cuda the kernel in the same precision as the input. This means for bf16 input, the scale is applied in bf16 type. This could cause fake and real quantization to diverge.
Hence in this PR:

  1. Fuse QDQ scaling with the QDQ cuda function
  2. Remove the previous fused_fp8 - this only supported per-channel quantization along -1 axis. Per-channel quantization is usually done along 0 axis and this kernel was not useful.
  3. Implemented an FP8 per-channel fake quantization kernel (similar to integer per-channel fake quantization)
  4. Removed the unnecessary operations such as masking values close to 0 - these add unnecessary overhead and are not present in FP4 or int quantization implementations.
  5. Added unit tests.

What was the impact of scaling in bf16 for QAT?
I have not tested this. I tested the abs quantization error Llama 3.1 8B's first q_proj weight.

Here it is:

Metric Value
Max 0.0547
Mean 1e-04
Min 0.0
Std 4e-04
Median 1.5e-05

Usage

High level usage does not change. See the updated unit tests.

Testing

Updated

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: the fused fp8 fake quantization is removed. However this is not a high-level API and should not cause any impact.
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: NA
  • Did you update Changelog?: No

Additional Information

Summary by CodeRabbit

  • New Features

    • FP8 (E4M3) quantization: per-channel (axis-aware) scaling, optional amax parameter, automatic eager vs GPU-backed selection, outputs preserve original dtype.
  • Refactor

    • CPU compiled FP8 implementation removed in favor of a pure-Python eager fallback; CUDA path streamlined, kernels now respect CUDA stream handling; FP8 extension builds GPU-only sources.
  • Tests

    • Tests updated for runtime FP8 behavior, expanded per-axis and dtype coverage; legacy test scaffolding removed.

@realAsma realAsma requested a review from a team as a code owner September 26, 2025 17:27
@realAsma realAsma requested a review from mxinO September 26, 2025 17:27
Copy link

coderabbitai bot commented Sep 26, 2025

Walkthrough

Removed CPU FP8 sources and header declaration; introduced axis-aware FP8 CUDA kernels and host APIs; Python FP8 routing added eager and axis-aware paths; CUDA kernel launches now use the current stream; tests and test utilities simplified.

Changes

Cohort / File(s) Summary
Header API
modelopt/torch/quantization/src/tensor_quant.h
Removed declaration for at::Tensor fake_e4m3fy_cuda(at::Tensor inputs).
Removed CPU FP8 sources & bindings
modelopt/torch/quantization/src/tensor_quant_fp8.cpp
File deleted: CPU FP8 implementations and pybind11 bindings for fake_e4m3fy, fake_e4m3fy_cuda, and fused_fake_e4m3fy removed.
CUDA FP8 kernels / host (GPU-only)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu
Replaced fused kernel with axis-aware kernel using per-axis scale/inv_scale; kernel/host signatures changed; added fake_e4m3fy(inputs, amax) and fake_e4m3fy_with_axis(inputs, amax, axis); removed per-block/zero-threshold logic; added #include <optional>.
CUDA extension loader
modelopt/torch/quantization/extensions.py
get_cuda_ext_fp8 now builds from only src/tensor_quant_gpu_fp8.cu (CPU FP8 source removed).
Python quantization routing
modelopt/torch/quantization/tensor_quant.py
Added _fp8_eager and fp8_eager; updated scaled_e4m3_impl(inputs, amax=None) to route eager/CUDA scalar/axis paths; removed legacy TensorQuantFunction, LegacyFakeTensorQuantFunction, FakeAffineTensorQuantFunction; _tensor_quant now returns only outputs and always casts to float internally.
CUDA stream usage (non-FP8)
modelopt/torch/quantization/src/tensor_quant_gpu.cu
Added #include <c10/cuda/CUDAStream.h); obtain current CUDA stream and pass it to three kernel launches.
Tests — GPU
tests/gpu/torch/quantization/test_tensor_quant_cuda.py
Removed get_cuda_ext_fp8 import and legacy tester usages; tests adapted to runtime FP8 via fp8_eager, added per-axis/device/dtype parametrization and adjusted assertions.
Tests — Test utils
tests/_test_utils/torch_quantization/tensor_quant_common.py
Removed TensorQuantTester and FakeAffineTensorQuantTester; pruned legacy comparison tests from FakeTensorQuantTester; retained adjusted assertions and overflow/NaN checks.
Tests — CPU
tests/unit/torch/quantization/test_tensor_quant_cpu.py
Simplified imports to only FakeTensorQuantTester; removed tests extending deleted testers.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant PY as Python (scaled_e4m3_impl)
  participant EG as fp8_eager (CPU/Python)
  participant EXT as CUDA extension (pybind11)
  participant CU as CUDA host/kernels

  PY->>PY: receive inputs, optional amax
  alt CPU / fallback / no CUDA ext
    PY->>EG: fp8_eager(inputs, amax)
    EG-->>PY: quantized outputs
  else CUDA & scalar amax
    PY->>EXT: fake_e4m3fy(inputs, amax)
    EXT->>CU: launch scalar CUDA host (scale/inv_scale)
    CU->>CU: elementwise kernel applies scalar scale
    CU-->>PY: quantized outputs
  else CUDA & 1‑D amax (per-axis)
    PY->>EXT: fake_e4m3fy_with_axis(inputs, amax, axis)
    EXT->>CU: launch axis-aware host (per-axis scale/inv_scale)
    CU->>CU: kernel computes axis_id, applies per‑axis scale
    CU-->>PY: quantized outputs
  end
  PY-->>PY: finalize/cast outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I nibble code like clover sprigs,
Kernels learn their axis digs.
CPU crumbs swept from the glade,
Streams hum while outputs fade.
FP8 hops home — a tidy upgrade. 🐇

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title clearly and concisely summarizes the primary bug fix of applying fp32 amax scaling within the FP8 fake quantization kernel and the addition of per-channel quantization support, which aligns with the main objectives and high-level changes described in the pull request. It remains specific to the core feature and fix without extraneous information or vague terminology.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch asma/fp8_fakequant_fix

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 56f5a08 and 2def489.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/tensor_quant.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/tensor_quant.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • amax (236-241)
  • amax (244-255)
  • axis (283-285)
  • axis (288-290)
modelopt/torch/quantization/extensions.py (1)
  • get_cuda_ext_fp8 (39-52)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
  • fake_e4m3fy (87-101)
  • fake_e4m3fy (87-87)
  • fake_e4m3fy_with_axis (66-85)
  • fake_e4m3fy_with_axis (66-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (6)
modelopt/torch/quantization/tensor_quant.py (6)

57-59: LGTM!

Clean public API for the eager FP8 path.


63-64: LGTM!

Optional amax parameter supports dynamic quantization without scaling, aligning with the PR objectives.


75-90: LGTM! Control flow routing is clear.

The routing logic appropriately directs non-CUDA, None amax, or multi-dimensional amax cases to the eager path, while CUDA-compatible cases use the optimized kernels.


374-375: LGTM! Simplified interface.

The refactored _tensor_quant now handles the full quantization-dequantization internally, simplifying the caller.


582-612: LGTM! FP32 QDQ implementation is correct.

The function now performs all quantization and dequantization operations in fp32 (lines 582-583, 609), then casts back to the original dtype (line 611). This aligns with the PR objective to fix FP8 fake quantization precision.


45-55: No device alignment change needed: callers of _fp8_eager already supply amax on the same device as x, so explicit in‐function alignment isn’t required.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🧪 Early access (Sonnet 4.5): enabled

We are currently testing the Sonnet 4.5 model, which is expected to improve code review quality. However, this model may lead to increased noise levels in the review comments. Please disable the early access features if the noise level causes any inconvenience.

Note:

  • Public repositories are always opted into early access features.
  • You can enable or disable early access features from the CodeRabbit UI or by updating the CodeRabbit configuration file.

Comment @coderabbitai help to get the list of available commands and usage tips.

@realAsma realAsma force-pushed the asma/fp8_fakequant_fix branch 2 times, most recently from 87e9050 to 34bcee8 Compare September 26, 2025 17:36
Copy link

codecov bot commented Sep 26, 2025

Codecov Report

❌ Patch coverage is 30.00000% with 21 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.81%. Comparing base (c9db0ce) to head (2def489).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/tensor_quant.py 30.00% 21 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #381      +/-   ##
==========================================
- Coverage   73.86%   73.81%   -0.05%     
==========================================
  Files         171      171              
  Lines       17629    17574      -55     
==========================================
- Hits        13021    12972      -49     
+ Misses       4608     4602       -6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@realAsma realAsma requested a review from cjluo-nv September 26, 2025 21:02
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/tensor_quant.py (1)

45-72: Fix device-guarding and robust axis inference; ensure amax on correct device; update error text.

  • Calling torch.cuda.current_device() unconditionally will crash on CPU-only builds. Guard the CUDA context manager with inputs.is_cuda.
  • In the per-axis path, axis = amax.shape.index(amax.numel()) is fragile/incorrect; it often resolves to 0 and doesn’t consider inputs’ shape. Infer axis by matching amax.numel() to an inputs dimension (or assert axis 0 if that’s the only supported mode).
  • Ensure amax is on the same device as inputs before calling the CUDA kernel.
  • Error message still references “Fused” though that path was removed.

Proposed diff:

 def scaled_e4m3_impl(
     inputs: torch.Tensor,
     amax: torch.Tensor | None = None,
 ) -> torch.Tensor:
@@
-    with torch.cuda.device(
-        None if inputs.device.index == torch.cuda.current_device() else inputs.device.index
-    ):
-        if amax is None or amax.numel() == 1:
-            outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
-        else:
-            if amax.squeeze().ndim > 1:
-                raise NotImplementedError(
-                    "Fused E4M3 kernel does not support multiaxis quantization."
-                )
-            axis = amax.shape.index(amax.numel())
-            outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
-        return outputs
+    if inputs.is_cuda:
+        with torch.cuda.device(
+            None if inputs.device.index == torch.cuda.current_device() else inputs.device.index
+        ):
+            if amax is None or amax.numel() == 1:
+                outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
+            else:
+                if amax.squeeze().ndim > 1:
+                    raise NotImplementedError("E4M3 per-axis quantization supports only 1D amax.")
+                # Ensure amax device matches inputs
+                amax = amax.to(device=inputs.device)
+                # Infer axis by matching amax length to one input dimension (default to 0 if ambiguous)
+                match_axes = [i for i, d in enumerate(inputs.shape) if d == amax.numel()]
+                if not match_axes:
+                    raise ValueError("amax length does not match any input dimension.")
+                axis = match_axes[0]
+                outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
+            return outputs
+    else:
+        # CPU path routes through the C++ wrapper which implements a CPU fallback.
+        if amax is None or amax.numel() == 1:
+            return cuda_ext_fp8.fake_e4m3fy(inputs, amax)
+        if amax.squeeze().ndim > 1:
+            raise NotImplementedError("E4M3 per-axis quantization supports only 1D amax.")
+        # Prefer axis 0 on CPU unless inputs dim match is unambiguous
+        match_axes = [i for i, d in enumerate(inputs.shape) if d == amax.numel()]
+        axis = match_axes[0] if match_axes else 0
+        return cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
🧹 Nitpick comments (6)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (2)

56-61: Axis index computation assumes contiguous layout

axis_id = (idx / stride(axis)) % size(axis) is only valid for contiguous row-major tensors. With the added contiguity normalization above, this becomes safe; otherwise results can be wrong for channels-last or non-contiguous inputs.

If you don't want to force contiguity here, rework indexing to derive multi-dimensional coordinates from idx and use them to index scale.


96-99: Ceil-div grid sizing for the per-tensor kernel

Minor efficiency fix: avoid launching an extra block.

Apply this diff:

-  AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_cuda", [&] {
-    fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
+  AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_cuda", [&] {
+    const int grid = (numel + BLOCK_SIZE * 4 - 1) / (BLOCK_SIZE * 4);
+    fake_e4m3fy_kernel<<<grid, BLOCK_SIZE, 0, stream>>>(
         inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
         outputs.data_ptr<scalar_t>());
   });
tests/gpu/torch/quantization/test_tensor_quant_cuda.py (2)

111-122: Harden reference against amax==0 producing NaNs

If amax is zero (e.g., all-zero channel), 0*inf leads to NaNs. Mirror the kernel-side guard by making amax safe in the reference.

Apply this diff:

 def _fp8_qdq(x, amax=None):
     dtype = x.dtype
     x = x.to(torch.float32)
     scale, scale_inv = None, None
     if amax is not None:
-        scale = 448.0 / (amax.to(torch.float32))
+        amax_safe = torch.where(amax == 0, torch.ones_like(amax, dtype=torch.float32), amax.to(torch.float32))
+        scale = 448.0 / amax_safe
         scale_inv = 1 / scale
         x = x * scale
     x = x.to(torch.float8_e4m3fn).to(torch.float32)
     if scale_inv is not None:
         x = x * scale_inv
     return x.to(dtype)

173-180: Per-channel test aligns with new kernel path

Good coverage across multiple axes; consider adding a case where one channel is all zeros to verify the NaN guard end-to-end.

modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)

268-275: Use property consistently in step_size (optional)

Minor consistency: prefer self.amax over self._amax for readability and single source of truth.

Apply this diff:

-        return self._amax / (2.0 ** (self._num_bits - 1 + int(self._unsigned)) - 1.0)
+        return self.amax / (2.0 ** (self._num_bits - 1 + int(self._unsigned)) - 1.0)

847-855: Avoid mutating internal amax in export_amax

export_amax currently modifies the buffer in-place when replacing zeros and nan. Clone before in-place ops to keep internal state intact.

Apply this diff:

-        if not hasattr(self, "_amax_shape_for_export"):
-            amax = self.amax
-        else:
-            amax = self.amax.reshape(self._amax_shape_for_export)
+        if not hasattr(self, "_amax_shape_for_export"):
+            amax = self.amax.clone()
+        else:
+            amax = self.amax.reshape(self._amax_shape_for_export).clone()
         amax[amax == 0] = self.maxbound
         amax = torch.nan_to_num(amax, nan=self.maxbound)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ad091e8 and ee03094.

📒 Files selected for processing (8)
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4 hunks)
  • modelopt/torch/quantization/src/tensor_quant.h (0 hunks)
  • modelopt/torch/quantization/src/tensor_quant_fp8.cpp (1 hunks)
  • modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (2 hunks)
  • modelopt/torch/quantization/tensor_quant.py (4 hunks)
  • tests/_test_utils/torch_quantization/tensor_quant_common.py (1 hunks)
  • tests/gpu/torch/quantization/test_tensor_quant_cuda.py (3 hunks)
  • tests/unit/torch/quantization/test_tensor_quant_cpu.py (1 hunks)
💤 Files with no reviewable changes (1)
  • modelopt/torch/quantization/src/tensor_quant.h
🧰 Additional context used
🧬 Code graph analysis (6)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • amax (231-236)
  • amax (239-250)
  • axis (277-279)
  • axis (282-284)
tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
  • FakeTensorQuantTester (125-135)
tests/gpu/torch/quantization/test_tensor_quant_cuda.py (2)
tests/_test_utils/torch_misc.py (1)
  • set_seed (33-40)
modelopt/torch/quantization/utils.py (1)
  • reduce_amax (148-180)
modelopt/torch/quantization/tensor_quant.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • amax (231-236)
  • amax (239-250)
  • axis (277-279)
  • axis (282-284)
modelopt/torch/quantization/src/tensor_quant_fp8.cpp (4)
  • fake_e4m3fy (26-48)
  • fake_e4m3fy (26-26)
  • fake_e4m3fy_with_axis (50-53)
  • fake_e4m3fy_with_axis (50-50)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
  • to (115-123)
modelopt/torch/quantization/src/tensor_quant_fp8.cpp (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • amax (231-236)
  • amax (239-250)
  • axis (277-279)
  • axis (282-284)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
  • fake_e4m3fy_cuda (84-101)
  • fake_e4m3fy_cuda (84-84)
  • fake_e4m3fy_cuda_with_axis (66-82)
  • fake_e4m3fy_cuda_with_axis (66-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (6)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)

135-135: LGTM: stability guard added

Asserting no inf/nan in fp16 path is good coverage for overflow regressions.

tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)

22-23: LGTM: imports and test surface aligned

CPU tests now only depend on FakeTensorQuantTester, consistent with the utility refactor.

tests/gpu/torch/quantization/test_tensor_quant_cuda.py (1)

20-24: Addressed prior ask: references now use pure torch

Switching to a pure torch reference (_fp8_qdq) resolves the earlier concern about reference correctness.

modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)

244-251: Forcing amax to float32 is correct

Standardizing amax dtype to float32 simplifies CUDA bindings and distributed sync.

modelopt/torch/quantization/tensor_quant.py (1)

564-566: LGTM on dtype handling in INT path.

Casting inputs/amax to float and restoring outputs to input dtype is correct and avoids fp16/bf16 overflow. Safe given quantized range <= 2^7 fits fp16/bf16.

Also applies to: 592-593

modelopt/torch/quantization/src/tensor_quant_fp8.cpp (1)

55-60: Bindings look good; optional amax default is correct.

API surface matches the Python usage and simplifies the call sites. Nice cleanup.

Please re-run unit tests that pass amax from CPU to ensure no hidden device mismatches remain after the fix.

@realAsma realAsma force-pushed the asma/fp8_fakequant_fix branch from ee03094 to dfca83e Compare September 26, 2025 21:44
@realAsma realAsma requested a review from meenchen September 26, 2025 21:44
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)

130-135: Fix wrong positional args in test; pass bias explicitly or use keywords.

You're passing 8 as bias and False as num_bits. This breaks the test intent.

Apply:

-        quant_x_test = tensor_quant.fake_tensor_quant(
-            x, torch.tensor(1e-4).to(self.device).half(), 8, False
-        )
+        quant_x_test = tensor_quant.fake_tensor_quant(
+            x,
+            torch.tensor(1e-4, device=self.device).to(x.dtype),
+            bias=None,
+            num_bits=8,
+            unsigned=False,
+        )
🧹 Nitpick comments (3)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (1)

95-99: Use ceil-div for grid sizing in per-tensor path (consistency).

Minor efficiency and consistency fix.

   AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_cuda", [&] {
-    fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
+    const int grid = (numel + BLOCK_SIZE * 4 - 1) / (BLOCK_SIZE * 4);
+    fake_e4m3fy_kernel<<<grid, BLOCK_SIZE, 0, stream>>>(
         inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
         outputs.data_ptr<scalar_t>());
   });
tests/gpu/torch/quantization/test_tensor_quant_cuda.py (2)

104-109: Prefer 1D amax for with-axis CUDA API (avoid relying on broadcast).

Simplifies expectations and mirrors other tests.

-        amax = torch.tensor([1.0, 1.0e-26, 1.0]).cuda().unsqueeze(-1).unsqueeze(1)
-        quant_x = get_cuda_ext().fake_tensor_quant_with_axis(x, amax, axis=1)
+        amax = torch.tensor([1.0, 1.0e-26, 1.0], device=x.device)
+        quant_x = get_cuda_ext().fake_tensor_quant_with_axis(x, amax, axis=1)

157-164: Good per-channel FP8 coverage. Consider adding a non-contiguous variant.

Transpose along a non-axis dim to exercise the dispatcher’s contiguity path.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ee03094 and dfca83e.

📒 Files selected for processing (7)
  • modelopt/torch/quantization/src/tensor_quant.h (0 hunks)
  • modelopt/torch/quantization/src/tensor_quant_fp8.cpp (1 hunks)
  • modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (2 hunks)
  • modelopt/torch/quantization/tensor_quant.py (4 hunks)
  • tests/_test_utils/torch_quantization/tensor_quant_common.py (1 hunks)
  • tests/gpu/torch/quantization/test_tensor_quant_cuda.py (3 hunks)
  • tests/unit/torch/quantization/test_tensor_quant_cpu.py (1 hunks)
💤 Files with no reviewable changes (1)
  • modelopt/torch/quantization/src/tensor_quant.h
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/src/tensor_quant_fp8.cpp
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/tensor_quant.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • amax (236-241)
  • amax (244-255)
  • axis (283-285)
  • axis (288-290)
modelopt/torch/quantization/src/tensor_quant_fp8.cpp (4)
  • fake_e4m3fy (26-48)
  • fake_e4m3fy (26-26)
  • fake_e4m3fy_with_axis (50-53)
  • fake_e4m3fy_with_axis (50-50)
tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
  • FakeTensorQuantTester (125-135)
tests/gpu/torch/quantization/test_tensor_quant_cuda.py (5)
tests/_test_utils/torch_misc.py (1)
  • set_seed (33-40)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
  • FakeTensorQuantTester (125-135)
modelopt/torch/quantization/extensions.py (2)
  • get_cuda_ext (27-36)
  • get_cuda_ext_mx (55-71)
modelopt/torch/quantization/tensor_quant.py (1)
  • fp8_eager (56-60)
modelopt/torch/quantization/utils.py (1)
  • reduce_amax (148-180)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (4)
modelopt/torch/quantization/tensor_quant.py (1)

579-581: Compute in FP32 and restore input dtype — LGTM.

This prevents overflow/rounding artifacts in half/bfloat16.

Also applies to: 607-607

tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)

22-22: Import consolidation LGTM—no stale tester references found
Verified no occurrences of TensorQuantTester or FakeAffineTensorQuantTester in the test suite.

modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (2)

75-79: Launch on current CUDA stream and use ceil-div for grid.

Avoid default stream usage and over-provisioned grids.

-  AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis_cuda", [&] {
-    fake_e4m3fy_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE>>>(
-        inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
-        axis_size, outer_size, outputs.data_ptr<scalar_t>());
-  });
+  AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis_cuda", [&] {
+    auto stream = c10::cuda::getCurrentCUDAStream();
+    const int grid = (numel + BLOCK_SIZE * 4 - 1) / (BLOCK_SIZE * 4);
+    fake_e4m3fy_with_axis_cuda_kernel<<<grid, BLOCK_SIZE, 0, stream>>>(
+        inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
+        axis_size, outer_size, outputs.data_ptr<scalar_t>());
+  });

66-71: Normalize axis and validate bounds before size/stride.

Handle negative axes and guard against OOB axis to avoid mis-indexing.

 at::Tensor fake_e4m3fy_cuda_with_axis(at::Tensor inputs, at::Tensor amax, int axis) {
   auto outputs = torch::empty_like(inputs);
-  size_t numel = inputs.numel();
-  int axis_size = inputs.size(axis);
-  int outer_size = inputs.stride(axis);
+  size_t numel = inputs.numel();
+  if (axis < 0) {
+    axis += inputs.dim();
+  }
+  TORCH_CHECK(axis >= 0 && axis < inputs.dim(), "fake_e4m3fy_cuda_with_axis: axis out of range");
+  int axis_size = inputs.size(axis);
+  int outer_size = inputs.stride(axis);

@realAsma realAsma force-pushed the asma/fp8_fakequant_fix branch from dfca83e to ddbd66d Compare September 26, 2025 22:03
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/tensor_quant.py (1)

443-451: quantize_op schema requires Tensor amax; None will break. Route None to eager path.

scaled_e4m3 now allows amax=None, but torch.library quantize_op is defined with Tensor amax. Passing None can error when the library path is active.

Apply this diff to handle None without touching the op schema:

-        outputs = quantize_op(
-            inputs,
-            amax,
-            num_bits=8,
-            exponent_bits=4,
-            unsigned=False,
-            narrow_range=False,
-        )
+        if amax is None:
+            outputs = scaled_e4m3_impl(inputs=inputs, amax=None)
+        else:
+            outputs = quantize_op(
+                inputs,
+                amax,
+                num_bits=8,
+                exponent_bits=4,
+                unsigned=False,
+                narrow_range=False,
+            )

Alternatively, define a None-overload for quantize_op like you did for dynamic_block_quantize_op.

🧹 Nitpick comments (3)
modelopt/torch/quantization/src/tensor_quant_gpu.cu (2)

18-24: Include CUDAStream header for portability.

c10::cuda::getCurrentCUDAStream() requires c10/cuda/CUDAStream.h on some PyTorch versions. Add the explicit include to avoid build issues.

Apply this diff:

 #include <torch/extension.h>
+#include <c10/cuda/CUDAStream.h>

130-136: Use the current stream consistently across kernels.

NF4/INT4 kernels below still launch on the default stream. For consistency and better interop, consider passing the current stream to those launches as well.

If you choose to update, mirror the pattern here:

-    NF4_dequantize_kernel<<<blocks, block_size / 2>>>(...)
+    auto stream = c10::cuda::getCurrentCUDAStream();
+    NF4_dequantize_kernel<<<blocks, block_size / 2, 0, stream>>>(...)
modelopt/torch/quantization/tensor_quant.py (1)

52-55: Guard torch.compile usage to avoid import-time failures.

Decorating with torch.compile at import time can break older Torch builds or unsupported backends. Consider guarding or compiling lazily.

Example:

if hasattr(torch, "compile"):
    _fp8_triton = torch.compile(dynamic=True)(_fp8_eager)
else:
    _fp8_triton = _fp8_eager
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dfca83e and ddbd66d.

📒 Files selected for processing (8)
  • modelopt/torch/quantization/src/tensor_quant.h (0 hunks)
  • modelopt/torch/quantization/src/tensor_quant_fp8.cpp (1 hunks)
  • modelopt/torch/quantization/src/tensor_quant_gpu.cu (3 hunks)
  • modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (2 hunks)
  • modelopt/torch/quantization/tensor_quant.py (4 hunks)
  • tests/_test_utils/torch_quantization/tensor_quant_common.py (1 hunks)
  • tests/gpu/torch/quantization/test_tensor_quant_cuda.py (3 hunks)
  • tests/unit/torch/quantization/test_tensor_quant_cpu.py (1 hunks)
💤 Files with no reviewable changes (1)
  • modelopt/torch/quantization/src/tensor_quant.h
🧰 Additional context used
🧬 Code graph analysis (5)
tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
  • FakeTensorQuantTester (125-135)
modelopt/torch/quantization/tensor_quant.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • amax (236-241)
  • amax (244-255)
  • axis (283-285)
  • axis (288-290)
modelopt/torch/quantization/src/tensor_quant_fp8.cpp (4)
  • fake_e4m3fy (26-48)
  • fake_e4m3fy (26-26)
  • fake_e4m3fy_with_axis (50-53)
  • fake_e4m3fy_with_axis (50-50)
tests/gpu/torch/quantization/test_tensor_quant_cuda.py (5)
tests/_test_utils/torch_misc.py (1)
  • set_seed (33-40)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
  • FakeTensorQuantTester (125-135)
modelopt/torch/quantization/extensions.py (2)
  • get_cuda_ext (27-36)
  • get_cuda_ext_mx (55-71)
modelopt/torch/quantization/tensor_quant.py (1)
  • fp8_eager (56-60)
modelopt/torch/quantization/utils.py (1)
  • reduce_amax (148-180)
modelopt/torch/quantization/src/tensor_quant_fp8.cpp (1)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
  • fake_e4m3fy_cuda (85-102)
  • fake_e4m3fy_cuda (85-85)
  • fake_e4m3fy_cuda_with_axis (66-83)
  • fake_e4m3fy_cuda_with_axis (66-66)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (1)
modelopt/torch/quantization/src/tensor_quant_gpu.cu (6)
  • void (58-72)
  • void (100-119)
  • void (145-165)
  • void (209-231)
  • void (260-277)
  • void (309-338)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (15)
modelopt/torch/quantization/src/tensor_quant_gpu.cu (1)

77-83: Good: launch on the current CUDA stream.

Passing the current stream to the kernel improves interop with upstream stream semantics.

Confirm that callers always enter this function under the proper device context (Python code seems to guard with torch.cuda.device). Otherwise, consider guarding here with cudaSetDevice(inputs.device().index()) for safety.

tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)

22-22: LGTM: import surface simplified.

Importing only FakeTensorQuantTester aligns with the updated shared test utilities.

tests/_test_utils/torch_quantization/tensor_quant_common.py (1)

130-135: LGTM: overflow/NaN guard.

The explicit inf/NaN assertion strengthens the fp16 overflow test.

modelopt/torch/quantization/tensor_quant.py (2)

44-51: FP8 QDQ in fp32 is correct.

Doing scale/convert/scale-back in fp32 and restoring the original dtype is the right fix to avoid bf16 scaling errors.


80-85: Axis inference is fragile; fix ambiguous/mismatched cases.

amax.squeeze().ndim == 1 with amax.shape.index(amax.numel()) can mis-infer or throw when multiple dims match or keepdims differ.

Apply robust inference and fall back to eager when ambiguous:

-        elif amax.squeeze().ndim == 1:
-                axis = amax.shape.index(amax.numel())
-                outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
+        elif amax.squeeze().ndim == 1:
+            amax_s = amax.squeeze()
+            axis_inferred = None
+            if amax.dim() == inputs.dim():
+                candidates = [
+                    i
+                    for i, (asize, isize) in enumerate(zip(amax.shape, inputs.shape))
+                    if asize == isize and asize > 1
+                ]
+                if len(candidates) == 1:
+                    axis_inferred = candidates[0]
+            else:
+                candidates = [i for i, isize in enumerate(inputs.shape) if isize == amax_s.numel()]
+                if len(candidates) == 1:
+                    axis_inferred = candidates[0]
+            if axis_inferred is not None:
+                outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax_s.reshape(-1), axis_inferred)
+            else:
+                outputs = fp8_eager(inputs, amax)
tests/gpu/torch/quantization/test_tensor_quant_cuda.py (2)

113-116: LGTM: runtime FP8 reference using fp8_eager.

This validates scaled_e4m3 against the eager reference and covers both CUDA and CPU.


157-164: Per-channel FP8 coverage looks good.

Axis-parameterized test aligns with the new per-channel path and eager fallback behavior.

modelopt/torch/quantization/src/tensor_quant_fp8.cpp (4)

26-33: LGTM! Proper optional amax handling added.

The function signature has been updated to accept an optional amax parameter and correctly handles both cases (with and without amax). The scaling factor of 448 is appropriate for FP8 E4M3 format which can store values up to +/-448.


34-48: LGTM! CPU path correctly implements FP8 quantization.

The CPU implementation properly calculates scale and inv_scale, and the quantization logic correctly converts to FP8 E4M3 and back to maintain the same precision as the CUDA kernel.


50-53: Device mismatch not addressed despite past review.

Previous review flagged that amax may reside on CPU while inputs are on CUDA, causing device mismatch in the CUDA kernel. The author mentioned this is "fine" but the issue remains unresolved.


55-60: LGTM! Module bindings correctly updated for new API.

The Python bindings properly expose both the regular fake_e4m3fy with optional amax and the new axis-aware variant fake_e4m3fy_with_axis.

modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)

35-45: LGTM! Kernel signature updated for fused QDQ scaling.

The kernel signature properly accepts scale and inv_scale pointers, enabling fp32 scaling computation as intended by the PR objectives. The loop correctly processes 4 elements per thread for vectorization.


47-64: LGTM! Axis-aware kernel correctly implements per-channel quantization.

The kernel properly computes axis_id using stride-based indexing and applies per-axis scale/inv_scale values. This enables the FP8 per-channel quantization along axis-0 as described in the PR objectives.


66-83: Grid sizing and contiguity concerns from previous review unaddressed.

Previous reviews flagged:

  1. Over-provisioned grid size using + 1 instead of proper ceil-div
  2. Missing contiguity checks for stride-based axis indexing
  3. Missing axis normalization for negative values

The author indicated contiguity is checked in the C++ dispatcher, but grid sizing remains suboptimal.


85-102: LGTM! Per-tensor kernel properly handles optional amax.

The function correctly computes scale from amax when provided, defaults to unity scale otherwise, and passes the scale/inv_scale pointers to the kernel. The implementation aligns with the PR objective of fusing QDQ scaling into the kernel.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
modelopt/torch/quantization/tensor_quant.py (1)

45-55: Guard against non-positive amax and clarify broadcasting expectations.

Division by zero/negatives will produce NaNs or invalid results; also per-channel eager path relies on caller-provided broadcastable shape.

Apply:

 def _fp8_eager(x, amax=None):
     dtype = x.dtype
     if amax is not None:
+        if torch.any(amax <= 0):
+            raise ValueError("amax must be > 0 for FP8 scaling.")
         scale = 448.0 / (amax.to(torch.float32))
         scale_inv = 1 / scale
         x = x.to(torch.float32) * scale

Please verify that CPU callers pass amax with a broadcastable shape for the intended axis (e.g., keepdim with 1s in other dims). Otherwise, broadcasting may target the last dimension by default.

modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (2)

87-95: Per-tensor path: enforce scalar amax.

Prevent accidental multi-element amax on the scalar kernel.

Apply:

 at::Tensor fake_e4m3fy(at::Tensor inputs, at::Tensor amax) {
   inputs = inputs.contiguous();
   size_t numel = inputs.numel();
-  at::Tensor scale = 448.f / amax;
+  TORCH_CHECK(amax.numel() == 1,
+              "fake_e4m3fy expects scalar amax; use fake_e4m3fy_with_axis for per-channel.");
+  at::Tensor scale = 448.f / amax;
   auto inv_scale = 1.f / scale;

77-82: Optional: ceil-div grid sizing.

Slightly cleaner and consistent with other kernels.

Apply:

-  AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis", [&] {
-    fake_e4m3fy_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
+  AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis", [&] {
+    const int grid = (numel + BLOCK_SIZE * 4 - 1) / (BLOCK_SIZE * 4);
+    fake_e4m3fy_with_axis_cuda_kernel<<<grid, BLOCK_SIZE, 0, stream>>>(
         inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
         axis_size, outer_size, outputs.data_ptr<scalar_t>());
   });
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ddbd66d and 97dc2ef.

📒 Files selected for processing (4)
  • modelopt/torch/quantization/extensions.py (1 hunks)
  • modelopt/torch/quantization/src/tensor_quant_fp8.cpp (0 hunks)
  • modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (2 hunks)
  • modelopt/torch/quantization/tensor_quant.py (4 hunks)
💤 Files with no reviewable changes (1)
  • modelopt/torch/quantization/src/tensor_quant_fp8.cpp
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/quantization/tensor_quant.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • amax (236-241)
  • amax (244-255)
  • axis (283-285)
  • axis (288-290)
modelopt/torch/quantization/extensions.py (1)
  • get_cuda_ext_fp8 (39-52)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
  • fake_e4m3fy (87-100)
  • fake_e4m3fy (87-87)
  • fake_e4m3fy_with_axis (66-85)
  • fake_e4m3fy_with_axis (66-66)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (1)
modelopt/torch/quantization/src/tensor_quant_gpu.cu (6)
  • void (58-72)
  • void (100-119)
  • void (145-165)
  • void (209-231)
  • void (260-277)
  • void (309-338)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (3)
modelopt/torch/quantization/extensions.py (1)

44-45: FP8 CUDA-only source switch looks good.

Dropping the CPU source here is consistent with the new CUDA-only FP8 path. No issues.

modelopt/torch/quantization/tensor_quant.py (2)

383-386: Do the dequant division in fp32 to preserve fp32 QDQ intent.

Prevents bf16/half rounding in the legacy path.

Apply:

-            return outputs / scale.to(inputs.dtype)
+            return (outputs.to(torch.float32) / scale.to(torch.float32)).to(inputs.dtype)

93-99: Axis inference can be wrong; at least validate or fall back.

Current axis = amax.shape.index(amax.numel()) resolves to 0 for 1D amax and can silently mismatch. Prefer robust inference; otherwise validate and fall back.

Apply:

-        elif amax.squeeze().ndim == 1:
-            axis = amax.shape.index(amax.numel())
-            outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
+        elif amax.squeeze().ndim == 1:
+            axis_len = amax.numel()
+            # Try to infer a unique matching axis from inputs' shape.
+            candidates = [i for i, s in enumerate(inputs.shape) if s == axis_len]
+            if len(candidates) == 1:
+                axis = candidates[0]
+                outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.reshape(-1), axis)
+            else:
+                # Ambiguous or no match; use eager path to respect broadcasting semantics.
+                outputs = fp8_eager(inputs, amax)

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/torch/quantization/tensor_quant.py (1)

57-60: Optionally guard torch.compile for wider PyTorch compatibility.

Unconditional decorator can break on older torch. Gate it and fallback to eager.

Apply this diff:

-@torch.compile(dynamic=True)
-def _fp8_triton(x, amax):
-    return _fp8_eager(x, amax)
+try:
+    _fp8_triton = torch.compile(dynamic=True)(_fp8_eager)  # type: ignore[attr-defined]
+except AttributeError:
+    _fp8_triton = _fp8_eager
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 97dc2ef and c4c32ba.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/tensor_quant.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/tensor_quant.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • amax (236-241)
  • amax (244-255)
  • axis (283-285)
  • axis (288-290)
modelopt/torch/quantization/extensions.py (1)
  • get_cuda_ext_fp8 (39-52)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
  • fake_e4m3fy (87-100)
  • fake_e4m3fy (87-87)
  • fake_e4m3fy_with_axis (66-85)
  • fake_e4m3fy_with_axis (66-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (6)
modelopt/torch/quantization/tensor_quant.py (6)

380-384: Do the legacy dequant division in fp32 to preserve fp32 QDQ intent.

Prevents bf16/fp16 rounding when dequantizing on CPU/fallback paths.

Apply this diff:

-            outputs, scale = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
-            return outputs / scale.to(inputs.dtype)
+            outputs, scale = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
+            return (outputs.to(torch.float32) / scale.to(torch.float32)).to(inputs.dtype)

Based on learnings


590-592: LGTM: compute in fp32 to avoid overflow/rounding.

Casting inputs and amax to float improves numerical stability and aligns with the PR goal.


83-85: Fix: replace inputs.is_cpu with a supported check.

torch.Tensor has is_cuda but not is_cpu. This will raise AttributeError on GPU tensors. Use not inputs.is_cuda instead.

Apply this diff:

-    if inputs.is_cpu or amax is None or amax.squeeze().ndim > 1:
+    if not inputs.is_cuda or amax is None or amax.squeeze().ndim > 1:
         return fp8_eager(inputs, amax)

Based on learnings


86-89: Ensure amax is on the same CUDA device and in fp32 before calling the CUDA kernel.

If amax is on CPU or not fp32, the extension can error or incur unintended casts. Move it to inputs.device and cast to fp32.

Apply this diff:

     cuda_ext_fp8 = get_cuda_ext_fp8(raise_if_failed=False)
     if cuda_ext_fp8 is None:
         return fp8_eager(inputs, amax)
 
+    # Align amax device/dtype for CUDA path
+    if amax.device != inputs.device or amax.dtype != torch.float32:
+        amax = amax.to(device=inputs.device, dtype=torch.float32, non_blocking=True)

Based on learnings


95-98: Robust axis inference; fall back to eager when ambiguous.

axis = amax.shape.index(amax.numel()) is fragile and can infer a wrong axis or fail when multiple dims match. Infer axis by comparing shapes; if ambiguous, use the eager path.

Apply this diff:

-        elif amax.squeeze().ndim == 1:
-            axis = amax.shape.index(amax.numel())
-            outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
+        elif amax.squeeze().ndim == 1:
+            amax_s = amax.squeeze()
+            axis = None
+            if amax.dim() == inputs.dim():
+                candidates = [
+                    i for i, (asize, isize) in enumerate(zip(amax.shape, inputs.shape))
+                    if asize == isize and asize > 1
+                ]
+                if len(candidates) == 1:
+                    axis = candidates[0]
+            else:
+                candidates = [i for i, isize in enumerate(inputs.shape) if isize == amax_s.numel()]
+                if len(candidates) == 1:
+                    axis = candidates[0]
+            if axis is None:
+                return fp8_eager(inputs, amax)
+            outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax_s, axis)

83-98: No additional is_cpu checks or FP8 callsites with ambiguous 1D amax found.

Copy link

@mxinO mxinO left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks. Have you benchmarked the tirton vs CUDA on hopper+? we can use triton by default if it's faster.

@realAsma realAsma force-pushed the asma/fp8_fakequant_fix branch 2 times, most recently from d920b7f to ce384be Compare September 27, 2025 13:35
@realAsma realAsma force-pushed the asma/fp8_fakequant_fix branch from a8919e9 to 1ae9302 Compare September 29, 2025 19:13
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (2)

66-85: Consider validating amax length matches axis size.

The function correctly makes inputs and amax contiguous and uses the current CUDA stream. However, consider adding a validation to ensure amax.numel() == axis_size to catch shape mismatches early and provide a clear error message.

 at::Tensor fake_e4m3fy_with_axis(at::Tensor inputs, at::Tensor amax, int axis) {
   inputs = inputs.contiguous();
   amax = amax.contiguous().to(at::kFloat);
   auto outputs = torch::empty_like(inputs);
   size_t numel = inputs.numel();
   int axis_size = inputs.size(axis);
   int outer_size = inputs.stride(axis);
 
+  TORCH_CHECK(
+      amax.numel() == axis_size,
+      "fake_e4m3fy_with_axis: amax.numel() must equal axis length. Expected ",
+      axis_size, ", got ", amax.numel()
+  );
+
   auto scale = 448.f / amax;
   auto inv_scale = 1.f / scale;

Based on past review comments.


87-101: Consider validating amax is scalar for per-tensor path.

The function correctly uses the current stream and makes inputs contiguous. However, the per-tensor kernel reads scale[0] only. Consider adding a validation to ensure amax.numel() == 1 to catch misuse early.

 at::Tensor fake_e4m3fy(at::Tensor inputs, at::Tensor amax) {
   inputs = inputs.contiguous();
   amax = amax.view(-1).to(at::kFloat);
+  TORCH_CHECK(
+      amax.numel() == 1,
+      "fake_e4m3fy expects scalar amax; use fake_e4m3fy_with_axis for per-channel."
+  );
   size_t numel = inputs.numel();
   at::Tensor scale = 448.f / amax;

Based on past review comments.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ce384be and 1ae9302.

📒 Files selected for processing (9)
  • modelopt/torch/quantization/extensions.py (1 hunks)
  • modelopt/torch/quantization/src/tensor_quant.h (0 hunks)
  • modelopt/torch/quantization/src/tensor_quant_fp8.cpp (0 hunks)
  • modelopt/torch/quantization/src/tensor_quant_gpu.cu (4 hunks)
  • modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (2 hunks)
  • modelopt/torch/quantization/tensor_quant.py (5 hunks)
  • tests/_test_utils/torch_quantization/tensor_quant_common.py (0 hunks)
  • tests/gpu/torch/quantization/test_tensor_quant_cuda.py (3 hunks)
  • tests/unit/torch/quantization/test_tensor_quant_cpu.py (1 hunks)
💤 Files with no reviewable changes (3)
  • modelopt/torch/quantization/src/tensor_quant.h
  • modelopt/torch/quantization/src/tensor_quant_fp8.cpp
  • tests/_test_utils/torch_quantization/tensor_quant_common.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/extensions.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-16T20:14:34.768Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:191-191
Timestamp: 2025-09-16T20:14:34.768Z
Learning: The TensorRT-Model-Optimizer project only supports PyTorch >= 2.6, so using the `weights_only` parameter in torch.load calls is acceptable and doesn't require backward compatibility handling.

Applied to files:

  • modelopt/torch/quantization/tensor_quant.py
🧬 Code graph analysis (4)
tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
  • FakeTensorQuantTester (126-136)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (1)
modelopt/torch/quantization/src/tensor_quant_gpu.cu (6)
  • void (59-73)
  • void (101-120)
  • void (147-167)
  • void (211-233)
  • void (262-279)
  • void (311-340)
tests/gpu/torch/quantization/test_tensor_quant_cuda.py (4)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
  • FakeTensorQuantTester (126-136)
modelopt/torch/quantization/extensions.py (2)
  • get_cuda_ext (27-36)
  • get_cuda_ext_mx (55-71)
modelopt/torch/quantization/tensor_quant.py (1)
  • fp8_eager (61-66)
modelopt/torch/quantization/utils.py (1)
  • reduce_amax (148-180)
modelopt/torch/quantization/tensor_quant.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • amax (236-241)
  • amax (244-255)
  • axis (283-285)
  • axis (288-290)
modelopt/torch/quantization/extensions.py (1)
  • get_cuda_ext_fp8 (39-52)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
  • fake_e4m3fy (87-101)
  • fake_e4m3fy (87-87)
  • fake_e4m3fy_with_axis (66-85)
  • fake_e4m3fy_with_axis (66-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (20)
modelopt/torch/quantization/src/tensor_quant_gpu.cu (4)

19-19: LGTM: Stream header inclusion is appropriate.

The inclusion of <c10/cuda/CUDAStream.h> is necessary for the stream-aware kernel launches below.


78-82: LGTM: Stream-aware kernel launch prevents synchronization issues.

The kernel now correctly executes on the current CUDA stream rather than the default stream, which is important for correctness in multi-stream scenarios.


90-94: LGTM: Consistent stream-aware pattern.

The fake_tensor_quant_cuda function correctly uses the current stream, matching the pattern in the inplace variant.


131-137: LGTM: Per-axis variant now stream-aware.

The with-axis kernel launch now correctly uses the current CUDA stream, completing the stream-awareness migration across all quantization kernels.

tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)

22-22: LGTM: Import simplified to reflect updated test surface.

The import correctly narrows to FakeTensorQuantTester only, aligning with the removal of TensorQuantTester and FakeAffineTensorQuantTester from test utilities as part of the broader test cleanup.

tests/gpu/torch/quantization/test_tensor_quant_cuda.py (6)

21-21: LGTM: Import simplified to reflect updated test utilities.

The import correctly narrows to FakeTensorQuantTester, consistent with the test utilities cleanup across the PR.


26-26: LGTM: Extension imports updated to reflect new FP8 path.

The removal of get_cuda_ext_fp8 from imports aligns with the PR's refactoring where FP8 functionality is now accessed through different pathways (e.g., via scaled_e4m3 which internally calls get_cuda_ext_fp8).


113-127: LGTM: FP8 tests now use pure PyTorch reference implementation.

The tests correctly use fp8_eager as the reference, addressing the past review feedback to use pure torch implementation. The parameterization across devices and dtypes improves coverage, and the CPU skip for non-float32 is appropriate.


129-134: LGTM: Incontiguous test updated with pure torch reference.

Consistent with the other FP8 test updates, ensuring non-contiguous tensors are validated against the pure PyTorch implementation.


148-156: LGTM: Multi-GPU test validates cross-device correctness.

The test correctly validates that FP8 quantization works when the tensor is on a non-current GPU device, using the pure torch reference for comparison.


157-163: LGTM: Per-channel FP8 test validates new axis-aware functionality.

This new test correctly validates the per-channel FP8 quantization feature introduced in this PR, testing across multiple axes and using the pure PyTorch reference for validation.

modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)

21-21: Note: <optional> included but not visibly used in this file.

The <optional> header is included but doesn't appear to be used in the code shown. This might be for future use or used in code not shown in this review.


35-45: LGTM: Per-tensor FP8 kernel correctly fuses scale/inv_scale in fp32.

The kernel correctly implements the PR objective by applying scale and inv_scale in fp32 before/after the FP8 conversion, preventing scaling from being applied in bf16 when inputs are bf16. The 4-element per-thread loop improves memory coalescing.


47-64: LGTM: Axis-aware FP8 kernel correctly implements per-channel quantization.

The kernel correctly computes axis_id from flattened indices and applies per-axis scale/inv_scale, enabling per-channel FP8 quantization. The indexing logic is consistent with the integer per-axis kernel in tensor_quant_gpu.cu (lines 115-116).


103-108: LGTM: PyBind11 module correctly exposes FP8 functions.

The module definition correctly binds both per-tensor and per-axis FP8 quantization functions with clear argument names, making them accessible from Python.

modelopt/torch/quantization/tensor_quant.py (5)

45-66: LGTM: FP8 eager implementation provides pure PyTorch reference.

The implementation correctly provides a pure PyTorch reference for FP8 quantization, addressing past review feedback. Key highlights:

  • Scaling applied in fp32 (lines 48-50) prevents bf16 rounding issues
  • amax=None case enables dynamic quantization without scaling
  • Triton optimization via torch.compile improves performance when available

82-97: LGTM: Robust fallback logic for FP8 routing.

The routing correctly falls back to eager mode for non-CUDA tensors, missing amax, or unsupported amax shapes (line 82), addressing past review feedback. The CUDA path appropriately routes between per-tensor and per-axis kernels based on amax shape.

Note: Axis inference via amax.shape.index(amax.numel()) (line 95) assumes unique matching dimensions. Past review comments noted this could be fragile, but the maintainer confirmed the current approach is acceptable.


379-382: LGTM: Legacy quant function simplified.

The function correctly uses the updated _tensor_quant which now returns only outputs (not a tuple), simplifying the call site.


589-590: LGTM: Explicit float casting ensures fp32 computation.

The explicit casting of inputs and amax to float (lines 589-590) ensures all quantization math happens in fp32, preventing precision issues when inputs are bf16 or fp16.


616-619: LGTM: Dequantization preserves fp32 precision before dtype conversion.

The division by scale is performed in float (line 616), and the result is cast back to the original dtype (line 618), addressing past review feedback about preserving fp32 precision in the QDQ path.

@realAsma
Copy link
Contributor Author

realAsma commented Sep 30, 2025

@mxinO I am running into issues with torch.compile of FP8 eager mode. Let us implement torch.compile + FP8 eager mode in a seperate PR. Let me merge the current PR to fix the FP8 QAT bug on the main branch.

Signed-off-by: realAsma <[email protected]>

minor

Signed-off-by: realAsma <[email protected]>

minor

Signed-off-by: realAsma <[email protected]>

minor

Signed-off-by: realAsma <[email protected]>

Added FP8 eager mode with triton for multi-axis FP8 fake quant

minor

Signed-off-by: realAsma <[email protected]>

Use fp8_eager for cpu or if amax is None

minor

Signed-off-by: realAsma <[email protected]>

minor

Signed-off-by: realAsma <[email protected]>

minor

Signed-off-by: realAsma <[email protected]>

minor

Signed-off-by: realAsma <[email protected]>

minor

Signed-off-by: realAsma <[email protected]>

minor

Signed-off-by: realAsma <[email protected]>

Removed FP8 eager with torch.compile

Signed-off-by: realAsma <[email protected]>
@realAsma realAsma force-pushed the asma/fp8_fakequant_fix branch from dc12e82 to 56f5a08 Compare September 30, 2025 00:21
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
modelopt/torch/quantization/tensor_quant.py (1)

56-58: Consider allowing amax=None in signature for consistency.

The wrapped _fp8_eager function accepts amax=None for dynamic quantization without scaling, but this wrapper requires amax. Consider adding a default parameter amax=None to match the underlying implementation and support the full range of use cases.

Apply this diff:

-def fp8_eager(x, amax):
+def fp8_eager(x, amax=None):
     """Eager mode implementation of FP8 quantization."""
     return _fp8_eager(x, amax)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1ae9302 and dc12e82.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/tensor_quant.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/tensor_quant.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • amax (236-241)
  • amax (244-255)
  • axis (283-285)
  • axis (288-290)
modelopt/torch/quantization/extensions.py (1)
  • get_cuda_ext_fp8 (39-52)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
  • fake_e4m3fy (87-101)
  • fake_e4m3fy (87-87)
  • fake_e4m3fy_with_axis (66-85)
  • fake_e4m3fy_with_axis (66-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (4)
modelopt/torch/quantization/tensor_quant.py (4)

371-374: LGTM! Simplified to match _tensor_quant's new return signature.

The removal of separate scale handling aligns with _tensor_quant now returning fully scaled outputs directly.


573-611: LGTM! FP32 precision preserved throughout quantization.

The refactored _tensor_quant correctly performs all scaling operations in fp32 before casting back to the input dtype, addressing the precision concerns from the PR objectives. The zero-amax handling logic at lines 599-607 is also correct.


45-54: No amax device alignment issues found. All existing callers supply amax as a tensor on the same device as x, so no additional checks are required.


84-88: Fix: axis inference crashes when amax has size-1 dimensions.

When amax has a shape like [1, 128, 1, 1] for per-channel quantization, line 87's amax.shape.index(amax.numel()) will raise ValueError because 128 is not present in the shape tuple (1, 128, 1, 1).

Apply this diff to fix the axis inference:

         if amax.numel() == 1:
             outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
         elif amax.squeeze().ndim == 1:
-            axis = amax.shape.index(amax.numel())
+            # Find the axis where amax size matches (non-singleton dimension)
+            axis = next(i for i, s in enumerate(amax.shape) if s == amax.numel())
             outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)

Alternative robust approach:

         if amax.numel() == 1:
             outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
         elif amax.squeeze().ndim == 1:
-            axis = amax.shape.index(amax.numel())
+            # Find axis by matching non-singleton dimensions
+            candidates = [i for i, s in enumerate(amax.shape) if s > 1]
+            if len(candidates) != 1:
+                return fp8_eager(inputs, amax)  # Ambiguous, use eager path
+            axis = candidates[0]
             outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)

Likely an incorrect or invalid review comment.

@realAsma realAsma enabled auto-merge (squash) September 30, 2025 00:22
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (1)

21-21: Remove unused include.

The <optional> header is included but never used in this file. No std::optional types are present in the code.

Apply this diff:

-#include <optional>
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dc12e82 and 56f5a08.

📒 Files selected for processing (9)
  • modelopt/torch/quantization/extensions.py (1 hunks)
  • modelopt/torch/quantization/src/tensor_quant.h (0 hunks)
  • modelopt/torch/quantization/src/tensor_quant_fp8.cpp (0 hunks)
  • modelopt/torch/quantization/src/tensor_quant_gpu.cu (4 hunks)
  • modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (2 hunks)
  • modelopt/torch/quantization/tensor_quant.py (5 hunks)
  • tests/_test_utils/torch_quantization/tensor_quant_common.py (0 hunks)
  • tests/gpu/torch/quantization/test_tensor_quant_cuda.py (3 hunks)
  • tests/unit/torch/quantization/test_tensor_quant_cpu.py (1 hunks)
💤 Files with no reviewable changes (3)
  • modelopt/torch/quantization/src/tensor_quant.h
  • tests/_test_utils/torch_quantization/tensor_quant_common.py
  • modelopt/torch/quantization/src/tensor_quant_fp8.cpp
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt/torch/quantization/extensions.py
  • modelopt/torch/quantization/src/tensor_quant_gpu.cu
🧰 Additional context used
🧬 Code graph analysis (4)
tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
  • FakeTensorQuantTester (126-136)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (1)
modelopt/torch/quantization/src/tensor_quant_gpu.cu (6)
  • void (59-73)
  • void (101-120)
  • void (147-167)
  • void (211-233)
  • void (262-279)
  • void (311-340)
tests/gpu/torch/quantization/test_tensor_quant_cuda.py (4)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
  • FakeTensorQuantTester (126-136)
modelopt/torch/quantization/extensions.py (2)
  • get_cuda_ext (27-36)
  • get_cuda_ext_mx (55-71)
modelopt/torch/quantization/tensor_quant.py (1)
  • fp8_eager (56-58)
modelopt/torch/quantization/utils.py (1)
  • reduce_amax (148-180)
modelopt/torch/quantization/tensor_quant.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • amax (236-241)
  • amax (244-255)
  • axis (283-285)
  • axis (288-290)
modelopt/torch/quantization/extensions.py (1)
  • get_cuda_ext_fp8 (39-52)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
  • fake_e4m3fy (87-101)
  • fake_e4m3fy (87-87)
  • fake_e4m3fy_with_axis (66-85)
  • fake_e4m3fy_with_axis (66-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (15)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)

47-64: LGTM! Axis-aware FP8 quantization kernel is correctly implemented.

The kernel correctly computes per-axis scale indices and applies FP32 scaling, which addresses the bf16 scaling issue mentioned in the PR objectives. The logic aligns with the integer quantization pattern from tensor_quant_gpu.cu.


66-85: LGTM! Host function correctly prepares tensors for axis-aware quantization.

The function ensures contiguity, casts amax to FP32, and uses the current CUDA stream. Based on past review comments, input validation (axis normalization, device matching, shape validation) is handled in the C++ dispatcher fake_e4m3fy_with_axis, which is appropriate for separation of concerns.

Based on past review comments.


87-101: LGTM! Per-tensor FP8 quantization host function is correct.

The function properly handles per-tensor quantization with FP32 scaling. The view(-1) call flattens amax, and shape validation is handled in the Python dispatcher scaled_e4m3_impl (as confirmed by past review comments), which is appropriate.

Based on past review comments.


103-108: LGTM! PyBind11 bindings correctly expose the new FP8 API.

The module exports both per-tensor and per-axis FP8 quantization functions with appropriate signatures and named arguments, aligning with the PR's goal to provide axis-aware FP8 quantization.

tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)

22-22: LGTM! Import simplified to match updated test utilities.

The import change correctly reflects the removal of TensorQuantTester and FakeAffineTensorQuantTester base classes from the test utilities, aligning with the PR's simplification of the quantization test surface.

tests/gpu/torch/quantization/test_tensor_quant_cuda.py (6)

21-26: LGTM! Import surface updated to reflect removed FP8 CPU extension.

The removal of get_cuda_ext_fp8 import aligns with the PR's removal of CPU-side FP8 implementations and the consolidation of FP8 paths to CUDA-only or eager mode.


112-116: LGTM! Test correctly uses eager mode as reference.

The test now uses fp8_eager as the ground truth reference, which aligns with the PR's goal to validate that the CUDA kernel produces the same results as the eager implementation.


118-127: LGTM! Enhanced test coverage with dtype parameterization.

The test now validates FP8 quantization across multiple dtypes (bfloat16, float16, float32) with appropriate CPU skip logic, and uses fp8_eager as the reference. This ensures the FP32 scaling fix works correctly across different input precisions.


129-134: LGTM! Non-contiguous tensor handling verified.

The test correctly validates that non-contiguous tensors are handled properly by the quantization path, with fp8_eager as the reference.


148-156: LGTM! Multi-GPU handling correctly tested.

The test validates that FP8 quantization works correctly when tensors are on a non-current GPU device, with proper device synchronization.


157-163: LGTM! New per-channel quantization test validates axis-aware functionality.

This new test validates the core feature of the PR: per-channel FP8 quantization with axis support. The test correctly computes per-axis amax and validates against the eager reference across different axes (0, 1, 2).

modelopt/torch/quantization/tensor_quant.py (4)

45-59: LGTM! FP8 eager implementation correctly applies FP32 scaling.

The eager mode implementation correctly:

  • Preserves the original dtype throughout the operation
  • Applies scaling in FP32 (448.0 / amax), addressing the bf16 scaling issue mentioned in PR objectives
  • Handles the amax=None case for unscaled quantization (useful for dynamic quantization)
  • Uses native PyTorch FP8 dtype (torch.float8_e4m3fn)

573-611: LGTM! Quantization math now performed in FP32.

The function correctly:

  • Casts inputs and amax to float32 before computation (lines 581-582)
  • Performs all quantization math in FP32, including the dequantization division (line 608)
  • Casts back to the original input dtype at the end (line 610)

This addresses the bf16 scaling issue mentioned in the PR objectives and past review comments.

Based on past review comments.


371-374: LGTM! Legacy quantization function updated to match new _tensor_quant signature.

The function correctly calls the updated _tensor_quant which now returns only outputs (not a tuple) and performs all computation in FP32.


74-89: Ignore axis inference ambiguity concern
The amax.squeeze().ndim > 1 guard ensures axis routing only runs on 1-D keepdims outputs, where exactly one dimension matches amax.numel(), so shape.index(...) is unambiguous.

Likely an incorrect or invalid review comment.

Signed-off-by: realAsma <[email protected]>
@realAsma realAsma merged commit 83d7623 into main Sep 30, 2025
27 checks passed
@realAsma realAsma deleted the asma/fp8_fakequant_fix branch September 30, 2025 01:56
kevalmorabia97 pushed a commit that referenced this pull request Oct 2, 2025
…or FP8 per-channel quantization (#381)

Signed-off-by: realAsma <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants