Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modelopt/torch/quantization/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_cuda_ext_fp8(raise_if_failed: bool = False):
if not hasattr(get_cuda_ext_fp8, "extension"):
get_cuda_ext_fp8.extension = load_cpp_extension( # type:ignore[attr-defined]
name="modelopt_cuda_ext_fp8",
sources=[path / "src/tensor_quant_fp8.cpp", path / "src/tensor_quant_gpu_fp8.cu"],
sources=[path / "src/tensor_quant_gpu_fp8.cu"],
cuda_version_specifiers=">=11.8",
fail_msg=(
"CUDA extension for FP8 quantization could not be built and loaded, FP8 simulated"
Expand Down
1 change: 0 additions & 1 deletion modelopt/torch/quantization/src/tensor_quant.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ void fake_tensor_quant_cuda_inplace(at::Tensor, at::Tensor, int, bool, bool);
at::Tensor fake_tensor_quant_cuda(at::Tensor, at::Tensor, int, bool, bool);
at::Tensor fake_tensor_quant_with_axis_cuda(at::Tensor, at::Tensor, int, int, bool, bool);
float bits_to_bound(int, int);
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs);

// Dequantizes data using NF4 quantization scheme and per-block scaling factors.
//
Expand Down
48 changes: 0 additions & 48 deletions modelopt/torch/quantization/src/tensor_quant_fp8.cpp

This file was deleted.

11 changes: 8 additions & 3 deletions modelopt/torch/quantization/src/tensor_quant_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#include <ATen/ATen.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -74,8 +75,9 @@ __global__ void fake_tensor_quant_kernel(const T *inputs, size_t n, T *outputs,
void fake_tensor_quant_cuda_inplace(at::Tensor inputs, at::Tensor amax, int num_bits = 8,
bool is_unsigned = false, bool narrow_range = true) {
size_t numel = inputs.numel();
auto stream = c10::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_tensor_quant_cuda_inplace", [&] {
fake_tensor_quant_kernel<<<numel / BLOCK_SIZE + 1, BLOCK_SIZE>>>(
fake_tensor_quant_kernel<<<numel / BLOCK_SIZE + 1, BLOCK_SIZE, 0, stream>>>(
inputs.data_ptr<scalar_t>(), numel, inputs.data_ptr<scalar_t>(),
amax.to(at::ScalarType::Float).data_ptr<float>(), num_bits, is_unsigned, narrow_range);
});
Expand All @@ -85,8 +87,9 @@ at::Tensor fake_tensor_quant_cuda(at::Tensor inputs, at::Tensor amax, int num_bi
bool is_unsigned = false, bool narrow_range = true) {
size_t numel = inputs.numel();
auto outputs = torch::empty_like(inputs);
auto stream = c10::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_tensor_quant_cuda", [&] {
fake_tensor_quant_kernel<<<numel / BLOCK_SIZE + 1, BLOCK_SIZE>>>(
fake_tensor_quant_kernel<<<numel / BLOCK_SIZE + 1, BLOCK_SIZE, 0, stream>>>(
inputs.data_ptr<scalar_t>(), numel, outputs.data_ptr<scalar_t>(),
amax.to(at::ScalarType::Float).data_ptr<float>(), num_bits, is_unsigned, narrow_range);
});
Expand Down Expand Up @@ -125,8 +128,10 @@ at::Tensor fake_tensor_quant_with_axis_cuda(at::Tensor inputs, at::Tensor amax,

int outer_size = inputs.stride(axis);

auto stream = c10::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_tensor_quant_cuda_with_axis", [&] {
fake_tensor_quant_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE>>>(
fake_tensor_quant_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0,
stream>>>(
inputs.data_ptr<scalar_t>(), numel, outputs.data_ptr<scalar_t>(),
amax.to(at::ScalarType::Float).data_ptr<float>(), axis_size, outer_size, num_bits,
is_unsigned, narrow_range);
Expand Down
94 changes: 40 additions & 54 deletions modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <ATen/ATen.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_fp8.h>
#include <optional>
#include <torch/extension.h>

#define BLOCK_SIZE 128
Expand All @@ -31,92 +32,77 @@
#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

template <typename T> __global__ void fake_e4m3fy_kernel(const T *inputs, size_t n, T *outputs) {
template <typename T>
__global__ void fake_e4m3fy_kernel(const T *inputs, size_t n, const float *scale,
const float *inv_scale, T *outputs) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;

for (int idx = 4 * tid; idx < 4 * (tid + 1) && idx < n; ++idx) {
outputs[idx] = static_cast<T>(
static_cast<float>(static_cast<__nv_fp8_e4m3>(static_cast<float>(inputs[idx]))));
static_cast<float>(static_cast<__nv_fp8_e4m3>(static_cast<float>(inputs[idx]) * scale[0])) *
inv_scale[0]);
}
}

template <typename T>
__global__ void fused_fake_e4m3fy_kernel(const T *inputs, size_t n, float *amax,
bool per_block_scaling_factor, size_t blocksize,
float zero_threshold, T *outputs) {
__global__ void fake_e4m3fy_with_axis_cuda_kernel(const T *inputs, size_t n, const float *scale,
const float *inv_scale, int axis_size,
int outer_size, T *outputs) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;

for (int idx = 4 * tid; idx < 4 * (tid + 1) && idx < n; ++idx) {
float x = static_cast<float>(inputs[idx]);

// generate mask for zeroing tiny values
float x_abs = fabsf(x);
bool zero_mask = x_abs < zero_threshold;

// grab the global scaling factor
size_t amax_idx = (per_block_scaling_factor) ? (idx / blocksize) : 0;

// compute scale and inverse-scales
float scale = 448.f / (amax[amax_idx]);
float inv_scale = 1.f / scale;
int axis_id = (idx / outer_size) % axis_size;

// compute the output
float output = static_cast<float>(static_cast<__nv_fp8_e4m3>(scale * x)) * inv_scale;

// zero out small values
if (zero_mask) {
output = 0.f;
}
float output =
static_cast<float>(static_cast<__nv_fp8_e4m3>(scale[axis_id] * x)) * inv_scale[axis_id];

outputs[idx] = output;
}
}

at::Tensor fused_fake_e4m3fy_cuda(at::Tensor inputs, at::Tensor amax, const float zero_threshold) {
size_t numel = inputs.numel();
at::Tensor fake_e4m3fy_with_axis(at::Tensor inputs, at::Tensor amax, int axis) {
inputs = inputs.contiguous();
amax = amax.contiguous().to(at::kFloat);
auto outputs = torch::empty_like(inputs);
size_t numel = inputs.numel();
int axis_size = inputs.size(axis);
int outer_size = inputs.stride(axis);

bool per_block_scaling_factor = false;
size_t blocksize = numel;

int amax_ndim = amax.dim();
int input_ndim = inputs.dim();

// 3 options:
// 1.
// inputs[numel], amax[1] -> per-tensor scaling
// 2.
// inputs[numel], amax[numel/num_cols] -> per-row / per-channel scaling
// 3.
// inputs[numel/bs, bs], amax[numel/bs, 1] -> blockwise scaling
if (amax.numel() == 1) {
// case 1.
per_block_scaling_factor = false;
} else if (amax.numel() > 1 && (amax_ndim > 1 && (amax.size(-1) == amax.numel()))) {
// case 2.
per_block_scaling_factor = true;
blocksize = numel / amax.numel();
} else {
throw std::runtime_error("invalid combination of inputs and amax shapes/sizes");
}
auto scale = 448.f / amax;
auto inv_scale = 1.f / scale;

auto stream = c10::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fused_fake_e4m3fy_cuda", [&] {
fused_fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
inputs.data_ptr<scalar_t>(), numel, amax.data_ptr<float>(), per_block_scaling_factor,
blocksize, zero_threshold, outputs.data_ptr<scalar_t>());
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis", [&] {
fake_e4m3fy_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
axis_size, outer_size, outputs.data_ptr<scalar_t>());
});

return outputs;
}

at::Tensor fake_e4m3fy_cuda(at::Tensor inputs) {
at::Tensor fake_e4m3fy(at::Tensor inputs, at::Tensor amax) {
inputs = inputs.contiguous();
amax = amax.view(-1).to(at::kFloat);
size_t numel = inputs.numel();
at::Tensor scale = 448.f / amax;
auto inv_scale = 1.f / scale;
auto outputs = torch::empty_like(inputs);
auto stream = c10::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_cuda", [&] {
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy", [&] {
fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
inputs.data_ptr<scalar_t>(), numel, outputs.data_ptr<scalar_t>());
inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
outputs.data_ptr<scalar_t>());
});
return outputs;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fake_e4m3fy", &fake_e4m3fy, "Reduce precision to E4M3", py::arg("inputs"),
py::arg("amax"));
m.def("fake_e4m3fy_with_axis", &fake_e4m3fy_with_axis, "Reduce precision to E4M3 (fused)",
py::arg("inputs"), py::arg("amax"), py::arg("axis"));
}
Loading