Skip to content

Commit 34bcee8

Browse files
committed
Fixed FP8 fake quantization to use fp32 amax scaling
Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]>
1 parent ad091e8 commit 34bcee8

File tree

5 files changed

+66
-115
lines changed

5 files changed

+66
-115
lines changed

modelopt/torch/quantization/src/tensor_quant.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ void fake_tensor_quant_cuda_inplace(at::Tensor, at::Tensor, int, bool, bool);
2222
at::Tensor fake_tensor_quant_cuda(at::Tensor, at::Tensor, int, bool, bool);
2323
at::Tensor fake_tensor_quant_with_axis_cuda(at::Tensor, at::Tensor, int, int, bool, bool);
2424
float bits_to_bound(int, int);
25-
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs);
2625

2726
// Dequantizes data using NF4 quantization scheme and per-block scaling factors.
2827
//

modelopt/torch/quantization/src/tensor_quant_fp8.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,30 +19,37 @@
1919
#include <cuda_fp8.h>
2020
#include <torch/extension.h>
2121

22-
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs);
23-
at::Tensor fused_fake_e4m3fy_cuda(at::Tensor inputs, at::Tensor amax, const float zero_threshold);
22+
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs, at::Tensor amax);
23+
at::Tensor fake_e4m3fy_cuda_with_axis(at::Tensor inputs, at::Tensor amax, int axis);
2424

25-
at::Tensor fake_e4m3fy(at::Tensor inputs) {
25+
at::Tensor fake_e4m3fy(at::Tensor inputs, at::Tensor amax) {
26+
TORCH_CHECK(amax.numel(), 1);
27+
inputs = inputs.contiguous();
28+
auto amax_view = amax.view(-1).to(at::kFloat);
2629
if (inputs.is_cuda()) {
27-
return fake_e4m3fy_cuda(inputs.contiguous());
30+
return fake_e4m3fy_cuda(inputs, amax_view);
2831
} else {
2932
TORCH_CHECK(inputs.dtype() == at::ScalarType::Float);
30-
TORCH_CHECK(inputs.is_contiguous());
33+
float scale = 448.f / amax_view[0].item<float>();
34+
float inv_scale = 1.f / scale;
3135
auto out = at::zeros_like(inputs);
3236
for (int i = 0; i < inputs.numel(); ++i) {
3337
out.data_ptr<float>()[i] =
34-
static_cast<float>(static_cast<__nv_fp8_e4m3>(inputs.data_ptr<float>()[i]));
38+
static_cast<float>(static_cast<__nv_fp8_e4m3>(inputs.data_ptr<float>()[i] * scale)) *
39+
inv_scale;
3540
}
3641
return out;
3742
}
3843
}
3944

40-
at::Tensor fused_fake_e4m3fy(at::Tensor inputs, at::Tensor amax, const float zero_threshold) {
41-
return fused_fake_e4m3fy_cuda(inputs.contiguous(), amax, zero_threshold);
45+
at::Tensor fake_e4m3fy_with_axis(at::Tensor inputs, at::Tensor amax, int axis) {
46+
TORCH_CHECK(inputs.is_cuda());
47+
return fake_e4m3fy_cuda_with_axis(inputs.contiguous(), amax.contiguous().to(at::kFloat), axis);
4248
}
4349

4450
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
45-
m.def("fake_e4m3fy", &fake_e4m3fy, "Reduce precision to E4M3", py::arg("inputs"));
46-
m.def("fused_fake_e4m3fy", &fused_fake_e4m3fy, "Reduce precision to E4M3 (fused)",
47-
py::arg("inputs"), py::arg("amax"), py::arg("zero_threshold"));
51+
m.def("fake_e4m3fy", &fake_e4m3fy, "Reduce precision to E4M3", py::arg("inputs"),
52+
py::arg("amax"));
53+
m.def("fake_e4m3fy_with_axis", &fake_e4m3fy_with_axis, "Reduce precision to E4M3 (fused)",
54+
py::arg("inputs"), py::arg("amax"), py::arg("axis"));
4855
}

modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu

Lines changed: 27 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -31,92 +31,65 @@
3131
#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
3232
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
3333

34-
template <typename T> __global__ void fake_e4m3fy_kernel(const T *inputs, size_t n, T *outputs) {
34+
template <typename T>
35+
__global__ void fake_e4m3fy_kernel(const T *inputs, size_t n, const float *scale,
36+
const float *inv_scale, T *outputs) {
3537
int tid = blockIdx.x * blockDim.x + threadIdx.x;
3638

3739
for (int idx = 4 * tid; idx < 4 * (tid + 1) && idx < n; ++idx) {
3840
outputs[idx] = static_cast<T>(
39-
static_cast<float>(static_cast<__nv_fp8_e4m3>(static_cast<float>(inputs[idx]))));
41+
static_cast<float>(static_cast<__nv_fp8_e4m3>(static_cast<float>(inputs[idx]) * scale[0])) *
42+
inv_scale[0]);
4043
}
4144
}
4245

4346
template <typename T>
44-
__global__ void fused_fake_e4m3fy_kernel(const T *inputs, size_t n, float *amax,
45-
bool per_block_scaling_factor, size_t blocksize,
46-
float zero_threshold, T *outputs) {
47+
__global__ void fake_e4m3fy_with_axis_cuda_kernel(const T *inputs, size_t n, const float *scale,
48+
const float *inv_scale, int axis_size,
49+
int outer_size, T *outputs) {
4750
int tid = blockIdx.x * blockDim.x + threadIdx.x;
4851

4952
for (int idx = 4 * tid; idx < 4 * (tid + 1) && idx < n; ++idx) {
5053
float x = static_cast<float>(inputs[idx]);
5154

52-
// generate mask for zeroing tiny values
53-
float x_abs = fabsf(x);
54-
bool zero_mask = x_abs < zero_threshold;
55-
56-
// grab the global scaling factor
57-
size_t amax_idx = (per_block_scaling_factor) ? (idx / blocksize) : 0;
58-
59-
// compute scale and inverse-scales
60-
float scale = 448.f / (amax[amax_idx]);
61-
float inv_scale = 1.f / scale;
55+
int axis_id = (idx / outer_size) % axis_size;
6256

6357
// compute the output
64-
float output = static_cast<float>(static_cast<__nv_fp8_e4m3>(scale * x)) * inv_scale;
65-
66-
// zero out small values
67-
if (zero_mask) {
68-
output = 0.f;
69-
}
58+
float output =
59+
static_cast<float>(static_cast<__nv_fp8_e4m3>(scale[axis_id] * x)) * inv_scale[axis_id];
7060

7161
outputs[idx] = output;
7262
}
7363
}
7464

75-
at::Tensor fused_fake_e4m3fy_cuda(at::Tensor inputs, at::Tensor amax, const float zero_threshold) {
76-
size_t numel = inputs.numel();
65+
at::Tensor fake_e4m3fy_cuda_with_axis(at::Tensor inputs, at::Tensor amax, int axis) {
7766
auto outputs = torch::empty_like(inputs);
67+
size_t numel = inputs.numel();
68+
int axis_size = inputs.size(axis);
69+
int outer_size = inputs.stride(axis);
7870

79-
bool per_block_scaling_factor = false;
80-
size_t blocksize = numel;
81-
82-
int amax_ndim = amax.dim();
83-
int input_ndim = inputs.dim();
84-
85-
// 3 options:
86-
// 1.
87-
// inputs[numel], amax[1] -> per-tensor scaling
88-
// 2.
89-
// inputs[numel], amax[numel/num_cols] -> per-row / per-channel scaling
90-
// 3.
91-
// inputs[numel/bs, bs], amax[numel/bs, 1] -> blockwise scaling
92-
if (amax.numel() == 1) {
93-
// case 1.
94-
per_block_scaling_factor = false;
95-
} else if (amax.numel() > 1 && (amax_ndim > 1 && (amax.size(-1) == amax.numel()))) {
96-
// case 2.
97-
per_block_scaling_factor = true;
98-
blocksize = numel / amax.numel();
99-
} else {
100-
throw std::runtime_error("invalid combination of inputs and amax shapes/sizes");
101-
}
71+
auto scale = 448.f / amax;
72+
auto inv_scale = 1.f / scale;
10273

103-
auto stream = c10::cuda::getCurrentCUDAStream();
104-
105-
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fused_fake_e4m3fy_cuda", [&] {
106-
fused_fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
107-
inputs.data_ptr<scalar_t>(), numel, amax.data_ptr<float>(), per_block_scaling_factor,
108-
blocksize, zero_threshold, outputs.data_ptr<scalar_t>());
74+
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis_cuda", [&] {
75+
fake_e4m3fy_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE>>>(
76+
inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
77+
axis_size, outer_size, outputs.data_ptr<scalar_t>());
10978
});
79+
11080
return outputs;
11181
}
11282

113-
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs) {
83+
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs, at::Tensor amax) {
11484
size_t numel = inputs.numel();
85+
auto scale = 448.f / amax;
86+
auto inv_scale = 1.f / scale;
11587
auto outputs = torch::empty_like(inputs);
11688
auto stream = c10::cuda::getCurrentCUDAStream();
11789
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_cuda", [&] {
11890
fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
119-
inputs.data_ptr<scalar_t>(), numel, outputs.data_ptr<scalar_t>());
91+
inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
92+
outputs.data_ptr<scalar_t>());
12093
});
12194
return outputs;
12295
}

modelopt/torch/quantization/tensor_quant.py

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@
4343

4444

4545
def scaled_e4m3_impl(
46-
inputs: torch.Tensor, # TODO: check support for multiple inputs
47-
amax: torch.Tensor,
48-
disable_fused_kernel=True,
46+
inputs: torch.Tensor,
47+
amax: torch.Tensor | None = None,
4948
) -> torch.Tensor:
5049
"""Implementation of fake quantizing input to FP8.
5150
@@ -58,41 +57,20 @@ def scaled_e4m3_impl(
5857
"""
5958
cuda_ext_fp8 = get_cuda_ext_fp8(raise_if_failed=True)
6059

61-
def is_fusable():
62-
# ignore no scaling and shape([]) cases
63-
if amax is None or len(amax.shape) == 0:
64-
return False
65-
else:
66-
# can't have amax.shape = [1, 1, 4, 1] and the like
67-
amax_last_dim_only = amax.numel() == amax.shape[-1]
68-
# must be cuda
69-
all_cuda = inputs.is_cuda and amax.is_cuda
70-
71-
# also check explicit disable.
72-
return amax_last_dim_only and all_cuda and (not disable_fused_kernel)
73-
7460
with torch.cuda.device(
7561
None if inputs.device.index == torch.cuda.current_device() else inputs.device.index
7662
):
77-
# differentiate between fused & unfused cases
78-
if is_fusable():
79-
zero_threshold = 1.0 / (1 << 24)
80-
outputs = cuda_ext_fp8.fused_fake_e4m3fy(inputs, amax.float(), zero_threshold)
63+
if amax is None:
64+
amax = torch.tensor(448.0, device=inputs.device, dtype=inputs.dtype)
65+
if amax.numel() == 1:
66+
outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
8167
else:
82-
zero_mask = inputs.abs() < 1.0 / (1 << 24)
83-
84-
if amax is None:
85-
outputs = cuda_ext_fp8.fake_e4m3fy(inputs)
86-
else:
87-
scale = 448.0 / amax
88-
outputs = cuda_ext_fp8.fake_e4m3fy(inputs * scale) / scale
89-
90-
# Zero out values that are tiny.
91-
# Tiny values could lead to tiny amax and then large scale which cause overflow/saturation
92-
# and won't go back to normal value after dividing by scale. The right behavior is to mark them
93-
# as zero which also get rid of inf/nan
94-
outputs[zero_mask] = 0.0
95-
68+
if amax.squeeze().ndim > 1:
69+
raise NotImplementedError(
70+
"Fused E4M3 kernel does not support multiaxis quantization."
71+
)
72+
axis = amax.shape.index(amax.numel())
73+
outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
9674
return outputs
9775

9876

tests/gpu/torch/quantization/test_tensor_quant_cuda.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import modelopt.torch.quantization.triton as triton_kernel
2828
import modelopt.torch.quantization.utils as quant_utils
2929
from modelopt.torch.quantization import tensor_quant
30-
from modelopt.torch.quantization.extensions import get_cuda_ext, get_cuda_ext_fp8, get_cuda_ext_mx
30+
from modelopt.torch.quantization.extensions import get_cuda_ext, get_cuda_ext_mx
3131
from modelopt.torch.quantization.tensor_quant import mx_format_map
3232

3333

@@ -187,20 +187,14 @@ def test_non_current_gpu(self, need_2_gpus):
187187
quant_x = tensor_quant.scaled_e4m3(x, x.amax(), None, 4, 3)
188188
assert torch.allclose(quant_x, quant_x_ref.cuda(device))
189189

190-
def test_fused_e4m3_kernel(self):
191-
cuda_ext_fp8 = get_cuda_ext_fp8()
192-
x = torch.tensor(TestScaledE4M3.x).cuda()
193-
xq_ref = torch.tensor(TestScaledE4M3.xq_scaled).cuda()
194-
amax = torch.ones(1, x.shape[-1]).cuda() * x.abs().amax()
195-
e4m3_x = cuda_ext_fp8.fused_fake_e4m3fy(x, amax.float(), 1.0 / (1 << 24))
196-
assert torch.allclose(e4m3_x, xq_ref, atol=1e-4, rtol=1e-4)
197-
198-
def test_e4m3_kernel_non_last_axis(self):
199-
x = torch.tensor(TestScaledE4M3.x).cuda()
200-
xq_ref = torch.tensor(TestScaledE4M3.xq_scaled).cuda()
201-
amax = torch.ones(x.shape[0], 1).cuda() * x.abs().amax()
202-
e4m3_x = tensor_quant.scaled_e4m3(x, amax, None, 4, 3)
203-
assert torch.allclose(e4m3_x, xq_ref, atol=1e-4, rtol=1e-4)
190+
@pytest.mark.parametrize("axis", [0, 1, 2])
191+
def test_e4m3_per_channel(self, axis):
192+
x = torch.randn(4, 4, 4, dtype=torch.float32).cuda()
193+
amax = x.abs().amax(dim=[ax for ax in range(x.ndim) if ax != axis], keepdim=True)
194+
scale = 448.0 / amax
195+
xq_ref = tensor_quant.scaled_e4m3(x * scale, None, None, 4, 3) / scale
196+
xq_test = tensor_quant.scaled_e4m3(x, amax.float(), None, 4, 3)
197+
assert torch.allclose(xq_test, xq_ref)
204198

205199

206200
class Testfp4:

0 commit comments

Comments
 (0)