Skip to content

Commit ddbd66d

Browse files
committed
Fixed FP8 fake quantization to use fp32 amax scaling
Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> minor Signed-off-by: realAsma <[email protected]> Added FP8 eager mode with triton for multi-axis FP8 fake quant minor Signed-off-by: realAsma <[email protected]>
1 parent 615f3c0 commit ddbd66d

File tree

8 files changed

+118
-460
lines changed

8 files changed

+118
-460
lines changed

modelopt/torch/quantization/src/tensor_quant.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ void fake_tensor_quant_cuda_inplace(at::Tensor, at::Tensor, int, bool, bool);
2222
at::Tensor fake_tensor_quant_cuda(at::Tensor, at::Tensor, int, bool, bool);
2323
at::Tensor fake_tensor_quant_with_axis_cuda(at::Tensor, at::Tensor, int, int, bool, bool);
2424
float bits_to_bound(int, int);
25-
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs);
2625

2726
// Dequantizes data using NF4 quantization scheme and per-block scaling factors.
2827
//

modelopt/torch/quantization/src/tensor_quant_fp8.cpp

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,43 @@
1818
#include <ATen/ATen.h>
1919
#include <cuda_fp8.h>
2020
#include <torch/extension.h>
21+
#include <optional>
2122

22-
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs);
23-
at::Tensor fused_fake_e4m3fy_cuda(at::Tensor inputs, at::Tensor amax, const float zero_threshold);
23+
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs, std::optional<at::Tensor> amax);
24+
at::Tensor fake_e4m3fy_cuda_with_axis(at::Tensor inputs, at::Tensor amax, int axis);
2425

25-
at::Tensor fake_e4m3fy(at::Tensor inputs) {
26+
at::Tensor fake_e4m3fy(at::Tensor inputs, std::optional<at::Tensor> amax) {
27+
inputs = inputs.contiguous();
28+
if (amax.has_value()) {
29+
amax = amax.value().view(-1).to(at::kFloat);
30+
}
2631
if (inputs.is_cuda()) {
27-
return fake_e4m3fy_cuda(inputs.contiguous());
32+
return fake_e4m3fy_cuda(inputs, amax);
2833
} else {
2934
TORCH_CHECK(inputs.dtype() == at::ScalarType::Float);
30-
TORCH_CHECK(inputs.is_contiguous());
35+
float scale = 1.f;
36+
if (amax.has_value()) {
37+
scale = 448.f / amax.value()[0].item<float>();
38+
}
39+
float inv_scale = 1.f / scale;
3140
auto out = at::zeros_like(inputs);
3241
for (int i = 0; i < inputs.numel(); ++i) {
3342
out.data_ptr<float>()[i] =
34-
static_cast<float>(static_cast<__nv_fp8_e4m3>(inputs.data_ptr<float>()[i]));
43+
static_cast<float>(static_cast<__nv_fp8_e4m3>(inputs.data_ptr<float>()[i] * scale)) *
44+
inv_scale;
3545
}
3646
return out;
3747
}
3848
}
3949

40-
at::Tensor fused_fake_e4m3fy(at::Tensor inputs, at::Tensor amax, const float zero_threshold) {
41-
return fused_fake_e4m3fy_cuda(inputs.contiguous(), amax, zero_threshold);
50+
at::Tensor fake_e4m3fy_with_axis(at::Tensor inputs, at::Tensor amax, int axis) {
51+
TORCH_CHECK(inputs.is_cuda());
52+
return fake_e4m3fy_cuda_with_axis(inputs.contiguous(), amax.contiguous().to(at::kFloat), axis);
4253
}
4354

4455
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
45-
m.def("fake_e4m3fy", &fake_e4m3fy, "Reduce precision to E4M3", py::arg("inputs"));
46-
m.def("fused_fake_e4m3fy", &fused_fake_e4m3fy, "Reduce precision to E4M3 (fused)",
47-
py::arg("inputs"), py::arg("amax"), py::arg("zero_threshold"));
56+
m.def("fake_e4m3fy", &fake_e4m3fy, "Reduce precision to E4M3", py::arg("inputs"),
57+
py::arg("amax") = py::none());
58+
m.def("fake_e4m3fy_with_axis", &fake_e4m3fy_with_axis, "Reduce precision to E4M3 (fused)",
59+
py::arg("inputs"), py::arg("amax"), py::arg("axis"));
4860
}

modelopt/torch/quantization/src/tensor_quant_gpu.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,9 @@ __global__ void fake_tensor_quant_kernel(const T *inputs, size_t n, T *outputs,
7474
void fake_tensor_quant_cuda_inplace(at::Tensor inputs, at::Tensor amax, int num_bits = 8,
7575
bool is_unsigned = false, bool narrow_range = true) {
7676
size_t numel = inputs.numel();
77+
auto stream = c10::cuda::getCurrentCUDAStream();
7778
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_tensor_quant_cuda_inplace", [&] {
78-
fake_tensor_quant_kernel<<<numel / BLOCK_SIZE + 1, BLOCK_SIZE>>>(
79+
fake_tensor_quant_kernel<<<numel / BLOCK_SIZE + 1, BLOCK_SIZE, 0, stream>>>(
7980
inputs.data_ptr<scalar_t>(), numel, inputs.data_ptr<scalar_t>(),
8081
amax.to(at::ScalarType::Float).data_ptr<float>(), num_bits, is_unsigned, narrow_range);
8182
});
@@ -85,8 +86,9 @@ at::Tensor fake_tensor_quant_cuda(at::Tensor inputs, at::Tensor amax, int num_bi
8586
bool is_unsigned = false, bool narrow_range = true) {
8687
size_t numel = inputs.numel();
8788
auto outputs = torch::empty_like(inputs);
89+
auto stream = c10::cuda::getCurrentCUDAStream();
8890
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_tensor_quant_cuda", [&] {
89-
fake_tensor_quant_kernel<<<numel / BLOCK_SIZE + 1, BLOCK_SIZE>>>(
91+
fake_tensor_quant_kernel<<<numel / BLOCK_SIZE + 1, BLOCK_SIZE, 0, stream>>>(
9092
inputs.data_ptr<scalar_t>(), numel, outputs.data_ptr<scalar_t>(),
9193
amax.to(at::ScalarType::Float).data_ptr<float>(), num_bits, is_unsigned, narrow_range);
9294
});
@@ -125,8 +127,9 @@ at::Tensor fake_tensor_quant_with_axis_cuda(at::Tensor inputs, at::Tensor amax,
125127

126128
int outer_size = inputs.stride(axis);
127129

130+
auto stream = c10::cuda::getCurrentCUDAStream();
128131
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_tensor_quant_cuda_with_axis", [&] {
129-
fake_tensor_quant_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE>>>(
132+
fake_tensor_quant_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
130133
inputs.data_ptr<scalar_t>(), numel, outputs.data_ptr<scalar_t>(),
131134
amax.to(at::ScalarType::Float).data_ptr<float>(), axis_size, outer_size, num_bits,
132135
is_unsigned, narrow_range);

modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu

Lines changed: 33 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <c10/cuda/CUDAStream.h>
2020
#include <cuda_fp8.h>
2121
#include <torch/extension.h>
22+
#include <optional>
2223

2324
#define BLOCK_SIZE 128
2425

@@ -31,92 +32,71 @@
3132
#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
3233
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
3334

34-
template <typename T> __global__ void fake_e4m3fy_kernel(const T *inputs, size_t n, T *outputs) {
35+
template <typename T>
36+
__global__ void fake_e4m3fy_kernel(const T *inputs, size_t n, const float *scale,
37+
const float *inv_scale, T *outputs) {
3538
int tid = blockIdx.x * blockDim.x + threadIdx.x;
3639

3740
for (int idx = 4 * tid; idx < 4 * (tid + 1) && idx < n; ++idx) {
3841
outputs[idx] = static_cast<T>(
39-
static_cast<float>(static_cast<__nv_fp8_e4m3>(static_cast<float>(inputs[idx]))));
42+
static_cast<float>(static_cast<__nv_fp8_e4m3>(static_cast<float>(inputs[idx]) * scale[0])) *
43+
inv_scale[0]);
4044
}
4145
}
4246

4347
template <typename T>
44-
__global__ void fused_fake_e4m3fy_kernel(const T *inputs, size_t n, float *amax,
45-
bool per_block_scaling_factor, size_t blocksize,
46-
float zero_threshold, T *outputs) {
48+
__global__ void fake_e4m3fy_with_axis_cuda_kernel(const T *inputs, size_t n, const float *scale,
49+
const float *inv_scale, int axis_size,
50+
int outer_size, T *outputs) {
4751
int tid = blockIdx.x * blockDim.x + threadIdx.x;
4852

4953
for (int idx = 4 * tid; idx < 4 * (tid + 1) && idx < n; ++idx) {
5054
float x = static_cast<float>(inputs[idx]);
5155

52-
// generate mask for zeroing tiny values
53-
float x_abs = fabsf(x);
54-
bool zero_mask = x_abs < zero_threshold;
55-
56-
// grab the global scaling factor
57-
size_t amax_idx = (per_block_scaling_factor) ? (idx / blocksize) : 0;
58-
59-
// compute scale and inverse-scales
60-
float scale = 448.f / (amax[amax_idx]);
61-
float inv_scale = 1.f / scale;
56+
int axis_id = (idx / outer_size) % axis_size;
6257

6358
// compute the output
64-
float output = static_cast<float>(static_cast<__nv_fp8_e4m3>(scale * x)) * inv_scale;
65-
66-
// zero out small values
67-
if (zero_mask) {
68-
output = 0.f;
69-
}
59+
float output =
60+
static_cast<float>(static_cast<__nv_fp8_e4m3>(scale[axis_id] * x)) * inv_scale[axis_id];
7061

7162
outputs[idx] = output;
7263
}
7364
}
7465

75-
at::Tensor fused_fake_e4m3fy_cuda(at::Tensor inputs, at::Tensor amax, const float zero_threshold) {
76-
size_t numel = inputs.numel();
66+
at::Tensor fake_e4m3fy_cuda_with_axis(at::Tensor inputs, at::Tensor amax, int axis) {
7767
auto outputs = torch::empty_like(inputs);
68+
size_t numel = inputs.numel();
69+
int axis_size = inputs.size(axis);
70+
int outer_size = inputs.stride(axis);
7871

79-
bool per_block_scaling_factor = false;
80-
size_t blocksize = numel;
81-
82-
int amax_ndim = amax.dim();
83-
int input_ndim = inputs.dim();
84-
85-
// 3 options:
86-
// 1.
87-
// inputs[numel], amax[1] -> per-tensor scaling
88-
// 2.
89-
// inputs[numel], amax[numel/num_cols] -> per-row / per-channel scaling
90-
// 3.
91-
// inputs[numel/bs, bs], amax[numel/bs, 1] -> blockwise scaling
92-
if (amax.numel() == 1) {
93-
// case 1.
94-
per_block_scaling_factor = false;
95-
} else if (amax.numel() > 1 && (amax_ndim > 1 && (amax.size(-1) == amax.numel()))) {
96-
// case 2.
97-
per_block_scaling_factor = true;
98-
blocksize = numel / amax.numel();
99-
} else {
100-
throw std::runtime_error("invalid combination of inputs and amax shapes/sizes");
101-
}
72+
auto scale = 448.f / amax;
73+
auto inv_scale = 1.f / scale;
10274

10375
auto stream = c10::cuda::getCurrentCUDAStream();
104-
105-
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fused_fake_e4m3fy_cuda", [&] {
106-
fused_fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
107-
inputs.data_ptr<scalar_t>(), numel, amax.data_ptr<float>(), per_block_scaling_factor,
108-
blocksize, zero_threshold, outputs.data_ptr<scalar_t>());
76+
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis_cuda", [&] {
77+
fake_e4m3fy_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
78+
inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
79+
axis_size, outer_size, outputs.data_ptr<scalar_t>());
10980
});
81+
11082
return outputs;
11183
}
11284

113-
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs) {
85+
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs, std::optional<at::Tensor> amax_opt) {
11486
size_t numel = inputs.numel();
87+
at::Tensor scale;
88+
if (amax_opt.has_value()) {
89+
scale = 448.f / amax_opt.value();
90+
} else {
91+
scale = at::ones({1}, inputs.options().dtype(at::kFloat));
92+
}
93+
auto inv_scale = 1.f / scale;
11594
auto outputs = torch::empty_like(inputs);
11695
auto stream = c10::cuda::getCurrentCUDAStream();
11796
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_cuda", [&] {
11897
fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
119-
inputs.data_ptr<scalar_t>(), numel, outputs.data_ptr<scalar_t>());
98+
inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
99+
outputs.data_ptr<scalar_t>());
120100
});
121101
return outputs;
122102
}

0 commit comments

Comments
 (0)