Skip to content

Commit 50104ab

Browse files
authored
[Inference/Feat] Add convert_fp8 op for fp8 test in the future (#5706)
* add convert_fp8 op for fp8 test in the future * rerun ci
1 parent bfad393 commit 50104ab

File tree

5 files changed

+197
-10
lines changed

5 files changed

+197
-10
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#include <torch/extension.h>
2+
#include <ATen/cuda/Exceptions.h>
3+
#include <ATen/cuda/CUDAContext.h>
4+
5+
#include <cmath>
6+
7+
#include "common/micros.h"
8+
#include "utils/vec_copy.h"
9+
#include "funcs/cast_functor.h"
10+
11+
12+
using colossalAI::cuda::utils::copy;
13+
using colossalAI::cuda::utils::get_vec_size;
14+
using colossalAI::funcs::CastFunctor;
15+
16+
template <typename InT, typename OutT, int VecSize>
17+
__global__ void convert_fp8_kernel(const InT* ins_data, OutT* outs_data, int numel, int tail)
18+
{
19+
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
20+
const int64_t grid_size = blockDim.x * gridDim.x;
21+
if(idx > numel + tail) {
22+
return;
23+
}
24+
25+
for(int64_t i = idx; i < numel; i += grid_size) {
26+
copy<InT, OutT, VecSize>(ins_data + i * VecSize, outs_data + i * VecSize);
27+
}
28+
// Tail process
29+
if(threadIdx.x == 0)
30+
{
31+
for(int i = 0; i < tail; ++i)
32+
{
33+
outs_data[i + numel * VecSize] = CastFunctor<InT, OutT>()(ins_data[i + numel * VecSize]);
34+
}
35+
}
36+
}
37+
38+
template <typename InT, typename OutT>
39+
void apply_convert_fp8(torch::Tensor& input, torch::Tensor& output)
40+
{
41+
const int kVecSize = get_vec_size<InT>(input);
42+
const int kNumel = torch::numel(input);
43+
44+
const int kVecNumel = (kNumel >> static_cast<int>(std::log2(kVecSize)));
45+
const int kTail = kNumel & (kVecSize - 1);
46+
int grid_size = kVecNumel ? (kVecNumel + 255) / 256 : 1;
47+
48+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
49+
50+
dim3 grid(grid_size);
51+
dim3 block(256);
52+
53+
#define _(VEC_SIZE) \
54+
convert_fp8_kernel<InT, OutT, VEC_SIZE> \
55+
<<<grid, block, 0, stream>>> \
56+
(reinterpret_cast<const InT*>(input.data_ptr()), \
57+
reinterpret_cast<OutT*>(output.data_ptr()), \
58+
kVecNumel, \
59+
kTail)
60+
61+
switch (kVecSize)
62+
{
63+
case 1:
64+
_(1);
65+
break;
66+
case 2:
67+
_(2);
68+
break;
69+
case 4:
70+
_(4);
71+
break;
72+
}
73+
#undef _
74+
AT_CUDA_CHECK(cudaGetLastError());
75+
}
76+
77+
void convert_fp8(torch::Tensor& input, torch::Tensor& output)
78+
{
79+
TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte || output.scalar_type() == at::ScalarType::Byte, "Data type of Input or Output should be torch.uint8 for convert_fp8!");
80+
TORCH_CHECK(input.scalar_type() != output.scalar_type(), "Data type of input and output are the same!");
81+
TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte ||
82+
input.scalar_type() == at::ScalarType::Float ||
83+
input.scalar_type() == at::ScalarType::Half ||
84+
input.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of input!");
85+
TORCH_CHECK(output.scalar_type() == at::ScalarType::Byte ||
86+
output.scalar_type() == at::ScalarType::Float ||
87+
output.scalar_type() == at::ScalarType::Half ||
88+
output.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of output!");
89+
TORCH_CHECK(input.sizes() == output.sizes(), "Shape of input and output should be the same!");
90+
91+
#define _(InT, OutT) \
92+
apply_convert_fp8<InT, OutT>(input, output)
93+
94+
95+
if(input.scalar_type() == at::ScalarType::Byte)
96+
{
97+
if(output.scalar_type() == at::ScalarType::Float)
98+
{
99+
_(uint8_t, float);
100+
}
101+
else if(output.scalar_type() == at::ScalarType::Half)
102+
{
103+
_(uint8_t, half);
104+
}
105+
else if(output.scalar_type() == at::ScalarType::BFloat16)
106+
{
107+
_(uint8_t, __nv_bfloat16);
108+
}
109+
}
110+
else
111+
{
112+
if(input.scalar_type() == at::ScalarType::Float)
113+
{
114+
_(float, uint8_t);
115+
}
116+
else if(input.scalar_type() == at::ScalarType::Half)
117+
{
118+
_(half, uint8_t);
119+
}
120+
else if(input.scalar_type() == at::ScalarType::BFloat16)
121+
{
122+
_(__nv_bfloat16, uint8_t);
123+
}
124+
}
125+
126+
#undef _
127+
}

extensions/csrc/kernel/cuda/utils/vec_copy.h

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11

22
#pragma once
33

4-
#include <cuda_fp16.h>
5-
#include <stdint.h>
6-
74
#include "common/vec_type_traits.h"
85
#include "funcs/cast_functor.h"
96

@@ -12,9 +9,9 @@ namespace cuda {
129
namespace utils {
1310

1411
// Note(LiuYang): Depreciated
15-
template <typename T, int vec_size>
12+
template <typename T, int VecSize>
1613
__device__ __inline__ void copy_vector(T *dst, const T *src) {
17-
using VT = typename common::VecTypeTrait<T, vec_size>::Type;
14+
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
1815
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
1916
}
2017

@@ -34,17 +31,17 @@ __device__ __inline__ void copy_zero_vector(T *dst) {
3431
*(reinterpret_cast<VT *>(dst)) = funcs::CastFunctor<float, VT>()(0.0f);
3532
}
3633

37-
template <typename SrcT, typename DstT, int vec_size>
34+
template <typename SrcT, typename DstT, int VecSize>
3835
__device__ __inline__ void copy(const SrcT *src, DstT *dst) {
39-
using SrcVT = typename common::VecTypeTrait<SrcT, vec_size>::Type;
40-
using DstVT = typename common::VecTypeTrait<DstT, vec_size>::Type;
36+
using SrcVT = typename common::VecTypeTrait<SrcT, VecSize>::Type;
37+
using DstVT = typename common::VecTypeTrait<DstT, VecSize>::Type;
4138
*(reinterpret_cast<DstVT *>(dst)) = funcs::CastFunctor<SrcVT, DstVT>()(
4239
*(reinterpret_cast<const SrcVT *>(src)));
4340
}
4441

45-
template <typename T, int vec_size>
42+
template <typename T, int VecSize>
4643
__device__ __inline__ void copy(const T *src, T *dst) {
47-
using VT = typename common::VecTypeTrait<T, vec_size>::Type;
44+
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
4845
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
4946
}
5047

extensions/pybind/inference/inference.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ void flash_decoding_attention(
7575
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
7676
const c10::optional<torch::Tensor>& alibi_slopes, float scale);
7777

78+
void convert_fp8(torch::Tensor& input, torch::Tensor& output);
79+
7880
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7981
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
8082
"Copy the GPU memory of kvcache during the decode stage.");
@@ -102,4 +104,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
102104
m.def("flash_decoding_attention", &flash_decoding_attention,
103105
"Compute the attention between an input query and the cached "
104106
"keys/values using PagedAttention.");
107+
108+
m.def("convert_fp8", &convert_fp8,
109+
"Convert input to fp8 output or convert fp8 input to output.");
105110
}

extensions/pybind/inference/inference_ops_cuda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def sources_files(self):
1717
"kernel/cuda/rms_layernorm_kernel.cu",
1818
"kernel/cuda/get_cos_and_sin_kernel.cu",
1919
"kernel/cuda/flash_decoding_attention_kernel.cu",
20+
"kernel/cuda/convert_fp8_kernel.cu",
2021
]
2122
] + [self.pybind_abs_path("inference/inference.cpp")]
2223
return ret
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import random
2+
3+
import pytest
4+
import torch
5+
6+
from colossalai.kernel.kernel_loader import InferenceOpsLoader
7+
from colossalai.utils import get_current_device
8+
9+
inference_ops = InferenceOpsLoader().load()
10+
11+
DTYPES = [torch.half, torch.bfloat16, torch.float]
12+
NUM_TOKENS = [42] # Arbitrary values for testing
13+
NUM_LAYERS = [1] # Arbitrary values for testing
14+
NUM_HEADS = [8] # Arbitrary values for testing
15+
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
16+
BLOCK_SIZES = [8, 16, 32]
17+
18+
19+
@pytest.mark.skipif(True, reason="FP8 conversion still needs improvement, now we skip it's relative test!")
20+
@pytest.mark.parametrize("num_heads", [8])
21+
@pytest.mark.parametrize("head_size", [64, 80, 96, 112, 128, 256])
22+
@pytest.mark.parametrize("block_size", [8, 16, 32])
23+
@pytest.mark.parametrize("num_blocks", [1024, 10000])
24+
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float])
25+
@pytest.mark.parametrize("seed", [0])
26+
@torch.inference_mode()
27+
def test_fp8_conversion(
28+
num_heads: int,
29+
head_size: int,
30+
block_size: int,
31+
num_blocks: int,
32+
dtype: torch.dtype,
33+
seed: int,
34+
) -> None:
35+
random.seed(seed)
36+
torch.random.manual_seed(seed)
37+
torch.cuda.manual_seed(seed)
38+
39+
device = get_current_device()
40+
41+
low = -224.0
42+
high = 224.0
43+
shape = (num_blocks, num_heads, head_size, block_size)
44+
cache = torch.empty(shape, dtype=dtype, device=device)
45+
cache.uniform_(low, high)
46+
47+
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
48+
inference_ops.convert_fp8(cache, cache_fp8)
49+
50+
converted_cache = torch.empty_like(cache)
51+
inference_ops.convert_fp8(cache_fp8, converted_cache)
52+
53+
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
54+
55+
56+
if __name__ == "__main__":
57+
test_fp8_conversion(8, 64, 8, 1024, torch.half, 0)

0 commit comments

Comments
 (0)