Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modelopt/torch/quantization/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_cuda_ext_fp8(raise_if_failed: bool = False):
if not hasattr(get_cuda_ext_fp8, "extension"):
get_cuda_ext_fp8.extension = load_cpp_extension( # type:ignore[attr-defined]
name="modelopt_cuda_ext_fp8",
sources=[path / "src/tensor_quant_fp8.cpp", path / "src/tensor_quant_gpu_fp8.cu"],
sources=[path / "src/tensor_quant_gpu_fp8.cu"],
cuda_version_specifiers=">=11.8",
fail_msg=(
"CUDA extension for FP8 quantization could not be built and loaded, FP8 simulated"
Expand Down
1 change: 0 additions & 1 deletion modelopt/torch/quantization/src/tensor_quant.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ void fake_tensor_quant_cuda_inplace(at::Tensor, at::Tensor, int, bool, bool);
at::Tensor fake_tensor_quant_cuda(at::Tensor, at::Tensor, int, bool, bool);
at::Tensor fake_tensor_quant_with_axis_cuda(at::Tensor, at::Tensor, int, int, bool, bool);
float bits_to_bound(int, int);
at::Tensor fake_e4m3fy_cuda(at::Tensor inputs);

// Dequantizes data using NF4 quantization scheme and per-block scaling factors.
//
Expand Down
48 changes: 0 additions & 48 deletions modelopt/torch/quantization/src/tensor_quant_fp8.cpp

This file was deleted.

11 changes: 8 additions & 3 deletions modelopt/torch/quantization/src/tensor_quant_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#include <ATen/ATen.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -74,8 +75,9 @@ __global__ void fake_tensor_quant_kernel(const T *inputs, size_t n, T *outputs,
void fake_tensor_quant_cuda_inplace(at::Tensor inputs, at::Tensor amax, int num_bits = 8,
bool is_unsigned = false, bool narrow_range = true) {
size_t numel = inputs.numel();
auto stream = c10::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_tensor_quant_cuda_inplace", [&] {
fake_tensor_quant_kernel<<<numel / BLOCK_SIZE + 1, BLOCK_SIZE>>>(
fake_tensor_quant_kernel<<<numel / BLOCK_SIZE + 1, BLOCK_SIZE, 0, stream>>>(
inputs.data_ptr<scalar_t>(), numel, inputs.data_ptr<scalar_t>(),
amax.to(at::ScalarType::Float).data_ptr<float>(), num_bits, is_unsigned, narrow_range);
});
Expand All @@ -85,8 +87,9 @@ at::Tensor fake_tensor_quant_cuda(at::Tensor inputs, at::Tensor amax, int num_bi
bool is_unsigned = false, bool narrow_range = true) {
size_t numel = inputs.numel();
auto outputs = torch::empty_like(inputs);
auto stream = c10::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_tensor_quant_cuda", [&] {
fake_tensor_quant_kernel<<<numel / BLOCK_SIZE + 1, BLOCK_SIZE>>>(
fake_tensor_quant_kernel<<<numel / BLOCK_SIZE + 1, BLOCK_SIZE, 0, stream>>>(
inputs.data_ptr<scalar_t>(), numel, outputs.data_ptr<scalar_t>(),
amax.to(at::ScalarType::Float).data_ptr<float>(), num_bits, is_unsigned, narrow_range);
});
Expand Down Expand Up @@ -125,8 +128,10 @@ at::Tensor fake_tensor_quant_with_axis_cuda(at::Tensor inputs, at::Tensor amax,

int outer_size = inputs.stride(axis);

auto stream = c10::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_tensor_quant_cuda_with_axis", [&] {
fake_tensor_quant_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE>>>(
fake_tensor_quant_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0,
stream>>>(
inputs.data_ptr<scalar_t>(), numel, outputs.data_ptr<scalar_t>(),
amax.to(at::ScalarType::Float).data_ptr<float>(), axis_size, outer_size, num_bits,
is_unsigned, narrow_range);
Expand Down
94 changes: 40 additions & 54 deletions modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <ATen/ATen.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_fp8.h>
#include <optional>
#include <torch/extension.h>

#define BLOCK_SIZE 128
Expand All @@ -31,92 +32,77 @@
#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

template <typename T> __global__ void fake_e4m3fy_kernel(const T *inputs, size_t n, T *outputs) {
template <typename T>
__global__ void fake_e4m3fy_kernel(const T *inputs, size_t n, const float *scale,
const float *inv_scale, T *outputs) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;

for (int idx = 4 * tid; idx < 4 * (tid + 1) && idx < n; ++idx) {
outputs[idx] = static_cast<T>(
static_cast<float>(static_cast<__nv_fp8_e4m3>(static_cast<float>(inputs[idx]))));
static_cast<float>(static_cast<__nv_fp8_e4m3>(static_cast<float>(inputs[idx]) * scale[0])) *
inv_scale[0]);
}
}

template <typename T>
__global__ void fused_fake_e4m3fy_kernel(const T *inputs, size_t n, float *amax,
bool per_block_scaling_factor, size_t blocksize,
float zero_threshold, T *outputs) {
__global__ void fake_e4m3fy_with_axis_cuda_kernel(const T *inputs, size_t n, const float *scale,
const float *inv_scale, int axis_size,
int outer_size, T *outputs) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;

for (int idx = 4 * tid; idx < 4 * (tid + 1) && idx < n; ++idx) {
float x = static_cast<float>(inputs[idx]);

// generate mask for zeroing tiny values
float x_abs = fabsf(x);
bool zero_mask = x_abs < zero_threshold;

// grab the global scaling factor
size_t amax_idx = (per_block_scaling_factor) ? (idx / blocksize) : 0;

// compute scale and inverse-scales
float scale = 448.f / (amax[amax_idx]);
float inv_scale = 1.f / scale;
int axis_id = (idx / outer_size) % axis_size;

// compute the output
float output = static_cast<float>(static_cast<__nv_fp8_e4m3>(scale * x)) * inv_scale;

// zero out small values
if (zero_mask) {
output = 0.f;
}
float output =
static_cast<float>(static_cast<__nv_fp8_e4m3>(scale[axis_id] * x)) * inv_scale[axis_id];

outputs[idx] = output;
}
}

at::Tensor fused_fake_e4m3fy_cuda(at::Tensor inputs, at::Tensor amax, const float zero_threshold) {
size_t numel = inputs.numel();
at::Tensor fake_e4m3fy_with_axis(at::Tensor inputs, at::Tensor amax, int axis) {
inputs = inputs.contiguous();
amax = amax.contiguous().to(at::kFloat);
auto outputs = torch::empty_like(inputs);
size_t numel = inputs.numel();
int axis_size = inputs.size(axis);
int outer_size = inputs.stride(axis);

bool per_block_scaling_factor = false;
size_t blocksize = numel;

int amax_ndim = amax.dim();
int input_ndim = inputs.dim();

// 3 options:
// 1.
// inputs[numel], amax[1] -> per-tensor scaling
// 2.
// inputs[numel], amax[numel/num_cols] -> per-row / per-channel scaling
// 3.
// inputs[numel/bs, bs], amax[numel/bs, 1] -> blockwise scaling
if (amax.numel() == 1) {
// case 1.
per_block_scaling_factor = false;
} else if (amax.numel() > 1 && (amax_ndim > 1 && (amax.size(-1) == amax.numel()))) {
// case 2.
per_block_scaling_factor = true;
blocksize = numel / amax.numel();
} else {
throw std::runtime_error("invalid combination of inputs and amax shapes/sizes");
}
auto scale = 448.f / amax;
auto inv_scale = 1.f / scale;

auto stream = c10::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fused_fake_e4m3fy_cuda", [&] {
fused_fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
inputs.data_ptr<scalar_t>(), numel, amax.data_ptr<float>(), per_block_scaling_factor,
blocksize, zero_threshold, outputs.data_ptr<scalar_t>());
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis", [&] {
fake_e4m3fy_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
axis_size, outer_size, outputs.data_ptr<scalar_t>());
});

return outputs;
}

at::Tensor fake_e4m3fy_cuda(at::Tensor inputs) {
at::Tensor fake_e4m3fy(at::Tensor inputs, at::Tensor amax) {
inputs = inputs.contiguous();
amax = amax.view(-1).to(at::kFloat);
size_t numel = inputs.numel();
at::Tensor scale = 448.f / amax;
auto inv_scale = 1.f / scale;
auto outputs = torch::empty_like(inputs);
auto stream = c10::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_cuda", [&] {
AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy", [&] {
fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
inputs.data_ptr<scalar_t>(), numel, outputs.data_ptr<scalar_t>());
inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(),
outputs.data_ptr<scalar_t>());
});
return outputs;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fake_e4m3fy", &fake_e4m3fy, "Reduce precision to E4M3", py::arg("inputs"),
py::arg("amax"));
m.def("fake_e4m3fy_with_axis", &fake_e4m3fy_with_axis, "Reduce precision to E4M3 (fused)",
py::arg("inputs"), py::arg("amax"), py::arg("axis"));
}
Loading
Loading