18
18
#include < ATen/ATen.h>
19
19
#include < c10/cuda/CUDAStream.h>
20
20
#include < cuda_fp8.h>
21
- #include < torch/extension.h>
22
21
#include < optional>
22
+ #include < torch/extension.h>
23
23
24
24
#define BLOCK_SIZE 128
25
25
@@ -63,7 +63,9 @@ __global__ void fake_e4m3fy_with_axis_cuda_kernel(const T *inputs, size_t n, con
63
63
}
64
64
}
65
65
66
- at::Tensor fake_e4m3fy_cuda_with_axis (at::Tensor inputs, at::Tensor amax, int axis) {
66
+ at::Tensor fake_e4m3fy_with_axis (at::Tensor inputs, at::Tensor amax, int axis) {
67
+ inputs = inputs.contiguous ();
68
+ amax = amax.contiguous ();
67
69
auto outputs = torch::empty_like (inputs);
68
70
size_t numel = inputs.numel ();
69
71
int axis_size = inputs.size (axis);
@@ -73,7 +75,7 @@ at::Tensor fake_e4m3fy_cuda_with_axis(at::Tensor inputs, at::Tensor amax, int ax
73
75
auto inv_scale = 1 .f / scale;
74
76
75
77
auto stream = c10::cuda::getCurrentCUDAStream ();
76
- AT_DISPATCH_FLOATING_TYPES (inputs.type ().scalarType (), " fake_e4m3fy_with_axis_cuda " , [&] {
78
+ AT_DISPATCH_FLOATING_TYPES (inputs.type ().scalarType (), " fake_e4m3fy_with_axis " , [&] {
77
79
fake_e4m3fy_with_axis_cuda_kernel<<<numel / (BLOCK_SIZE * 4 ) + 1 , BLOCK_SIZE, 0 , stream>>> (
78
80
inputs.data_ptr <scalar_t >(), numel, scale.data_ptr <float >(), inv_scale.data_ptr <float >(),
79
81
axis_size, outer_size, outputs.data_ptr <scalar_t >());
@@ -82,21 +84,24 @@ at::Tensor fake_e4m3fy_cuda_with_axis(at::Tensor inputs, at::Tensor amax, int ax
82
84
return outputs;
83
85
}
84
86
85
- at::Tensor fake_e4m3fy_cuda (at::Tensor inputs, std::optional<at::Tensor> amax_opt) {
87
+ at::Tensor fake_e4m3fy (at::Tensor inputs, at::Tensor amax) {
88
+ inputs = inputs.contiguous ();
86
89
size_t numel = inputs.numel ();
87
- at::Tensor scale;
88
- if (amax_opt.has_value ()) {
89
- scale = 448 .f / amax_opt.value ();
90
- } else {
91
- scale = at::ones ({1 }, inputs.options ().dtype (at::kFloat ));
92
- }
90
+ at::Tensor scale = 448 .f / amax;
93
91
auto inv_scale = 1 .f / scale;
94
92
auto outputs = torch::empty_like (inputs);
95
93
auto stream = c10::cuda::getCurrentCUDAStream ();
96
- AT_DISPATCH_FLOATING_TYPES (inputs.type ().scalarType (), " fake_e4m3fy_cuda " , [&] {
94
+ AT_DISPATCH_FLOATING_TYPES (inputs.type ().scalarType (), " fake_e4m3fy " , [&] {
97
95
fake_e4m3fy_kernel<<<numel / (BLOCK_SIZE * 4 ) + 1 , BLOCK_SIZE, 0 , stream>>> (
98
96
inputs.data_ptr <scalar_t >(), numel, scale.data_ptr <float >(), inv_scale.data_ptr <float >(),
99
97
outputs.data_ptr <scalar_t >());
100
98
});
101
99
return outputs;
102
100
}
101
+
102
+ PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
103
+ m.def (" fake_e4m3fy" , &fake_e4m3fy, " Reduce precision to E4M3" , py::arg (" inputs" ),
104
+ py::arg (" amax" ));
105
+ m.def (" fake_e4m3fy_with_axis" , &fake_e4m3fy_with_axis, " Reduce precision to E4M3 (fused)" ,
106
+ py::arg (" inputs" ), py::arg (" amax" ), py::arg (" axis" ));
107
+ }
0 commit comments