|
18 | 18 | #include <ATen/ATen.h>
|
19 | 19 | #include <c10/cuda/CUDAStream.h>
|
20 | 20 | #include <cuda_fp8.h>
|
| 21 | +#include <optional> |
21 | 22 | #include <torch/extension.h>
|
22 | 23 |
|
23 | 24 | #define BLOCK_SIZE 128
|
|
31 | 32 | #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
32 | 33 | AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
33 | 34 |
|
34 |
| -template <typename T> __global__ void fake_e4m3fy_kernel(const T *inputs, size_t n, T *outputs) { |
| 35 | +template <typename T> |
| 36 | +__global__ void fake_e4m3fy_kernel(const T *inputs, size_t n, const float *scale, |
| 37 | + const float *inv_scale, T *outputs) { |
35 | 38 | int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
36 | 39 |
|
37 | 40 | for (int idx = 4 * tid; idx < 4 * (tid + 1) && idx < n; ++idx) {
|
38 | 41 | outputs[idx] = static_cast<T>(
|
39 |
| - static_cast<float>(static_cast<__nv_fp8_e4m3>(static_cast<float>(inputs[idx])))); |
| 42 | + static_cast<float>(static_cast<__nv_fp8_e4m3>(static_cast<float>(inputs[idx]) * scale[0])) * |
| 43 | + inv_scale[0]); |
40 | 44 | }
|
41 | 45 | }
|
42 | 46 |
|
43 | 47 | template <typename T>
|
44 |
| -__global__ void fused_fake_e4m3fy_kernel(const T *inputs, size_t n, float *amax, |
45 |
| - bool per_block_scaling_factor, size_t blocksize, |
46 |
| - float zero_threshold, T *outputs) { |
| 48 | +__global__ void fake_e4m3fy_with_axis_cuda_kernel(const T *inputs, size_t n, const float *scale, |
| 49 | + const float *inv_scale, int axis_size, |
| 50 | + int outer_size, T *outputs) { |
47 | 51 | int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
48 | 52 |
|
49 | 53 | for (int idx = 4 * tid; idx < 4 * (tid + 1) && idx < n; ++idx) {
|
50 | 54 | float x = static_cast<float>(inputs[idx]);
|
51 | 55 |
|
52 |
| - // generate mask for zeroing tiny values |
53 |
| - float x_abs = fabsf(x); |
54 |
| - bool zero_mask = x_abs < zero_threshold; |
55 |
| - |
56 |
| - // grab the global scaling factor |
57 |
| - size_t amax_idx = (per_block_scaling_factor) ? (idx / blocksize) : 0; |
58 |
| - |
59 |
| - // compute scale and inverse-scales |
60 |
| - float scale = 448.f / (amax[amax_idx]); |
61 |
| - float inv_scale = 1.f / scale; |
| 56 | + int axis_id = (idx / outer_size) % axis_size; |
62 | 57 |
|
63 | 58 | // compute the output
|
64 |
| - float output = static_cast<float>(static_cast<__nv_fp8_e4m3>(scale * x)) * inv_scale; |
65 |
| - |
66 |
| - // zero out small values |
67 |
| - if (zero_mask) { |
68 |
| - output = 0.f; |
69 |
| - } |
| 59 | + float output = |
| 60 | + static_cast<float>(static_cast<__nv_fp8_e4m3>(scale[axis_id] * x)) * inv_scale[axis_id]; |
70 | 61 |
|
71 | 62 | outputs[idx] = output;
|
72 | 63 | }
|
73 | 64 | }
|
74 | 65 |
|
75 |
| -at::Tensor fused_fake_e4m3fy_cuda(at::Tensor inputs, at::Tensor amax, const float zero_threshold) { |
76 |
| - size_t numel = inputs.numel(); |
| 66 | +at::Tensor fake_e4m3fy_with_axis(at::Tensor inputs, at::Tensor amax, int axis) { |
| 67 | + inputs = inputs.contiguous(); |
| 68 | + amax = amax.contiguous().to(at::kFloat); |
77 | 69 | auto outputs = torch::empty_like(inputs);
|
| 70 | + size_t numel = inputs.numel(); |
| 71 | + int axis_size = inputs.size(axis); |
| 72 | + int outer_size = inputs.stride(axis); |
78 | 73 |
|
79 |
| - bool per_block_scaling_factor = false; |
80 |
| - size_t blocksize = numel; |
81 |
| - |
82 |
| - int amax_ndim = amax.dim(); |
83 |
| - int input_ndim = inputs.dim(); |
84 |
| - |
85 |
| - // 3 options: |
86 |
| - // 1. |
87 |
| - // inputs[numel], amax[1] -> per-tensor scaling |
88 |
| - // 2. |
89 |
| - // inputs[numel], amax[numel/num_cols] -> per-row / per-channel scaling |
90 |
| - // 3. |
91 |
| - // inputs[numel/bs, bs], amax[numel/bs, 1] -> blockwise scaling |
92 |
| - if (amax.numel() == 1) { |
93 |
| - // case 1. |
94 |
| - per_block_scaling_factor = false; |
95 |
| - } else if (amax.numel() > 1 && (amax_ndim > 1 && (amax.size(-1) == amax.numel()))) { |
96 |
| - // case 2. |
97 |
| - per_block_scaling_factor = true; |
98 |
| - blocksize = numel / amax.numel(); |
99 |
| - } else { |
100 |
| - throw std::runtime_error("invalid combination of inputs and amax shapes/sizes"); |
101 |
| - } |
| 74 | + auto scale = 448.f / amax; |
| 75 | + auto inv_scale = 1.f / scale; |
102 | 76 |
|
103 | 77 | auto stream = c10::cuda::getCurrentCUDAStream();
|
104 |
| - |
105 |
| - AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fused_fake_e4m3fy_cuda", [&] { |
106 |
| - fused_fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>( |
107 |
| - inputs.data_ptr<scalar_t>(), numel, amax.data_ptr<float>(), per_block_scaling_factor, |
108 |
| - blocksize, zero_threshold, outputs.data_ptr<scalar_t>()); |
| 78 | + AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis", [&] { |
| 79 | + fake_e4m3fy_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>( |
| 80 | + inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(), |
| 81 | + axis_size, outer_size, outputs.data_ptr<scalar_t>()); |
109 | 82 | });
|
| 83 | + |
110 | 84 | return outputs;
|
111 | 85 | }
|
112 | 86 |
|
113 |
| -at::Tensor fake_e4m3fy_cuda(at::Tensor inputs) { |
| 87 | +at::Tensor fake_e4m3fy(at::Tensor inputs, at::Tensor amax) { |
| 88 | + inputs = inputs.contiguous(); |
| 89 | + amax = amax.view(-1).to(at::kFloat); |
114 | 90 | size_t numel = inputs.numel();
|
| 91 | + at::Tensor scale = 448.f / amax; |
| 92 | + auto inv_scale = 1.f / scale; |
115 | 93 | auto outputs = torch::empty_like(inputs);
|
116 | 94 | auto stream = c10::cuda::getCurrentCUDAStream();
|
117 |
| - AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_cuda", [&] { |
| 95 | + AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy", [&] { |
118 | 96 | fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
|
119 |
| - inputs.data_ptr<scalar_t>(), numel, outputs.data_ptr<scalar_t>()); |
| 97 | + inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(), |
| 98 | + outputs.data_ptr<scalar_t>()); |
120 | 99 | });
|
121 | 100 | return outputs;
|
122 | 101 | }
|
| 102 | + |
| 103 | +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| 104 | + m.def("fake_e4m3fy", &fake_e4m3fy, "Reduce precision to E4M3", py::arg("inputs"), |
| 105 | + py::arg("amax")); |
| 106 | + m.def("fake_e4m3fy_with_axis", &fake_e4m3fy_with_axis, "Reduce precision to E4M3 (fused)", |
| 107 | + py::arg("inputs"), py::arg("amax"), py::arg("axis")); |
| 108 | +} |
0 commit comments