|
31 | 31 | #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
32 | 32 | AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
33 | 33 |
|
34 |
| -template <typename T> __global__ void fake_e4m3fy_kernel(const T *inputs, size_t n, T *outputs) { |
| 34 | +template <typename T> |
| 35 | +__global__ void fake_e4m3fy_kernel(const T *inputs, size_t n, const float *scale, |
| 36 | + const float *inv_scale, T *outputs) { |
35 | 37 | int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
36 | 38 |
|
37 | 39 | for (int idx = 4 * tid; idx < 4 * (tid + 1) && idx < n; ++idx) {
|
38 | 40 | outputs[idx] = static_cast<T>(
|
39 |
| - static_cast<float>(static_cast<__nv_fp8_e4m3>(static_cast<float>(inputs[idx])))); |
| 41 | + static_cast<float>(static_cast<__nv_fp8_e4m3>(static_cast<float>(inputs[idx]) * scale[0])) * |
| 42 | + inv_scale[0]); |
40 | 43 | }
|
41 | 44 | }
|
42 | 45 |
|
43 | 46 | template <typename T>
|
44 |
| -__global__ void fused_fake_e4m3fy_kernel(const T *inputs, size_t n, float *amax, |
45 |
| - bool per_block_scaling_factor, size_t blocksize, |
46 |
| - float zero_threshold, T *outputs) { |
| 47 | +__global__ void fake_e4m3fy_with_axis_cuda_kernel(const T *inputs, size_t n, const float *scale, |
| 48 | + const float *inv_scale, int axis_size, |
| 49 | + int outer_size, T *outputs) { |
47 | 50 | int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
48 | 51 |
|
49 | 52 | for (int idx = 4 * tid; idx < 4 * (tid + 1) && idx < n; ++idx) {
|
50 | 53 | float x = static_cast<float>(inputs[idx]);
|
51 | 54 |
|
52 |
| - // generate mask for zeroing tiny values |
53 |
| - float x_abs = fabsf(x); |
54 |
| - bool zero_mask = x_abs < zero_threshold; |
55 |
| - |
56 |
| - // grab the global scaling factor |
57 |
| - size_t amax_idx = (per_block_scaling_factor) ? (idx / blocksize) : 0; |
58 |
| - |
59 |
| - // compute scale and inverse-scales |
60 |
| - float scale = 448.f / (amax[amax_idx]); |
61 |
| - float inv_scale = 1.f / scale; |
| 55 | + int axis_id = (idx / outer_size) % axis_size; |
62 | 56 |
|
63 | 57 | // compute the output
|
64 |
| - float output = static_cast<float>(static_cast<__nv_fp8_e4m3>(scale * x)) * inv_scale; |
65 |
| - |
66 |
| - // zero out small values |
67 |
| - if (zero_mask) { |
68 |
| - output = 0.f; |
69 |
| - } |
| 58 | + float output = |
| 59 | + static_cast<float>(static_cast<__nv_fp8_e4m3>(scale[axis_id] * x)) * inv_scale[axis_id]; |
70 | 60 |
|
71 | 61 | outputs[idx] = output;
|
72 | 62 | }
|
73 | 63 | }
|
74 | 64 |
|
75 |
| -at::Tensor fused_fake_e4m3fy_cuda(at::Tensor inputs, at::Tensor amax, const float zero_threshold) { |
76 |
| - size_t numel = inputs.numel(); |
| 65 | +at::Tensor fake_e4m3fy_cuda_with_axis(at::Tensor inputs, at::Tensor amax, int axis) { |
77 | 66 | auto outputs = torch::empty_like(inputs);
|
| 67 | + size_t numel = inputs.numel(); |
| 68 | + int axis_size = inputs.size(axis); |
| 69 | + int outer_size = inputs.stride(axis); |
78 | 70 |
|
79 |
| - bool per_block_scaling_factor = false; |
80 |
| - size_t blocksize = numel; |
81 |
| - |
82 |
| - int amax_ndim = amax.dim(); |
83 |
| - int input_ndim = inputs.dim(); |
84 |
| - |
85 |
| - // 3 options: |
86 |
| - // 1. |
87 |
| - // inputs[numel], amax[1] -> per-tensor scaling |
88 |
| - // 2. |
89 |
| - // inputs[numel], amax[numel/num_cols] -> per-row / per-channel scaling |
90 |
| - // 3. |
91 |
| - // inputs[numel/bs, bs], amax[numel/bs, 1] -> blockwise scaling |
92 |
| - if (amax.numel() == 1) { |
93 |
| - // case 1. |
94 |
| - per_block_scaling_factor = false; |
95 |
| - } else if (amax.numel() > 1 && (amax_ndim > 1 && (amax.size(-1) == amax.numel()))) { |
96 |
| - // case 2. |
97 |
| - per_block_scaling_factor = true; |
98 |
| - blocksize = numel / amax.numel(); |
99 |
| - } else { |
100 |
| - throw std::runtime_error("invalid combination of inputs and amax shapes/sizes"); |
101 |
| - } |
| 71 | + auto scale = 448.f / amax; |
| 72 | + auto inv_scale = 1.f / scale; |
102 | 73 |
|
103 |
| - auto stream = c10::cuda::getCurrentCUDAStream(); |
104 |
| - |
105 |
| - AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fused_fake_e4m3fy_cuda", [&] { |
106 |
| - fused_fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>( |
107 |
| - inputs.data_ptr<scalar_t>(), numel, amax.data_ptr<float>(), per_block_scaling_factor, |
108 |
| - blocksize, zero_threshold, outputs.data_ptr<scalar_t>()); |
| 74 | + AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_with_axis_cuda", [&] { |
| 75 | + fake_e4m3fy_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE>>>( |
| 76 | + inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(), |
| 77 | + axis_size, outer_size, outputs.data_ptr<scalar_t>()); |
109 | 78 | });
|
| 79 | + |
110 | 80 | return outputs;
|
111 | 81 | }
|
112 | 82 |
|
113 |
| -at::Tensor fake_e4m3fy_cuda(at::Tensor inputs) { |
| 83 | +at::Tensor fake_e4m3fy_cuda(at::Tensor inputs, at::Tensor amax) { |
114 | 84 | size_t numel = inputs.numel();
|
| 85 | + auto scale = 448.f / amax; |
| 86 | + auto inv_scale = 1.f / scale; |
115 | 87 | auto outputs = torch::empty_like(inputs);
|
116 | 88 | auto stream = c10::cuda::getCurrentCUDAStream();
|
117 | 89 | AT_DISPATCH_FLOATING_TYPES(inputs.type().scalarType(), "fake_e4m3fy_cuda", [&] {
|
118 | 90 | fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4) + 1, BLOCK_SIZE, 0, stream>>>(
|
119 |
| - inputs.data_ptr<scalar_t>(), numel, outputs.data_ptr<scalar_t>()); |
| 91 | + inputs.data_ptr<scalar_t>(), numel, scale.data_ptr<float>(), inv_scale.data_ptr<float>(), |
| 92 | + outputs.data_ptr<scalar_t>()); |
120 | 93 | });
|
121 | 94 | return outputs;
|
122 | 95 | }
|
0 commit comments