Skip to content

Commit 83d7623

Browse files
authored
Fixed FP8 fake quantization to use fp32 amax scaling; Added support for FP8 per-channel quantization (#381)
Signed-off-by: realAsma <[email protected]>
1 parent c55bcf0 commit 83d7623

File tree

9 files changed

+113
-502
lines changed

9 files changed

+113
-502
lines changed

modelopt/torch/quantization/extensions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def get_cuda_ext_fp8(raise_if_failed: bool = False):
4141
if not hasattr(get_cuda_ext_fp8, "extension"):
4242
get_cuda_ext_fp8.extension = load_cpp_extension( # type:ignore[attr-defined]
4343
name="modelopt_cuda_ext_fp8",
44-
sources=[path / "src/tensor_quant_fp8.cpp", path / "src/tensor_quant_gpu_fp8.cu"],
44+
sources=[path / "src/tensor_quant_gpu_fp8.cu"],
4545
cuda_version_specifiers=">=11.8",
4646
fail_msg=(
4747
"CUDA extension for FP8 quantization could not be built and loaded, FP8 simulated"

modelopt/torch/quantization/src/tensor_quant.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ void fake_tensor_quant_cuda_inplace(at::Tensor, at::Tensor, int, bool, bool);
2222
at::Tensor fake_tensor_quant_cuda(at::Tensor, at::Tensor, int, bool, bool);
2323
at::Tensor fake_tensor_quant_with_axis_cuda(at::Tensor, at::Tensor, int, int, bool, bool);
2424
float bits_to_bound(int, int);
25-
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs);
2625

2726
// Dequantizes data using NF4 quantization scheme and per-block scaling factors.
2827
//

modelopt/torch/quantization/src/tensor_quant_fp8.cpp

Lines changed: 0 additions & 48 deletions
This file was deleted.

modelopt/torch/quantization/src/tensor_quant_gpu.cu

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717

1818
#include <ATen/ATen.h>
19+
#include <c10/cuda/CUDAStream.h>
1920
#include <cuda.h>
2021
#include <cuda_fp16.h>
2122
#include <cuda_runtime.h>
@@ -74,8 +75,9 @@ __global__ void fake_tensor_quant_kernel(const T *inputs, size_t n, T *outputs,
7475
void fake_tensor_quant_cuda_inplace(at::Tensor inputs, at::Tensor amax, int num_bits = 8,
7576
bool is_unsigned = false, bool narrow_range = true) {
7677
size_t numel = inputs.numel();
78+
auto stream = c10::cuda::getCurrentCUDAStream();
7779
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_tensor_quant_cuda_inplace", [&] {
78-
fake_tensor_quant_kernel<<<numel / BLOCK_SIZE + 1, BLOCK_SIZE>>>(
80+
fake_tensor_quant_kernel<<<numel / BLOCK_SIZE + 1, BLOCK_SIZE, 0, stream>>>(
7981
inputs.data_ptr<scalar_t>(), numel, inputs.data_ptr<scalar_t>(),
8082
amax.to(at::ScalarType::Float).data_ptr<float>(), num_bits, is_unsigned, narrow_range);
8183
});
@@ -85,8 +87,9 @@ at::Tensor fake_tensor_quant_cuda(at::Tensor inputs, at::Tensor amax, int num_bi
8587
bool is_unsigned = false, bool narrow_range = true) {
8688
size_t numel = inputs.numel();
8789
auto outputs = torch::empty_like(inputs);
90+
auto stream = c10::cuda::getCurrentCUDAStream();
8891
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_tensor_quant_cuda", [&] {
89-
fake_tensor_quant_kernel<<<numel / BLOCK_SIZE + 1, BLOCK_SIZE>>>(
92+
fake_tensor_quant_kernel<<<numel / BLOCK_SIZE + 1, BLOCK_SIZE, 0, stream>>>(
9093
inputs.data_ptr<scalar_t>(), numel, outputs.data_ptr<scalar_t>(),
9194
amax.to(at::ScalarType::Float).data_ptr<float>(), num_bits, is_unsigned, narrow_range);
9295
});
@@ -125,8 +128,10 @@ at::Tensor fake_tensor_quant_with_axis_cuda(at::Tensor inputs, at::Tensor amax,
125128

126129
int outer_size = inputs.stride(axis);
127130

131+
auto stream = c10::cuda::getCurrentCUDAStream();
128132
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_tensor_quant_cuda_with_axis", [&] {
129-
fake_tensor_quant_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE>>>(
133+
fake_tensor_quant_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0,
134+
stream>>>(
130135
inputs.data_ptr<scalar_t>(), numel, outputs.data_ptr<scalar_t>(),
131136
amax.to(at::ScalarType::Float).data_ptr<float>(), axis_size, outer_size, num_bits,
132137
is_unsigned, narrow_range);

modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu

Lines changed: 40 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <ATen/ATen.h>
1919
#include <c10/cuda/CUDAStream.h>
2020
#include <cuda_fp8.h>
21+
#include <optional>
2122
#include <torch/extension.h>
2223

2324
#define BLOCK_SIZE 128
@@ -31,92 +32,77 @@
3132
#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
3233
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
3334

34-
template <typename T> __global__ void fake_e4m3fy_kernel(const T *inputs, size_t n, T *outputs) {
35+
template <typename T>
36+
__global__ void fake_e4m3fy_kernel(const T *inputs, size_t n, const float *scale,
37+
const float *inv_scale, T *outputs) {
3538
int tid = blockIdx.x * blockDim.x + threadIdx.x;
3639

3740
for (int idx = 4 * tid; idx < 4 * (tid + 1) && idx < n; ++idx) {
3841
outputs[idx] = static_cast<T>(
39-
static_cast<float>(static_cast<__nv_fp8_e4m3>(static_cast<float>(inputs[idx]))));
42+
static_cast<float>(static_cast<__nv_fp8_e4m3>(static_cast<float>(inputs[idx]) * scale[0])) *
43+
inv_scale[0]);
4044
}
4145
}
4246

4347
template <typename T>
44-
__global__ void fused_fake_e4m3fy_kernel(const T *inputs, size_t n, float *amax,
45-
bool per_block_scaling_factor, size_t blocksize,
46-
float zero_threshold, T *outputs) {
48+
__global__ void fake_e4m3fy_with_axis_cuda_kernel(const T *inputs, size_t n, const float *scale,
49+
const float *inv_scale, int axis_size,
50+
int outer_size, T *outputs) {
4751
int tid = blockIdx.x * blockDim.x + threadIdx.x;
4852

4953
for (int idx = 4 * tid; idx < 4 * (tid + 1) && idx < n; ++idx) {
5054
float x = static_cast<float>(inputs[idx]);
5155

52-
// generate mask for zeroing tiny values
53-
float x_abs = fabsf(x);
54-
bool zero_mask = x_abs < zero_threshold;
55-
56-
// grab the global scaling factor
57-
size_t amax_idx = (per_block_scaling_factor) ? (idx / blocksize) : 0;
58-
59-
// compute scale and inverse-scales
60-
float scale = 448.f / (amax[amax_idx]);
61-
float inv_scale = 1.f / scale;
56+
int axis_id = (idx / outer_size) % axis_size;
6257

6358
// compute the output
64-
float output = static_cast<float>(static_cast<__nv_fp8_e4m3>(scale * x)) * inv_scale;
65-
66-
// zero out small values
67-
if (zero_mask) {
68-
output = 0.f;
69-
}
59+
float output =
60+
static_cast<float>(static_cast<__nv_fp8_e4m3>(scale[axis_id] * x)) * inv_scale[axis_id];
7061

7162
outputs[idx] = output;
7263
}
7364
}
7465

75-
at::Tensor fused_fake_e4m3fy_cuda(at::Tensor inputs, at::Tensor amax, const float zero_threshold) {
76-
size_t numel = inputs.numel();
66+
at::Tensor fake_e4m3fy_with_axis(at::Tensor inputs, at::Tensor amax, int axis) {
67+
inputs = inputs.contiguous();
68+
amax = amax.contiguous().to(at::kFloat);
7769
auto outputs = torch::empty_like(inputs);
70+
size_t numel = inputs.numel();
71+
int axis_size = inputs.size(axis);
72+
int outer_size = inputs.stride(axis);
7873

79-
bool per_block_scaling_factor = false;
80-
size_t blocksize = numel;
81-
82-
int amax_ndim = amax.dim();
83-
int input_ndim = inputs.dim();
84-
85-
// 3 options:
86-
// 1.
87-
// inputs[numel], amax[1] -> per-tensor scaling
88-
// 2.
89-
// inputs[numel], amax[numel/num_cols] -> per-row / per-channel scaling
90-
// 3.
91-
// inputs[numel/bs, bs], amax[numel/bs, 1] -> blockwise scaling
92-
if (amax.numel() == 1) {
93-
// case 1.
94-
per_block_scaling_factor = false;
95-
} else if (amax.numel() > 1 && (amax_ndim > 1 && (amax.size(-1) == amax.numel()))) {
96-
// case 2.
97-
per_block_scaling_factor = true;
98-
blocksize = numel / amax.numel();
99-
} else {
100-
throw std::runtime_error("invalid combination of inputs and amax shapes/sizes");
101-
}
74+
auto scale = 448.f / amax;
75+
auto inv_scale = 1.f / scale;
10276

10377
auto stream = c10::cuda::getCurrentCUDAStream();
104-
105-
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fused_fake_e4m3fy_cuda", [&] {
106-
fused_fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
107-
inputs.data_ptr<scalar_t>(), numel, amax.data_ptr<float>(), per_block_scaling_factor,
108-
blocksize, zero_threshold, outputs.data_ptr<scalar_t>());
78+
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis", [&] {
79+
fake_e4m3fy_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
80+
inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
81+
axis_size, outer_size, outputs.data_ptr<scalar_t>());
10982
});
83+
11084
return outputs;
11185
}
11286

113-
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs) {
87+
at::Tensor fake_e4m3fy(at::Tensor inputs, at::Tensor amax) {
88+
inputs = inputs.contiguous();
89+
amax = amax.view(-1).to(at::kFloat);
11490
size_t numel = inputs.numel();
91+
at::Tensor scale = 448.f / amax;
92+
auto inv_scale = 1.f / scale;
11593
auto outputs = torch::empty_like(inputs);
11694
auto stream = c10::cuda::getCurrentCUDAStream();
117-
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_cuda", [&] {
95+
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy", [&] {
11896
fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
119-
inputs.data_ptr<scalar_t>(), numel, outputs.data_ptr<scalar_t>());
97+
inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
98+
outputs.data_ptr<scalar_t>());
12099
});
121100
return outputs;
122101
}
102+
103+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
104+
m.def("fake_e4m3fy", &fake_e4m3fy, "Reduce precision to E4M3", py::arg("inputs"),
105+
py::arg("amax"));
106+
m.def("fake_e4m3fy_with_axis", &fake_e4m3fy_with_axis, "Reduce precision to E4M3 (fused)",
107+
py::arg("inputs"), py::arg("amax"), py::arg("axis"));
108+
}

0 commit comments

Comments
 (0)