|
| 1 | +#include <torch/extension.h> |
| 2 | +#include <ATen/cuda/Exceptions.h> |
| 3 | +#include <ATen/cuda/CUDAContext.h> |
| 4 | + |
| 5 | +#include <cmath> |
| 6 | + |
| 7 | +#include "common/micros.h" |
| 8 | +#include "utils/vec_copy.h" |
| 9 | +#include "funcs/cast_functor.h" |
| 10 | + |
| 11 | + |
| 12 | +using colossalAI::cuda::utils::copy; |
| 13 | +using colossalAI::cuda::utils::get_vec_size; |
| 14 | +using colossalAI::funcs::CastFunctor; |
| 15 | + |
| 16 | +template <typename InT, typename OutT, int VecSize> |
| 17 | +__global__ void convert_fp8_kernel(const InT* ins_data, OutT* outs_data, int numel, int tail) |
| 18 | +{ |
| 19 | + int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x); |
| 20 | + const int64_t grid_size = blockDim.x * gridDim.x; |
| 21 | + if(idx > numel + tail) { |
| 22 | + return; |
| 23 | + } |
| 24 | + |
| 25 | + for(int64_t i = idx; i < numel; i += grid_size) { |
| 26 | + copy<InT, OutT, VecSize>(ins_data + i * VecSize, outs_data + i * VecSize); |
| 27 | + } |
| 28 | + // Tail process |
| 29 | + if(threadIdx.x == 0) |
| 30 | + { |
| 31 | + for(int i = 0; i < tail; ++i) |
| 32 | + { |
| 33 | + outs_data[i + numel * VecSize] = CastFunctor<InT, OutT>()(ins_data[i + numel * VecSize]); |
| 34 | + } |
| 35 | + } |
| 36 | +} |
| 37 | + |
| 38 | +template <typename InT, typename OutT> |
| 39 | +void apply_convert_fp8(torch::Tensor& input, torch::Tensor& output) |
| 40 | +{ |
| 41 | + const int kVecSize = get_vec_size<InT>(input); |
| 42 | + const int kNumel = torch::numel(input); |
| 43 | + |
| 44 | + const int kVecNumel = (kNumel >> static_cast<int>(std::log2(kVecSize))); |
| 45 | + const int kTail = kNumel & (kVecSize - 1); |
| 46 | + int grid_size = kVecNumel ? (kVecNumel + 255) / 256 : 1; |
| 47 | + |
| 48 | + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 49 | + |
| 50 | + dim3 grid(grid_size); |
| 51 | + dim3 block(256); |
| 52 | + |
| 53 | +#define _(VEC_SIZE) \ |
| 54 | + convert_fp8_kernel<InT, OutT, VEC_SIZE> \ |
| 55 | + <<<grid, block, 0, stream>>> \ |
| 56 | + (reinterpret_cast<const InT*>(input.data_ptr()), \ |
| 57 | + reinterpret_cast<OutT*>(output.data_ptr()), \ |
| 58 | + kVecNumel, \ |
| 59 | + kTail) |
| 60 | + |
| 61 | + switch (kVecSize) |
| 62 | + { |
| 63 | + case 1: |
| 64 | + _(1); |
| 65 | + break; |
| 66 | + case 2: |
| 67 | + _(2); |
| 68 | + break; |
| 69 | + case 4: |
| 70 | + _(4); |
| 71 | + break; |
| 72 | + } |
| 73 | +#undef _ |
| 74 | + AT_CUDA_CHECK(cudaGetLastError()); |
| 75 | +} |
| 76 | + |
| 77 | +void convert_fp8(torch::Tensor& input, torch::Tensor& output) |
| 78 | +{ |
| 79 | + TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte || output.scalar_type() == at::ScalarType::Byte, "Data type of Input or Output should be torch.uint8 for convert_fp8!"); |
| 80 | + TORCH_CHECK(input.scalar_type() != output.scalar_type(), "Data type of input and output are the same!"); |
| 81 | + TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte || |
| 82 | + input.scalar_type() == at::ScalarType::Float || |
| 83 | + input.scalar_type() == at::ScalarType::Half || |
| 84 | + input.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of input!"); |
| 85 | + TORCH_CHECK(output.scalar_type() == at::ScalarType::Byte || |
| 86 | + output.scalar_type() == at::ScalarType::Float || |
| 87 | + output.scalar_type() == at::ScalarType::Half || |
| 88 | + output.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of output!"); |
| 89 | + TORCH_CHECK(input.sizes() == output.sizes(), "Shape of input and output should be the same!"); |
| 90 | + |
| 91 | +#define _(InT, OutT) \ |
| 92 | + apply_convert_fp8<InT, OutT>(input, output) |
| 93 | + |
| 94 | + |
| 95 | + if(input.scalar_type() == at::ScalarType::Byte) |
| 96 | + { |
| 97 | + if(output.scalar_type() == at::ScalarType::Float) |
| 98 | + { |
| 99 | + _(uint8_t, float); |
| 100 | + } |
| 101 | + else if(output.scalar_type() == at::ScalarType::Half) |
| 102 | + { |
| 103 | + _(uint8_t, half); |
| 104 | + } |
| 105 | + else if(output.scalar_type() == at::ScalarType::BFloat16) |
| 106 | + { |
| 107 | + _(uint8_t, __nv_bfloat16); |
| 108 | + } |
| 109 | + } |
| 110 | + else |
| 111 | + { |
| 112 | + if(input.scalar_type() == at::ScalarType::Float) |
| 113 | + { |
| 114 | + _(float, uint8_t); |
| 115 | + } |
| 116 | + else if(input.scalar_type() == at::ScalarType::Half) |
| 117 | + { |
| 118 | + _(half, uint8_t); |
| 119 | + } |
| 120 | + else if(input.scalar_type() == at::ScalarType::BFloat16) |
| 121 | + { |
| 122 | + _(__nv_bfloat16, uint8_t); |
| 123 | + } |
| 124 | + } |
| 125 | + |
| 126 | +#undef _ |
| 127 | +} |
0 commit comments