-
Notifications
You must be signed in to change notification settings - Fork 169
Fixed FP8 fake quantization to use fp32 amax scaling; Added support for FP8 per-channel quantization #381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughRemoved CPU FP8 sources and header declaration; introduced axis-aware FP8 CUDA kernels and host APIs; Python FP8 routing added eager and axis-aware paths; CUDA kernel launches now use the current stream; tests and test utilities simplified. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant PY as Python (scaled_e4m3_impl)
participant EG as fp8_eager (CPU/Python)
participant EXT as CUDA extension (pybind11)
participant CU as CUDA host/kernels
PY->>PY: receive inputs, optional amax
alt CPU / fallback / no CUDA ext
PY->>EG: fp8_eager(inputs, amax)
EG-->>PY: quantized outputs
else CUDA & scalar amax
PY->>EXT: fake_e4m3fy(inputs, amax)
EXT->>CU: launch scalar CUDA host (scale/inv_scale)
CU->>CU: elementwise kernel applies scalar scale
CU-->>PY: quantized outputs
else CUDA & 1‑D amax (per-axis)
PY->>EXT: fake_e4m3fy_with_axis(inputs, amax, axis)
EXT->>CU: launch axis-aware host (per-axis scale/inv_scale)
CU->>CU: kernel computes axis_id, applies per‑axis scale
CU-->>PY: quantized outputs
end
PY-->>PY: finalize/cast outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧬 Code graph analysis (1)modelopt/torch/quantization/tensor_quant.py (3)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
🔇 Additional comments (6)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🧪 Early access (Sonnet 4.5): enabledWe are currently testing the Sonnet 4.5 model, which is expected to improve code review quality. However, this model may lead to increased noise levels in the review comments. Please disable the early access features if the noise level causes any inconvenience. Note:
Comment |
87e9050
to
34bcee8
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #381 +/- ##
==========================================
- Coverage 73.86% 73.81% -0.05%
==========================================
Files 171 171
Lines 17629 17574 -55
==========================================
- Hits 13021 12972 -49
+ Misses 4608 4602 -6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/tensor_quant.py (1)
45-72
: Fix device-guarding and robust axis inference; ensure amax on correct device; update error text.
- Calling torch.cuda.current_device() unconditionally will crash on CPU-only builds. Guard the CUDA context manager with inputs.is_cuda.
- In the per-axis path, axis = amax.shape.index(amax.numel()) is fragile/incorrect; it often resolves to 0 and doesn’t consider inputs’ shape. Infer axis by matching amax.numel() to an inputs dimension (or assert axis 0 if that’s the only supported mode).
- Ensure amax is on the same device as inputs before calling the CUDA kernel.
- Error message still references “Fused” though that path was removed.
Proposed diff:
def scaled_e4m3_impl( inputs: torch.Tensor, amax: torch.Tensor | None = None, ) -> torch.Tensor: @@ - with torch.cuda.device( - None if inputs.device.index == torch.cuda.current_device() else inputs.device.index - ): - if amax is None or amax.numel() == 1: - outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax) - else: - if amax.squeeze().ndim > 1: - raise NotImplementedError( - "Fused E4M3 kernel does not support multiaxis quantization." - ) - axis = amax.shape.index(amax.numel()) - outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis) - return outputs + if inputs.is_cuda: + with torch.cuda.device( + None if inputs.device.index == torch.cuda.current_device() else inputs.device.index + ): + if amax is None or amax.numel() == 1: + outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax) + else: + if amax.squeeze().ndim > 1: + raise NotImplementedError("E4M3 per-axis quantization supports only 1D amax.") + # Ensure amax device matches inputs + amax = amax.to(device=inputs.device) + # Infer axis by matching amax length to one input dimension (default to 0 if ambiguous) + match_axes = [i for i, d in enumerate(inputs.shape) if d == amax.numel()] + if not match_axes: + raise ValueError("amax length does not match any input dimension.") + axis = match_axes[0] + outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis) + return outputs + else: + # CPU path routes through the C++ wrapper which implements a CPU fallback. + if amax is None or amax.numel() == 1: + return cuda_ext_fp8.fake_e4m3fy(inputs, amax) + if amax.squeeze().ndim > 1: + raise NotImplementedError("E4M3 per-axis quantization supports only 1D amax.") + # Prefer axis 0 on CPU unless inputs dim match is unambiguous + match_axes = [i for i, d in enumerate(inputs.shape) if d == amax.numel()] + axis = match_axes[0] if match_axes else 0 + return cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
🧹 Nitpick comments (6)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (2)
56-61
: Axis index computation assumes contiguous layoutaxis_id = (idx / stride(axis)) % size(axis) is only valid for contiguous row-major tensors. With the added contiguity normalization above, this becomes safe; otherwise results can be wrong for channels-last or non-contiguous inputs.
If you don't want to force contiguity here, rework indexing to derive multi-dimensional coordinates from idx and use them to index scale.
96-99
: Ceil-div grid sizing for the per-tensor kernelMinor efficiency fix: avoid launching an extra block.
Apply this diff:
- AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_cuda", [&] { - fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>( + AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_cuda", [&] { + const int grid = (numel + BLOCK_SIZE * 4 - 1) / (BLOCK_SIZE * 4); + fake_e4m3fy_kernel<<<grid, BLOCK_SIZE, 0, stream>>>( inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(), outputs.data_ptr<scalar_t>()); });tests/gpu/torch/quantization/test_tensor_quant_cuda.py (2)
111-122
: Harden reference against amax==0 producing NaNsIf amax is zero (e.g., all-zero channel), 0*inf leads to NaNs. Mirror the kernel-side guard by making amax safe in the reference.
Apply this diff:
def _fp8_qdq(x, amax=None): dtype = x.dtype x = x.to(torch.float32) scale, scale_inv = None, None if amax is not None: - scale = 448.0 / (amax.to(torch.float32)) + amax_safe = torch.where(amax == 0, torch.ones_like(amax, dtype=torch.float32), amax.to(torch.float32)) + scale = 448.0 / amax_safe scale_inv = 1 / scale x = x * scale x = x.to(torch.float8_e4m3fn).to(torch.float32) if scale_inv is not None: x = x * scale_inv return x.to(dtype)
173-180
: Per-channel test aligns with new kernel pathGood coverage across multiple axes; consider adding a case where one channel is all zeros to verify the NaN guard end-to-end.
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
268-275
: Use property consistently in step_size (optional)Minor consistency: prefer self.amax over self._amax for readability and single source of truth.
Apply this diff:
- return self._amax / (2.0 ** (self._num_bits - 1 + int(self._unsigned)) - 1.0) + return self.amax / (2.0 ** (self._num_bits - 1 + int(self._unsigned)) - 1.0)
847-855
: Avoid mutating internal amax in export_amaxexport_amax currently modifies the buffer in-place when replacing zeros and nan. Clone before in-place ops to keep internal state intact.
Apply this diff:
- if not hasattr(self, "_amax_shape_for_export"): - amax = self.amax - else: - amax = self.amax.reshape(self._amax_shape_for_export) + if not hasattr(self, "_amax_shape_for_export"): + amax = self.amax.clone() + else: + amax = self.amax.reshape(self._amax_shape_for_export).clone() amax[amax == 0] = self.maxbound amax = torch.nan_to_num(amax, nan=self.maxbound)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(4 hunks)modelopt/torch/quantization/src/tensor_quant.h
(0 hunks)modelopt/torch/quantization/src/tensor_quant_fp8.cpp
(1 hunks)modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu
(2 hunks)modelopt/torch/quantization/tensor_quant.py
(4 hunks)tests/_test_utils/torch_quantization/tensor_quant_common.py
(1 hunks)tests/gpu/torch/quantization/test_tensor_quant_cuda.py
(3 hunks)tests/unit/torch/quantization/test_tensor_quant_cpu.py
(1 hunks)
💤 Files with no reviewable changes (1)
- modelopt/torch/quantization/src/tensor_quant.h
🧰 Additional context used
🧬 Code graph analysis (6)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
amax
(231-236)amax
(239-250)axis
(277-279)axis
(282-284)
tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
FakeTensorQuantTester
(125-135)
tests/gpu/torch/quantization/test_tensor_quant_cuda.py (2)
tests/_test_utils/torch_misc.py (1)
set_seed
(33-40)modelopt/torch/quantization/utils.py (1)
reduce_amax
(148-180)
modelopt/torch/quantization/tensor_quant.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
amax
(231-236)amax
(239-250)axis
(277-279)axis
(282-284)modelopt/torch/quantization/src/tensor_quant_fp8.cpp (4)
fake_e4m3fy
(26-48)fake_e4m3fy
(26-26)fake_e4m3fy_with_axis
(50-53)fake_e4m3fy_with_axis
(50-50)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
to
(115-123)
modelopt/torch/quantization/src/tensor_quant_fp8.cpp (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
amax
(231-236)amax
(239-250)axis
(277-279)axis
(282-284)modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
fake_e4m3fy_cuda
(84-101)fake_e4m3fy_cuda
(84-84)fake_e4m3fy_cuda_with_axis
(66-82)fake_e4m3fy_cuda_with_axis
(66-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (6)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
135-135
: LGTM: stability guard addedAsserting no inf/nan in fp16 path is good coverage for overflow regressions.
tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)
22-23
: LGTM: imports and test surface alignedCPU tests now only depend on FakeTensorQuantTester, consistent with the utility refactor.
tests/gpu/torch/quantization/test_tensor_quant_cuda.py (1)
20-24
: Addressed prior ask: references now use pure torchSwitching to a pure torch reference (_fp8_qdq) resolves the earlier concern about reference correctness.
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
244-251
: Forcing amax to float32 is correctStandardizing amax dtype to float32 simplifies CUDA bindings and distributed sync.
modelopt/torch/quantization/tensor_quant.py (1)
564-566
: LGTM on dtype handling in INT path.Casting inputs/amax to float and restoring outputs to input dtype is correct and avoids fp16/bf16 overflow. Safe given quantized range <= 2^7 fits fp16/bf16.
Also applies to: 592-593
modelopt/torch/quantization/src/tensor_quant_fp8.cpp (1)
55-60
: Bindings look good; optional amax default is correct.API surface matches the Python usage and simplifies the call sites. Nice cleanup.
Please re-run unit tests that pass amax from CPU to ensure no hidden device mismatches remain after the fix.
ee03094
to
dfca83e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
130-135
: Fix wrong positional args in test; pass bias explicitly or use keywords.You're passing 8 as bias and False as num_bits. This breaks the test intent.
Apply:
- quant_x_test = tensor_quant.fake_tensor_quant( - x, torch.tensor(1e-4).to(self.device).half(), 8, False - ) + quant_x_test = tensor_quant.fake_tensor_quant( + x, + torch.tensor(1e-4, device=self.device).to(x.dtype), + bias=None, + num_bits=8, + unsigned=False, + )
🧹 Nitpick comments (3)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (1)
95-99
: Use ceil-div for grid sizing in per-tensor path (consistency).Minor efficiency and consistency fix.
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_cuda", [&] { - fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>( + const int grid = (numel + BLOCK_SIZE * 4 - 1) / (BLOCK_SIZE * 4); + fake_e4m3fy_kernel<<<grid, BLOCK_SIZE, 0, stream>>>( inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(), outputs.data_ptr<scalar_t>()); });tests/gpu/torch/quantization/test_tensor_quant_cuda.py (2)
104-109
: Prefer 1D amax for with-axis CUDA API (avoid relying on broadcast).Simplifies expectations and mirrors other tests.
- amax = torch.tensor([1.0, 1.0e-26, 1.0]).cuda().unsqueeze(-1).unsqueeze(1) - quant_x = get_cuda_ext().fake_tensor_quant_with_axis(x, amax, axis=1) + amax = torch.tensor([1.0, 1.0e-26, 1.0], device=x.device) + quant_x = get_cuda_ext().fake_tensor_quant_with_axis(x, amax, axis=1)
157-164
: Good per-channel FP8 coverage. Consider adding a non-contiguous variant.Transpose along a non-axis dim to exercise the dispatcher’s contiguity path.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
modelopt/torch/quantization/src/tensor_quant.h
(0 hunks)modelopt/torch/quantization/src/tensor_quant_fp8.cpp
(1 hunks)modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu
(2 hunks)modelopt/torch/quantization/tensor_quant.py
(4 hunks)tests/_test_utils/torch_quantization/tensor_quant_common.py
(1 hunks)tests/gpu/torch/quantization/test_tensor_quant_cuda.py
(3 hunks)tests/unit/torch/quantization/test_tensor_quant_cpu.py
(1 hunks)
💤 Files with no reviewable changes (1)
- modelopt/torch/quantization/src/tensor_quant.h
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/src/tensor_quant_fp8.cpp
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/torch/quantization/tensor_quant.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
amax
(236-241)amax
(244-255)axis
(283-285)axis
(288-290)modelopt/torch/quantization/src/tensor_quant_fp8.cpp (4)
fake_e4m3fy
(26-48)fake_e4m3fy
(26-26)fake_e4m3fy_with_axis
(50-53)fake_e4m3fy_with_axis
(50-50)
tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
FakeTensorQuantTester
(125-135)
tests/gpu/torch/quantization/test_tensor_quant_cuda.py (5)
tests/_test_utils/torch_misc.py (1)
set_seed
(33-40)tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
FakeTensorQuantTester
(125-135)modelopt/torch/quantization/extensions.py (2)
get_cuda_ext
(27-36)get_cuda_ext_mx
(55-71)modelopt/torch/quantization/tensor_quant.py (1)
fp8_eager
(56-60)modelopt/torch/quantization/utils.py (1)
reduce_amax
(148-180)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (4)
modelopt/torch/quantization/tensor_quant.py (1)
579-581
: Compute in FP32 and restore input dtype — LGTM.This prevents overflow/rounding artifacts in half/bfloat16.
Also applies to: 607-607
tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)
22-22
: Import consolidation LGTM—no stale tester references found
Verified no occurrences ofTensorQuantTester
orFakeAffineTensorQuantTester
in the test suite.modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (2)
75-79
: Launch on current CUDA stream and use ceil-div for grid.Avoid default stream usage and over-provisioned grids.
- AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis_cuda", [&] { - fake_e4m3fy_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE>>>( - inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(), - axis_size, outer_size, outputs.data_ptr<scalar_t>()); - }); + AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis_cuda", [&] { + auto stream = c10::cuda::getCurrentCUDAStream(); + const int grid = (numel + BLOCK_SIZE * 4 - 1) / (BLOCK_SIZE * 4); + fake_e4m3fy_with_axis_cuda_kernel<<<grid, BLOCK_SIZE, 0, stream>>>( + inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(), + axis_size, outer_size, outputs.data_ptr<scalar_t>()); + });
66-71
: Normalize axis and validate bounds before size/stride.Handle negative axes and guard against OOB axis to avoid mis-indexing.
at::Tensor fake_e4m3fy_cuda_with_axis(at::Tensor inputs, at::Tensor amax, int axis) { auto outputs = torch::empty_like(inputs); - size_t numel = inputs.numel(); - int axis_size = inputs.size(axis); - int outer_size = inputs.stride(axis); + size_t numel = inputs.numel(); + if (axis < 0) { + axis += inputs.dim(); + } + TORCH_CHECK(axis >= 0 && axis < inputs.dim(), "fake_e4m3fy_cuda_with_axis: axis out of range"); + int axis_size = inputs.size(axis); + int outer_size = inputs.stride(axis);
dfca83e
to
ddbd66d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/tensor_quant.py (1)
443-451
: quantize_op schema requires Tensor amax; None will break. Route None to eager path.scaled_e4m3 now allows amax=None, but torch.library quantize_op is defined with Tensor amax. Passing None can error when the library path is active.
Apply this diff to handle None without touching the op schema:
- outputs = quantize_op( - inputs, - amax, - num_bits=8, - exponent_bits=4, - unsigned=False, - narrow_range=False, - ) + if amax is None: + outputs = scaled_e4m3_impl(inputs=inputs, amax=None) + else: + outputs = quantize_op( + inputs, + amax, + num_bits=8, + exponent_bits=4, + unsigned=False, + narrow_range=False, + )Alternatively, define a None-overload for quantize_op like you did for dynamic_block_quantize_op.
🧹 Nitpick comments (3)
modelopt/torch/quantization/src/tensor_quant_gpu.cu (2)
18-24
: Include CUDAStream header for portability.c10::cuda::getCurrentCUDAStream() requires c10/cuda/CUDAStream.h on some PyTorch versions. Add the explicit include to avoid build issues.
Apply this diff:
#include <torch/extension.h> +#include <c10/cuda/CUDAStream.h>
130-136
: Use the current stream consistently across kernels.NF4/INT4 kernels below still launch on the default stream. For consistency and better interop, consider passing the current stream to those launches as well.
If you choose to update, mirror the pattern here:
- NF4_dequantize_kernel<<<blocks, block_size / 2>>>(...) + auto stream = c10::cuda::getCurrentCUDAStream(); + NF4_dequantize_kernel<<<blocks, block_size / 2, 0, stream>>>(...)modelopt/torch/quantization/tensor_quant.py (1)
52-55
: Guard torch.compile usage to avoid import-time failures.Decorating with torch.compile at import time can break older Torch builds or unsupported backends. Consider guarding or compiling lazily.
Example:
if hasattr(torch, "compile"): _fp8_triton = torch.compile(dynamic=True)(_fp8_eager) else: _fp8_triton = _fp8_eager
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
modelopt/torch/quantization/src/tensor_quant.h
(0 hunks)modelopt/torch/quantization/src/tensor_quant_fp8.cpp
(1 hunks)modelopt/torch/quantization/src/tensor_quant_gpu.cu
(3 hunks)modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu
(2 hunks)modelopt/torch/quantization/tensor_quant.py
(4 hunks)tests/_test_utils/torch_quantization/tensor_quant_common.py
(1 hunks)tests/gpu/torch/quantization/test_tensor_quant_cuda.py
(3 hunks)tests/unit/torch/quantization/test_tensor_quant_cpu.py
(1 hunks)
💤 Files with no reviewable changes (1)
- modelopt/torch/quantization/src/tensor_quant.h
🧰 Additional context used
🧬 Code graph analysis (5)
tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
FakeTensorQuantTester
(125-135)
modelopt/torch/quantization/tensor_quant.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
amax
(236-241)amax
(244-255)axis
(283-285)axis
(288-290)modelopt/torch/quantization/src/tensor_quant_fp8.cpp (4)
fake_e4m3fy
(26-48)fake_e4m3fy
(26-26)fake_e4m3fy_with_axis
(50-53)fake_e4m3fy_with_axis
(50-50)
tests/gpu/torch/quantization/test_tensor_quant_cuda.py (5)
tests/_test_utils/torch_misc.py (1)
set_seed
(33-40)tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
FakeTensorQuantTester
(125-135)modelopt/torch/quantization/extensions.py (2)
get_cuda_ext
(27-36)get_cuda_ext_mx
(55-71)modelopt/torch/quantization/tensor_quant.py (1)
fp8_eager
(56-60)modelopt/torch/quantization/utils.py (1)
reduce_amax
(148-180)
modelopt/torch/quantization/src/tensor_quant_fp8.cpp (1)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
fake_e4m3fy_cuda
(85-102)fake_e4m3fy_cuda
(85-85)fake_e4m3fy_cuda_with_axis
(66-83)fake_e4m3fy_cuda_with_axis
(66-66)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (1)
modelopt/torch/quantization/src/tensor_quant_gpu.cu (6)
void
(58-72)void
(100-119)void
(145-165)void
(209-231)void
(260-277)void
(309-338)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (15)
modelopt/torch/quantization/src/tensor_quant_gpu.cu (1)
77-83
: Good: launch on the current CUDA stream.Passing the current stream to the kernel improves interop with upstream stream semantics.
Confirm that callers always enter this function under the proper device context (Python code seems to guard with torch.cuda.device). Otherwise, consider guarding here with cudaSetDevice(inputs.device().index()) for safety.
tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)
22-22
: LGTM: import surface simplified.Importing only FakeTensorQuantTester aligns with the updated shared test utilities.
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
130-135
: LGTM: overflow/NaN guard.The explicit inf/NaN assertion strengthens the fp16 overflow test.
modelopt/torch/quantization/tensor_quant.py (2)
44-51
: FP8 QDQ in fp32 is correct.Doing scale/convert/scale-back in fp32 and restoring the original dtype is the right fix to avoid bf16 scaling errors.
80-85
: Axis inference is fragile; fix ambiguous/mismatched cases.amax.squeeze().ndim == 1 with amax.shape.index(amax.numel()) can mis-infer or throw when multiple dims match or keepdims differ.
Apply robust inference and fall back to eager when ambiguous:
- elif amax.squeeze().ndim == 1: - axis = amax.shape.index(amax.numel()) - outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis) + elif amax.squeeze().ndim == 1: + amax_s = amax.squeeze() + axis_inferred = None + if amax.dim() == inputs.dim(): + candidates = [ + i + for i, (asize, isize) in enumerate(zip(amax.shape, inputs.shape)) + if asize == isize and asize > 1 + ] + if len(candidates) == 1: + axis_inferred = candidates[0] + else: + candidates = [i for i, isize in enumerate(inputs.shape) if isize == amax_s.numel()] + if len(candidates) == 1: + axis_inferred = candidates[0] + if axis_inferred is not None: + outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax_s.reshape(-1), axis_inferred) + else: + outputs = fp8_eager(inputs, amax)tests/gpu/torch/quantization/test_tensor_quant_cuda.py (2)
113-116
: LGTM: runtime FP8 reference using fp8_eager.This validates scaled_e4m3 against the eager reference and covers both CUDA and CPU.
157-164
: Per-channel FP8 coverage looks good.Axis-parameterized test aligns with the new per-channel path and eager fallback behavior.
modelopt/torch/quantization/src/tensor_quant_fp8.cpp (4)
26-33
: LGTM! Proper optional amax handling added.The function signature has been updated to accept an optional amax parameter and correctly handles both cases (with and without amax). The scaling factor of 448 is appropriate for FP8 E4M3 format which can store values up to +/-448.
34-48
: LGTM! CPU path correctly implements FP8 quantization.The CPU implementation properly calculates scale and inv_scale, and the quantization logic correctly converts to FP8 E4M3 and back to maintain the same precision as the CUDA kernel.
50-53
: Device mismatch not addressed despite past review.Previous review flagged that amax may reside on CPU while inputs are on CUDA, causing device mismatch in the CUDA kernel. The author mentioned this is "fine" but the issue remains unresolved.
55-60
: LGTM! Module bindings correctly updated for new API.The Python bindings properly expose both the regular
fake_e4m3fy
with optional amax and the new axis-aware variantfake_e4m3fy_with_axis
.modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
35-45
: LGTM! Kernel signature updated for fused QDQ scaling.The kernel signature properly accepts scale and inv_scale pointers, enabling fp32 scaling computation as intended by the PR objectives. The loop correctly processes 4 elements per thread for vectorization.
47-64
: LGTM! Axis-aware kernel correctly implements per-channel quantization.The kernel properly computes axis_id using stride-based indexing and applies per-axis scale/inv_scale values. This enables the FP8 per-channel quantization along axis-0 as described in the PR objectives.
66-83
: Grid sizing and contiguity concerns from previous review unaddressed.Previous reviews flagged:
- Over-provisioned grid size using
+ 1
instead of proper ceil-div- Missing contiguity checks for stride-based axis indexing
- Missing axis normalization for negative values
The author indicated contiguity is checked in the C++ dispatcher, but grid sizing remains suboptimal.
85-102
: LGTM! Per-tensor kernel properly handles optional amax.The function correctly computes scale from amax when provided, defaults to unity scale otherwise, and passes the scale/inv_scale pointers to the kernel. The implementation aligns with the PR objective of fusing QDQ scaling into the kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (3)
modelopt/torch/quantization/tensor_quant.py (1)
45-55
: Guard against non-positive amax and clarify broadcasting expectations.Division by zero/negatives will produce NaNs or invalid results; also per-channel eager path relies on caller-provided broadcastable shape.
Apply:
def _fp8_eager(x, amax=None): dtype = x.dtype if amax is not None: + if torch.any(amax <= 0): + raise ValueError("amax must be > 0 for FP8 scaling.") scale = 448.0 / (amax.to(torch.float32)) scale_inv = 1 / scale x = x.to(torch.float32) * scalePlease verify that CPU callers pass amax with a broadcastable shape for the intended axis (e.g., keepdim with 1s in other dims). Otherwise, broadcasting may target the last dimension by default.
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (2)
87-95
: Per-tensor path: enforce scalar amax.Prevent accidental multi-element amax on the scalar kernel.
Apply:
at::Tensor fake_e4m3fy(at::Tensor inputs, at::Tensor amax) { inputs = inputs.contiguous(); size_t numel = inputs.numel(); - at::Tensor scale = 448.f / amax; + TORCH_CHECK(amax.numel() == 1, + "fake_e4m3fy expects scalar amax; use fake_e4m3fy_with_axis for per-channel."); + at::Tensor scale = 448.f / amax; auto inv_scale = 1.f / scale;
77-82
: Optional: ceil-div grid sizing.Slightly cleaner and consistent with other kernels.
Apply:
- AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis", [&] { - fake_e4m3fy_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>( + AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis", [&] { + const int grid = (numel + BLOCK_SIZE * 4 - 1) / (BLOCK_SIZE * 4); + fake_e4m3fy_with_axis_cuda_kernel<<<grid, BLOCK_SIZE, 0, stream>>>( inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(), axis_size, outer_size, outputs.data_ptr<scalar_t>()); });
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
modelopt/torch/quantization/extensions.py
(1 hunks)modelopt/torch/quantization/src/tensor_quant_fp8.cpp
(0 hunks)modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu
(2 hunks)modelopt/torch/quantization/tensor_quant.py
(4 hunks)
💤 Files with no reviewable changes (1)
- modelopt/torch/quantization/src/tensor_quant_fp8.cpp
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/quantization/tensor_quant.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
amax
(236-241)amax
(244-255)axis
(283-285)axis
(288-290)modelopt/torch/quantization/extensions.py (1)
get_cuda_ext_fp8
(39-52)modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
fake_e4m3fy
(87-100)fake_e4m3fy
(87-87)fake_e4m3fy_with_axis
(66-85)fake_e4m3fy_with_axis
(66-66)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (1)
modelopt/torch/quantization/src/tensor_quant_gpu.cu (6)
void
(58-72)void
(100-119)void
(145-165)void
(209-231)void
(260-277)void
(309-338)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (3)
modelopt/torch/quantization/extensions.py (1)
44-45
: FP8 CUDA-only source switch looks good.Dropping the CPU source here is consistent with the new CUDA-only FP8 path. No issues.
modelopt/torch/quantization/tensor_quant.py (2)
383-386
: Do the dequant division in fp32 to preserve fp32 QDQ intent.Prevents bf16/half rounding in the legacy path.
Apply:
- return outputs / scale.to(inputs.dtype) + return (outputs.to(torch.float32) / scale.to(torch.float32)).to(inputs.dtype)
93-99
: Axis inference can be wrong; at least validate or fall back.Current axis = amax.shape.index(amax.numel()) resolves to 0 for 1D amax and can silently mismatch. Prefer robust inference; otherwise validate and fall back.
Apply:
- elif amax.squeeze().ndim == 1: - axis = amax.shape.index(amax.numel()) - outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis) + elif amax.squeeze().ndim == 1: + axis_len = amax.numel() + # Try to infer a unique matching axis from inputs' shape. + candidates = [i for i, s in enumerate(inputs.shape) if s == axis_len] + if len(candidates) == 1: + axis = candidates[0] + outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.reshape(-1), axis) + else: + # Ambiguous or no match; use eager path to respect broadcasting semantics. + outputs = fp8_eager(inputs, amax)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/torch/quantization/tensor_quant.py (1)
57-60
: Optionally guardtorch.compile
for wider PyTorch compatibility.Unconditional decorator can break on older torch. Gate it and fallback to eager.
Apply this diff:
-@torch.compile(dynamic=True) -def _fp8_triton(x, amax): - return _fp8_eager(x, amax) +try: + _fp8_triton = torch.compile(dynamic=True)(_fp8_eager) # type: ignore[attr-defined] +except AttributeError: + _fp8_triton = _fp8_eager
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/quantization/tensor_quant.py
(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/tensor_quant.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
amax
(236-241)amax
(244-255)axis
(283-285)axis
(288-290)modelopt/torch/quantization/extensions.py (1)
get_cuda_ext_fp8
(39-52)modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
fake_e4m3fy
(87-100)fake_e4m3fy
(87-87)fake_e4m3fy_with_axis
(66-85)fake_e4m3fy_with_axis
(66-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (6)
modelopt/torch/quantization/tensor_quant.py (6)
380-384
: Do the legacy dequant division in fp32 to preserve fp32 QDQ intent.Prevents bf16/fp16 rounding when dequantizing on CPU/fallback paths.
Apply this diff:
- outputs, scale = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range) - return outputs / scale.to(inputs.dtype) + outputs, scale = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range) + return (outputs.to(torch.float32) / scale.to(torch.float32)).to(inputs.dtype)Based on learnings
590-592
: LGTM: compute in fp32 to avoid overflow/rounding.Casting
inputs
andamax
to float improves numerical stability and aligns with the PR goal.
83-85
: Fix: replaceinputs.is_cpu
with a supported check.
torch.Tensor
hasis_cuda
but notis_cpu
. This will raise AttributeError on GPU tensors. Usenot inputs.is_cuda
instead.Apply this diff:
- if inputs.is_cpu or amax is None or amax.squeeze().ndim > 1: + if not inputs.is_cuda or amax is None or amax.squeeze().ndim > 1: return fp8_eager(inputs, amax)Based on learnings
86-89
: Ensureamax
is on the same CUDA device and in fp32 before calling the CUDA kernel.If
amax
is on CPU or not fp32, the extension can error or incur unintended casts. Move it toinputs.device
and cast to fp32.Apply this diff:
cuda_ext_fp8 = get_cuda_ext_fp8(raise_if_failed=False) if cuda_ext_fp8 is None: return fp8_eager(inputs, amax) + # Align amax device/dtype for CUDA path + if amax.device != inputs.device or amax.dtype != torch.float32: + amax = amax.to(device=inputs.device, dtype=torch.float32, non_blocking=True)Based on learnings
95-98
: Robust axis inference; fall back to eager when ambiguous.
axis = amax.shape.index(amax.numel())
is fragile and can infer a wrong axis or fail when multiple dims match. Infer axis by comparing shapes; if ambiguous, use the eager path.Apply this diff:
- elif amax.squeeze().ndim == 1: - axis = amax.shape.index(amax.numel()) - outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis) + elif amax.squeeze().ndim == 1: + amax_s = amax.squeeze() + axis = None + if amax.dim() == inputs.dim(): + candidates = [ + i for i, (asize, isize) in enumerate(zip(amax.shape, inputs.shape)) + if asize == isize and asize > 1 + ] + if len(candidates) == 1: + axis = candidates[0] + else: + candidates = [i for i, isize in enumerate(inputs.shape) if isize == amax_s.numel()] + if len(candidates) == 1: + axis = candidates[0] + if axis is None: + return fp8_eager(inputs, amax) + outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax_s, axis)
83-98
: No additionalis_cpu
checks or FP8 callsites with ambiguous 1Damax
found.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks. Have you benchmarked the tirton vs CUDA on hopper+? we can use triton by default if it's faster.
d920b7f
to
ce384be
Compare
a8919e9
to
1ae9302
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (2)
66-85
: Consider validating amax length matches axis size.The function correctly makes inputs and amax contiguous and uses the current CUDA stream. However, consider adding a validation to ensure
amax.numel() == axis_size
to catch shape mismatches early and provide a clear error message.at::Tensor fake_e4m3fy_with_axis(at::Tensor inputs, at::Tensor amax, int axis) { inputs = inputs.contiguous(); amax = amax.contiguous().to(at::kFloat); auto outputs = torch::empty_like(inputs); size_t numel = inputs.numel(); int axis_size = inputs.size(axis); int outer_size = inputs.stride(axis); + TORCH_CHECK( + amax.numel() == axis_size, + "fake_e4m3fy_with_axis: amax.numel() must equal axis length. Expected ", + axis_size, ", got ", amax.numel() + ); + auto scale = 448.f / amax; auto inv_scale = 1.f / scale;Based on past review comments.
87-101
: Consider validating amax is scalar for per-tensor path.The function correctly uses the current stream and makes inputs contiguous. However, the per-tensor kernel reads
scale[0]
only. Consider adding a validation to ensureamax.numel() == 1
to catch misuse early.at::Tensor fake_e4m3fy(at::Tensor inputs, at::Tensor amax) { inputs = inputs.contiguous(); amax = amax.view(-1).to(at::kFloat); + TORCH_CHECK( + amax.numel() == 1, + "fake_e4m3fy expects scalar amax; use fake_e4m3fy_with_axis for per-channel." + ); size_t numel = inputs.numel(); at::Tensor scale = 448.f / amax;Based on past review comments.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
modelopt/torch/quantization/extensions.py
(1 hunks)modelopt/torch/quantization/src/tensor_quant.h
(0 hunks)modelopt/torch/quantization/src/tensor_quant_fp8.cpp
(0 hunks)modelopt/torch/quantization/src/tensor_quant_gpu.cu
(4 hunks)modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu
(2 hunks)modelopt/torch/quantization/tensor_quant.py
(5 hunks)tests/_test_utils/torch_quantization/tensor_quant_common.py
(0 hunks)tests/gpu/torch/quantization/test_tensor_quant_cuda.py
(3 hunks)tests/unit/torch/quantization/test_tensor_quant_cpu.py
(1 hunks)
💤 Files with no reviewable changes (3)
- modelopt/torch/quantization/src/tensor_quant.h
- modelopt/torch/quantization/src/tensor_quant_fp8.cpp
- tests/_test_utils/torch_quantization/tensor_quant_common.py
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/extensions.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-16T20:14:34.768Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:191-191
Timestamp: 2025-09-16T20:14:34.768Z
Learning: The TensorRT-Model-Optimizer project only supports PyTorch >= 2.6, so using the `weights_only` parameter in torch.load calls is acceptable and doesn't require backward compatibility handling.
Applied to files:
modelopt/torch/quantization/tensor_quant.py
🧬 Code graph analysis (4)
tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
FakeTensorQuantTester
(126-136)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (1)
modelopt/torch/quantization/src/tensor_quant_gpu.cu (6)
void
(59-73)void
(101-120)void
(147-167)void
(211-233)void
(262-279)void
(311-340)
tests/gpu/torch/quantization/test_tensor_quant_cuda.py (4)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
FakeTensorQuantTester
(126-136)modelopt/torch/quantization/extensions.py (2)
get_cuda_ext
(27-36)get_cuda_ext_mx
(55-71)modelopt/torch/quantization/tensor_quant.py (1)
fp8_eager
(61-66)modelopt/torch/quantization/utils.py (1)
reduce_amax
(148-180)
modelopt/torch/quantization/tensor_quant.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
amax
(236-241)amax
(244-255)axis
(283-285)axis
(288-290)modelopt/torch/quantization/extensions.py (1)
get_cuda_ext_fp8
(39-52)modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
fake_e4m3fy
(87-101)fake_e4m3fy
(87-87)fake_e4m3fy_with_axis
(66-85)fake_e4m3fy_with_axis
(66-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (20)
modelopt/torch/quantization/src/tensor_quant_gpu.cu (4)
19-19
: LGTM: Stream header inclusion is appropriate.The inclusion of
<c10/cuda/CUDAStream.h>
is necessary for the stream-aware kernel launches below.
78-82
: LGTM: Stream-aware kernel launch prevents synchronization issues.The kernel now correctly executes on the current CUDA stream rather than the default stream, which is important for correctness in multi-stream scenarios.
90-94
: LGTM: Consistent stream-aware pattern.The fake_tensor_quant_cuda function correctly uses the current stream, matching the pattern in the inplace variant.
131-137
: LGTM: Per-axis variant now stream-aware.The with-axis kernel launch now correctly uses the current CUDA stream, completing the stream-awareness migration across all quantization kernels.
tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)
22-22
: LGTM: Import simplified to reflect updated test surface.The import correctly narrows to
FakeTensorQuantTester
only, aligning with the removal ofTensorQuantTester
andFakeAffineTensorQuantTester
from test utilities as part of the broader test cleanup.tests/gpu/torch/quantization/test_tensor_quant_cuda.py (6)
21-21
: LGTM: Import simplified to reflect updated test utilities.The import correctly narrows to
FakeTensorQuantTester
, consistent with the test utilities cleanup across the PR.
26-26
: LGTM: Extension imports updated to reflect new FP8 path.The removal of
get_cuda_ext_fp8
from imports aligns with the PR's refactoring where FP8 functionality is now accessed through different pathways (e.g., viascaled_e4m3
which internally callsget_cuda_ext_fp8
).
113-127
: LGTM: FP8 tests now use pure PyTorch reference implementation.The tests correctly use
fp8_eager
as the reference, addressing the past review feedback to use pure torch implementation. The parameterization across devices and dtypes improves coverage, and the CPU skip for non-float32 is appropriate.
129-134
: LGTM: Incontiguous test updated with pure torch reference.Consistent with the other FP8 test updates, ensuring non-contiguous tensors are validated against the pure PyTorch implementation.
148-156
: LGTM: Multi-GPU test validates cross-device correctness.The test correctly validates that FP8 quantization works when the tensor is on a non-current GPU device, using the pure torch reference for comparison.
157-163
: LGTM: Per-channel FP8 test validates new axis-aware functionality.This new test correctly validates the per-channel FP8 quantization feature introduced in this PR, testing across multiple axes and using the pure PyTorch reference for validation.
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
21-21
: Note:<optional>
included but not visibly used in this file.The
<optional>
header is included but doesn't appear to be used in the code shown. This might be for future use or used in code not shown in this review.
35-45
: LGTM: Per-tensor FP8 kernel correctly fuses scale/inv_scale in fp32.The kernel correctly implements the PR objective by applying scale and inv_scale in fp32 before/after the FP8 conversion, preventing scaling from being applied in bf16 when inputs are bf16. The 4-element per-thread loop improves memory coalescing.
47-64
: LGTM: Axis-aware FP8 kernel correctly implements per-channel quantization.The kernel correctly computes
axis_id
from flattened indices and applies per-axis scale/inv_scale, enabling per-channel FP8 quantization. The indexing logic is consistent with the integer per-axis kernel intensor_quant_gpu.cu
(lines 115-116).
103-108
: LGTM: PyBind11 module correctly exposes FP8 functions.The module definition correctly binds both per-tensor and per-axis FP8 quantization functions with clear argument names, making them accessible from Python.
modelopt/torch/quantization/tensor_quant.py (5)
45-66
: LGTM: FP8 eager implementation provides pure PyTorch reference.The implementation correctly provides a pure PyTorch reference for FP8 quantization, addressing past review feedback. Key highlights:
- Scaling applied in fp32 (lines 48-50) prevents bf16 rounding issues
amax=None
case enables dynamic quantization without scaling- Triton optimization via
torch.compile
improves performance when available
82-97
: LGTM: Robust fallback logic for FP8 routing.The routing correctly falls back to eager mode for non-CUDA tensors, missing amax, or unsupported amax shapes (line 82), addressing past review feedback. The CUDA path appropriately routes between per-tensor and per-axis kernels based on amax shape.
Note: Axis inference via
amax.shape.index(amax.numel())
(line 95) assumes unique matching dimensions. Past review comments noted this could be fragile, but the maintainer confirmed the current approach is acceptable.
379-382
: LGTM: Legacy quant function simplified.The function correctly uses the updated
_tensor_quant
which now returns only outputs (not a tuple), simplifying the call site.
589-590
: LGTM: Explicit float casting ensures fp32 computation.The explicit casting of inputs and amax to float (lines 589-590) ensures all quantization math happens in fp32, preventing precision issues when inputs are bf16 or fp16.
616-619
: LGTM: Dequantization preserves fp32 precision before dtype conversion.The division by scale is performed in float (line 616), and the result is cast back to the original dtype (line 618), addressing past review feedback about preserving fp32 precision in the QDQ path.
@mxinO I am running into issues with torch.compile of FP8 eager mode. Let us implement torch.compile + FP8 eager mode in a seperate PR. Let me merge the current PR to fix the FP8 QAT bug on the main branch. |
Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> Added FP8 eager mode with triton for multi-axis FP8 fake quant minor Signed-off-by: realAsma <[email protected]> Use fp8_eager for cpu or if amax is None minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> Removed FP8 eager with torch.compile Signed-off-by: realAsma <[email protected]>
dc12e82
to
56f5a08
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
modelopt/torch/quantization/tensor_quant.py (1)
56-58
: Consider allowingamax=None
in signature for consistency.The wrapped
_fp8_eager
function acceptsamax=None
for dynamic quantization without scaling, but this wrapper requiresamax
. Consider adding a default parameteramax=None
to match the underlying implementation and support the full range of use cases.Apply this diff:
-def fp8_eager(x, amax): +def fp8_eager(x, amax=None): """Eager mode implementation of FP8 quantization.""" return _fp8_eager(x, amax)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/quantization/tensor_quant.py
(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/tensor_quant.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
amax
(236-241)amax
(244-255)axis
(283-285)axis
(288-290)modelopt/torch/quantization/extensions.py (1)
get_cuda_ext_fp8
(39-52)modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
fake_e4m3fy
(87-101)fake_e4m3fy
(87-87)fake_e4m3fy_with_axis
(66-85)fake_e4m3fy_with_axis
(66-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (4)
modelopt/torch/quantization/tensor_quant.py (4)
371-374
: LGTM! Simplified to match _tensor_quant's new return signature.The removal of separate scale handling aligns with
_tensor_quant
now returning fully scaled outputs directly.
573-611
: LGTM! FP32 precision preserved throughout quantization.The refactored
_tensor_quant
correctly performs all scaling operations in fp32 before casting back to the input dtype, addressing the precision concerns from the PR objectives. The zero-amax handling logic at lines 599-607 is also correct.
45-54
: No amax device alignment issues found. All existing callers supplyamax
as a tensor on the same device asx
, so no additional checks are required.
84-88
: Fix: axis inference crashes when amax has size-1 dimensions.When
amax
has a shape like[1, 128, 1, 1]
for per-channel quantization, line 87'samax.shape.index(amax.numel())
will raiseValueError
because128
is not present in the shape tuple(1, 128, 1, 1)
.Apply this diff to fix the axis inference:
if amax.numel() == 1: outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax) elif amax.squeeze().ndim == 1: - axis = amax.shape.index(amax.numel()) + # Find the axis where amax size matches (non-singleton dimension) + axis = next(i for i, s in enumerate(amax.shape) if s == amax.numel()) outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)Alternative robust approach:
if amax.numel() == 1: outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax) elif amax.squeeze().ndim == 1: - axis = amax.shape.index(amax.numel()) + # Find axis by matching non-singleton dimensions + candidates = [i for i, s in enumerate(amax.shape) if s > 1] + if len(candidates) != 1: + return fp8_eager(inputs, amax) # Ambiguous, use eager path + axis = candidates[0] outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)Likely an incorrect or invalid review comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (1)
21-21
: Remove unused include.The
<optional>
header is included but never used in this file. Nostd::optional
types are present in the code.Apply this diff:
-#include <optional>
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
modelopt/torch/quantization/extensions.py
(1 hunks)modelopt/torch/quantization/src/tensor_quant.h
(0 hunks)modelopt/torch/quantization/src/tensor_quant_fp8.cpp
(0 hunks)modelopt/torch/quantization/src/tensor_quant_gpu.cu
(4 hunks)modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu
(2 hunks)modelopt/torch/quantization/tensor_quant.py
(5 hunks)tests/_test_utils/torch_quantization/tensor_quant_common.py
(0 hunks)tests/gpu/torch/quantization/test_tensor_quant_cuda.py
(3 hunks)tests/unit/torch/quantization/test_tensor_quant_cpu.py
(1 hunks)
💤 Files with no reviewable changes (3)
- modelopt/torch/quantization/src/tensor_quant.h
- tests/_test_utils/torch_quantization/tensor_quant_common.py
- modelopt/torch/quantization/src/tensor_quant_fp8.cpp
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt/torch/quantization/extensions.py
- modelopt/torch/quantization/src/tensor_quant_gpu.cu
🧰 Additional context used
🧬 Code graph analysis (4)
tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
FakeTensorQuantTester
(126-136)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (1)
modelopt/torch/quantization/src/tensor_quant_gpu.cu (6)
void
(59-73)void
(101-120)void
(147-167)void
(211-233)void
(262-279)void
(311-340)
tests/gpu/torch/quantization/test_tensor_quant_cuda.py (4)
tests/_test_utils/torch_quantization/tensor_quant_common.py (1)
FakeTensorQuantTester
(126-136)modelopt/torch/quantization/extensions.py (2)
get_cuda_ext
(27-36)get_cuda_ext_mx
(55-71)modelopt/torch/quantization/tensor_quant.py (1)
fp8_eager
(56-58)modelopt/torch/quantization/utils.py (1)
reduce_amax
(148-180)
modelopt/torch/quantization/tensor_quant.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
amax
(236-241)amax
(244-255)axis
(283-285)axis
(288-290)modelopt/torch/quantization/extensions.py (1)
get_cuda_ext_fp8
(39-52)modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
fake_e4m3fy
(87-101)fake_e4m3fy
(87-87)fake_e4m3fy_with_axis
(66-85)fake_e4m3fy_with_axis
(66-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (15)
modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu (4)
47-64
: LGTM! Axis-aware FP8 quantization kernel is correctly implemented.The kernel correctly computes per-axis scale indices and applies FP32 scaling, which addresses the bf16 scaling issue mentioned in the PR objectives. The logic aligns with the integer quantization pattern from tensor_quant_gpu.cu.
66-85
: LGTM! Host function correctly prepares tensors for axis-aware quantization.The function ensures contiguity, casts amax to FP32, and uses the current CUDA stream. Based on past review comments, input validation (axis normalization, device matching, shape validation) is handled in the C++ dispatcher
fake_e4m3fy_with_axis
, which is appropriate for separation of concerns.Based on past review comments.
87-101
: LGTM! Per-tensor FP8 quantization host function is correct.The function properly handles per-tensor quantization with FP32 scaling. The
view(-1)
call flattens amax, and shape validation is handled in the Python dispatcherscaled_e4m3_impl
(as confirmed by past review comments), which is appropriate.Based on past review comments.
103-108
: LGTM! PyBind11 bindings correctly expose the new FP8 API.The module exports both per-tensor and per-axis FP8 quantization functions with appropriate signatures and named arguments, aligning with the PR's goal to provide axis-aware FP8 quantization.
tests/unit/torch/quantization/test_tensor_quant_cpu.py (1)
22-22
: LGTM! Import simplified to match updated test utilities.The import change correctly reflects the removal of
TensorQuantTester
andFakeAffineTensorQuantTester
base classes from the test utilities, aligning with the PR's simplification of the quantization test surface.tests/gpu/torch/quantization/test_tensor_quant_cuda.py (6)
21-26
: LGTM! Import surface updated to reflect removed FP8 CPU extension.The removal of
get_cuda_ext_fp8
import aligns with the PR's removal of CPU-side FP8 implementations and the consolidation of FP8 paths to CUDA-only or eager mode.
112-116
: LGTM! Test correctly uses eager mode as reference.The test now uses
fp8_eager
as the ground truth reference, which aligns with the PR's goal to validate that the CUDA kernel produces the same results as the eager implementation.
118-127
: LGTM! Enhanced test coverage with dtype parameterization.The test now validates FP8 quantization across multiple dtypes (bfloat16, float16, float32) with appropriate CPU skip logic, and uses
fp8_eager
as the reference. This ensures the FP32 scaling fix works correctly across different input precisions.
129-134
: LGTM! Non-contiguous tensor handling verified.The test correctly validates that non-contiguous tensors are handled properly by the quantization path, with
fp8_eager
as the reference.
148-156
: LGTM! Multi-GPU handling correctly tested.The test validates that FP8 quantization works correctly when tensors are on a non-current GPU device, with proper device synchronization.
157-163
: LGTM! New per-channel quantization test validates axis-aware functionality.This new test validates the core feature of the PR: per-channel FP8 quantization with axis support. The test correctly computes per-axis amax and validates against the eager reference across different axes (0, 1, 2).
modelopt/torch/quantization/tensor_quant.py (4)
45-59
: LGTM! FP8 eager implementation correctly applies FP32 scaling.The eager mode implementation correctly:
- Preserves the original dtype throughout the operation
- Applies scaling in FP32 (448.0 / amax), addressing the bf16 scaling issue mentioned in PR objectives
- Handles the amax=None case for unscaled quantization (useful for dynamic quantization)
- Uses native PyTorch FP8 dtype (torch.float8_e4m3fn)
573-611
: LGTM! Quantization math now performed in FP32.The function correctly:
- Casts inputs and amax to float32 before computation (lines 581-582)
- Performs all quantization math in FP32, including the dequantization division (line 608)
- Casts back to the original input dtype at the end (line 610)
This addresses the bf16 scaling issue mentioned in the PR objectives and past review comments.
Based on past review comments.
371-374
: LGTM! Legacy quantization function updated to match new _tensor_quant signature.The function correctly calls the updated
_tensor_quant
which now returns only outputs (not a tuple) and performs all computation in FP32.
74-89
: Ignore axis inference ambiguity concern
Theamax.squeeze().ndim > 1
guard ensures axis routing only runs on 1-D keepdims outputs, where exactly one dimension matchesamax.numel()
, soshape.index(...)
is unambiguous.Likely an incorrect or invalid review comment.
Signed-off-by: realAsma <[email protected]>
…or FP8 per-channel quantization (#381) Signed-off-by: realAsma <[email protected]>
What does this PR do?
Type of change: ? new_feature & bug fix
Overview:
For FP8 fake quantization, the QDQ scaling was applied outside the cuda the kernel in the same precision as the input. This means for bf16 input, the scale is applied in bf16 type. This could cause fake and real quantization to diverge.
Hence in this PR:
fused_fp8
- this only supported per-channel quantization along -1 axis. Per-channel quantization is usually done along 0 axis and this kernel was not useful.What was the impact of scaling in bf16 for QAT?
I have not tested this. I tested the abs quantization error Llama 3.1 8B's first q_proj weight.
Here it is:
Usage
High level usage does not change. See the updated unit tests.
Testing
Updated
Before your PR is "Ready for review"
fused
fp8 fake quantization is removed. However this is not a high-level API and should not cause any impact.Additional Information
Summary by CodeRabbit
New Features
Refactor
Tests