Skip to content

Commit 97dc2ef

Browse files
committed
Use fp8_eager for cpu or if amax is None
1 parent ddbd66d commit 97dc2ef

File tree

4 files changed

+39
-81
lines changed

4 files changed

+39
-81
lines changed

modelopt/torch/quantization/extensions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def get_cuda_ext_fp8(raise_if_failed: bool = False):
4141
if not hasattr(get_cuda_ext_fp8, "extension"):
4242
get_cuda_ext_fp8.extension = load_cpp_extension( # type:ignore[attr-defined]
4343
name="modelopt_cuda_ext_fp8",
44-
sources=[path / "src/tensor_quant_fp8.cpp", path / "src/tensor_quant_gpu_fp8.cu"],
44+
sources=[path / "src/tensor_quant_gpu_fp8.cu"],
4545
cuda_version_specifiers=">=11.8",
4646
fail_msg=(
4747
"CUDA extension for FP8 quantization could not be built and loaded, FP8 simulated"

modelopt/torch/quantization/src/tensor_quant_fp8.cpp

Lines changed: 0 additions & 60 deletions
This file was deleted.

modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
#include <ATen/ATen.h>
1919
#include <c10/cuda/CUDAStream.h>
2020
#include <cuda_fp8.h>
21-
#include <torch/extension.h>
2221
#include <optional>
22+
#include <torch/extension.h>
2323

2424
#define BLOCK_SIZE 128
2525

@@ -63,7 +63,9 @@ __global__ void fake_e4m3fy_with_axis_cuda_kernel(const T *inputs, size_t n, con
6363
}
6464
}
6565

66-
at::Tensor fake_e4m3fy_cuda_with_axis(at::Tensor inputs, at::Tensor amax, int axis) {
66+
at::Tensor fake_e4m3fy_with_axis(at::Tensor inputs, at::Tensor amax, int axis) {
67+
inputs = inputs.contiguous();
68+
amax = amax.contiguous();
6769
auto outputs = torch::empty_like(inputs);
6870
size_t numel = inputs.numel();
6971
int axis_size = inputs.size(axis);
@@ -73,7 +75,7 @@ at::Tensor fake_e4m3fy_cuda_with_axis(at::Tensor inputs, at::Tensor amax, int ax
7375
auto inv_scale = 1.f / scale;
7476

7577
auto stream = c10::cuda::getCurrentCUDAStream();
76-
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis_cuda", [&] {
78+
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis", [&] {
7779
fake_e4m3fy_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
7880
inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
7981
axis_size, outer_size, outputs.data_ptr<scalar_t>());
@@ -82,21 +84,24 @@ at::Tensor fake_e4m3fy_cuda_with_axis(at::Tensor inputs, at::Tensor amax, int ax
8284
return outputs;
8385
}
8486

85-
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs, std::optional<at::Tensor> amax_opt) {
87+
at::Tensor fake_e4m3fy(at::Tensor inputs, at::Tensor amax) {
88+
inputs = inputs.contiguous();
8689
size_t numel = inputs.numel();
87-
at::Tensor scale;
88-
if (amax_opt.has_value()) {
89-
scale = 448.f / amax_opt.value();
90-
} else {
91-
scale = at::ones({1}, inputs.options().dtype(at::kFloat));
92-
}
90+
at::Tensor scale = 448.f / amax;
9391
auto inv_scale = 1.f / scale;
9492
auto outputs = torch::empty_like(inputs);
9593
auto stream = c10::cuda::getCurrentCUDAStream();
96-
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_cuda", [&] {
94+
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy", [&] {
9795
fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
9896
inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
9997
outputs.data_ptr<scalar_t>());
10098
});
10199
return outputs;
102100
}
101+
102+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
103+
m.def("fake_e4m3fy", &fake_e4m3fy, "Reduce precision to E4M3", py::arg("inputs"),
104+
py::arg("amax"));
105+
m.def("fake_e4m3fy_with_axis", &fake_e4m3fy_with_axis, "Reduce precision to E4M3 (fused)",
106+
py::arg("inputs"), py::arg("amax"), py::arg("axis"));
107+
}

modelopt/torch/quantization/tensor_quant.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,32 @@
4141

4242
DISABLE_TRITON_KERNEL = False
4343

44-
def _fp8_eager(x, amax):
44+
45+
def _fp8_eager(x, amax=None):
4546
dtype = x.dtype
46-
x = x.to(torch.float32)
47-
scale = 448.0 / (amax.to(torch.float32))
48-
scale_inv = 1 / scale
49-
x = (x*scale).to(torch.float8_e4m3fn).to(torch.float32)*scale_inv
47+
if amax is not None:
48+
scale = 448.0 / (amax.to(torch.float32))
49+
scale_inv = 1 / scale
50+
x = x.to(torch.float32) * scale
51+
x = x.to(torch.float8_e4m3fn)
52+
if amax is not None:
53+
x = x.to(torch.float32) * scale_inv
5054
return x.to(dtype)
5155

56+
5257
@torch.compile(dynamic=True)
5358
def _fp8_triton(x, amax):
5459
return _fp8_eager(x, amax)
5560

61+
5662
def fp8_eager(x, amax):
63+
"""Eager mode implementation of FP8 quantization."""
5764
if triton_kernel.IS_AVAILABLE and not DISABLE_TRITON_KERNEL:
5865
return _fp8_triton(x, amax)
5966
else:
6067
return _fp8_eager(x, amax)
6168

69+
6270
def scaled_e4m3_impl(
6371
inputs: torch.Tensor,
6472
amax: torch.Tensor | None = None,
@@ -72,16 +80,21 @@ def scaled_e4m3_impl(
7280
Returns:
7381
Input tensors faked quantized to FP8.
7482
"""
75-
cuda_ext_fp8 = get_cuda_ext_fp8(raise_if_failed=True)
83+
if inputs.is_cpu:
84+
return fp8_eager(inputs, amax)
85+
86+
cuda_ext_fp8 = get_cuda_ext_fp8(raise_if_failed=False)
87+
if cuda_ext_fp8 is None or amax is None:
88+
return fp8_eager(inputs, amax)
7689

7790
with torch.cuda.device(
7891
None if inputs.device.index == torch.cuda.current_device() else inputs.device.index
7992
):
80-
if amax is None or amax.numel() == 1:
93+
if amax.numel() == 1:
8194
outputs = cuda_ext_fp8.fake_e4m3fy(inputs, amax)
8295
elif amax.squeeze().ndim == 1:
83-
axis = amax.shape.index(amax.numel())
84-
outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
96+
axis = amax.shape.index(amax.numel())
97+
outputs = cuda_ext_fp8.fake_e4m3fy_with_axis(inputs, amax.squeeze(), axis)
8598
else:
8699
outputs = fp8_eager(inputs, amax)
87100
return outputs

0 commit comments

Comments
 (0)