Skip to content

Commit 176244a

Browse files
authored
Merge pull request #642 from eliotwang/ck_tile_gemm
support ck-tile blockquant gemm in vllm
2 parents caea443 + 25f843d commit 176244a

File tree

21 files changed

+491
-171
lines changed

21 files changed

+491
-171
lines changed

benchmarks/P3L.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22
# SPDX-License-Identifier: Apache-2.0
3+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
34
"""
45
Patch-Perplexity (P3L)
56

benchmarks/P3L_mling.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22
# SPDX-License-Identifier: Apache-2.0
3+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
34
"""
45
*MULTILINGUAL* Patch-Perplexity (P3L)
56
@@ -91,19 +92,19 @@ def get_wikitext2_text(tokenizer):
9192
return test_enc, test_text
9293

9394

94-
def get_flores_plus_text(tokenizer, lng_scrpt):
95+
def get_flores_plus_text(tokenizer, lng_script):
9596
hf_hub_download(
9697
repo_id="alexei-v-ivanov-amd/flores_plus",
9798
repo_type="dataset",
98-
filename=lng_scrpt + ".parquet",
99+
filename=lng_script + ".parquet",
99100
local_dir="./",
100101
)
101102

102-
df = pandas.read_parquet("./" + lng_scrpt + ".parquet")
103+
df = pandas.read_parquet("./" + lng_script + ".parquet")
103104
test_text = "\n\n".join(line.strip() for line in df["text"])
104105
test_enc = tokenizer(test_text)
105106

106-
os.remove("./" + lng_scrpt + ".parquet")
107+
os.remove("./" + lng_script + ".parquet")
107108

108109
return test_enc, test_text
109110

benchmarks/profiling/benchmark_latency.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
23
"""Benchmark the latency of processing a single batch of requests."""
34

45
import argparse

benchmarks/profiling/benchmark_throughput.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
23
"""Benchmark offline inference throughput."""
34

45
import argparse

csrc/layernorm_kernels.cu

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ __global__ void rms_norm_kernel(
5151
template <typename scalar_t, int width>
5252
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
5353
fused_add_rms_norm_kernel(
54-
scalar_t* __restrict__ output, // [..., hidden_size]
55-
const scalar_t* __restrict__ input, // [..., hidden_size]
54+
scalar_t* __restrict__ output, // [..., hidden_size]
55+
const scalar_t* __restrict__ input, // [..., hidden_size]
5656
const int64_t input_stride,
5757
scalar_t* __restrict__ residual_out, // [..., hidden_size]
5858
const scalar_t* __restrict__ residual, // [..., hidden_size]
@@ -114,8 +114,8 @@ fused_add_rms_norm_kernel(
114114
template <typename scalar_t, int width>
115115
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
116116
fused_add_rms_norm_kernel(
117-
scalar_t* __restrict__ output, // [..., hidden_size]
118-
const scalar_t* __restrict__ input, // [..., hidden_size]
117+
scalar_t* __restrict__ output, // [..., hidden_size]
118+
const scalar_t* __restrict__ input, // [..., hidden_size]
119119
const int64_t input_stride,
120120
scalar_t* __restrict__ residual_out, // [..., hidden_size]
121121
const scalar_t* __restrict__ residual, // [..., hidden_size]
@@ -221,9 +221,10 @@ void fused_add_rms_norm(torch::Tensor& out, // [..., hidden_size]
221221
constexpr int req_alignment_bytes =
222222
vector_width * 2; // vector_width * sizeof(bfloat16 or float16) (float32
223223
// falls back to non-vectorized version anyway)
224-
bool ptrs_are_aligned = out_ptr % 16 == 0 && inp_ptr % req_alignment_bytes == 0 &&
225-
res_out_ptr % 16 == 0 && res_ptr % req_alignment_bytes == 0 &&
226-
wt_ptr % req_alignment_bytes == 0;
224+
bool ptrs_are_aligned =
225+
out_ptr % 16 == 0 && inp_ptr % req_alignment_bytes == 0 &&
226+
res_out_ptr % 16 == 0 && res_ptr % req_alignment_bytes == 0 &&
227+
wt_ptr % req_alignment_bytes == 0;
227228
bool offsets_are_multiple_of_vector_width =
228229
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
229230
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {

csrc/layernorm_quant_kernels.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ __global__ void rms_norm_static_fp8_quant_kernel(
6464
template <typename scalar_t, int width, typename fp8_type>
6565
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
6666
fused_add_rms_norm_static_fp8_quant_kernel(
67-
fp8_type* __restrict__ out, // [..., hidden_size]
68-
scalar_t* __restrict__ input, // [..., hidden_size]
67+
fp8_type* __restrict__ out, // [..., hidden_size]
68+
scalar_t* __restrict__ input, // [..., hidden_size]
6969
const int input_stride,
7070
scalar_t* __restrict__ residual_out, // [..., hidden_size]
7171
scalar_t* __restrict__ residual, // [..., hidden_size]
@@ -132,8 +132,8 @@ fused_add_rms_norm_static_fp8_quant_kernel(
132132
template <typename scalar_t, int width, typename fp8_type>
133133
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
134134
fused_add_rms_norm_static_fp8_quant_kernel(
135-
fp8_type* __restrict__ out, // [..., hidden_size]
136-
scalar_t* __restrict__ input, // [..., hidden_size]
135+
fp8_type* __restrict__ out, // [..., hidden_size]
136+
scalar_t* __restrict__ input, // [..., hidden_size]
137137
const int input_stride,
138138
scalar_t* __restrict__ residual_out, // [..., hidden_size]
139139
scalar_t* __restrict__ residual, // [..., hidden_size]
@@ -210,8 +210,8 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
210210
width, fp8_t> \
211211
<<<grid, block, 0, stream>>>( \
212212
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
213-
input_stride, residual_out.data_ptr<scalar_t>(), \
214-
residual.data_ptr<scalar_t>(), \
213+
input_stride, residual_out.data_ptr<scalar_t>(), \
214+
residual.data_ptr<scalar_t>(), \
215215
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
216216
epsilon, num_tokens, hidden_size); \
217217
}); \

csrc/rocm/fused_kernels.cu

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
#include <cuda_runtime.h>
2+
#include <cuda_fp16.h>
3+
#include <stdexcept>
4+
#include <algorithm>
5+
6+
constexpr int WARP_SIZE = 64;
7+
8+
template <typename T>
9+
__device__ __forceinline__ T silu(const T& x) {
10+
// x * sigmoid(x)
11+
return (T)(((float)x) / (1.0f + expf((float)-x)));
12+
}
13+
14+
template <typename T>
15+
__device__ __forceinline__ T loadnt(T* addr) {
16+
return __builtin_nontemporal_load(addr);
17+
}
18+
19+
__device__ __forceinline__ float4 load_ntmprl(const float4* addr) {
20+
auto addr_alias = reinterpret_cast<const float*>(addr);
21+
auto dat0 = loadnt(addr_alias);
22+
auto dat1 = loadnt(addr_alias + 1);
23+
auto dat2 = loadnt(addr_alias + 2);
24+
auto dat3 = loadnt(addr_alias + 3);
25+
// auto dat0 = *(addr_alias);
26+
// auto dat1 = *(addr_alias+1);
27+
// auto dat2 = *(addr_alias+2);
28+
// auto dat3 = *(addr_alias+3);
29+
return make_float4(dat0, dat1, dat2, dat3);
30+
}
31+
32+
// TBlock fetches entire rows of A, and entire col of B (K dimension); assume
33+
// N=1 for time being grid is M/A_NUM_ROWS blocks
34+
template <int NUM_A_ROWS_PER_BLOCK>
35+
__global__ void LLGemm_Silu_kernel(float4* af4, __half2* bf4, _Float16* c,
36+
const int d) {
37+
__shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE];
38+
const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 * blockDim.x;
39+
const int row_addr_d = row_addr + d * blockDim.x;
40+
// int row_addr_1 = row_addr + CUDA_NUM_THREADS;
41+
// int row_addr_2 = row_addr_1 + CUDA_NUM_THREADS;
42+
// int row_addr_3 = row_addr_2 + CUDA_NUM_THREADS;
43+
const int threadid = threadIdx.x;
44+
const int warp = threadIdx.x / WARP_SIZE;
45+
const int lane = threadIdx.x % WARP_SIZE;
46+
const int num_warps = blockDim.x / WARP_SIZE;
47+
const int qwarpid = threadid / 16;
48+
const int qthreadid = threadid % 16;
49+
float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK];
50+
// float4 colB_elem4;
51+
__half2 colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w;
52+
float acc[NUM_A_ROWS_PER_BLOCK]; //= 0.0;
53+
__half2 acch2;
54+
55+
// rowA_elem4 = af4[row_addr + threadid];
56+
//__syncthreads();
57+
// rowA_elem4_1 = af4[row_addr_1 + threadid];
58+
// rowA_elem4_2 = af4[row_addr_2 + threadid];
59+
// rowA_elem4_3 = af4[row_addr_3 + threadid];
60+
#pragma unroll
61+
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK / 2; i++) {
62+
rowA_elem4[2 * i] = load_ntmprl(&af4[row_addr + i * blockDim.x + threadid]);
63+
rowA_elem4[2 * i + 1] =
64+
load_ntmprl(&af4[row_addr_d + i * blockDim.x + threadid]);
65+
// rowA_elem4[i] = af4[row_addr + i*blockDim.x + threadid];
66+
//__syncthreads();
67+
}
68+
colB_elem4x = bf4[threadid * 4 + 0];
69+
colB_elem4y = bf4[threadid * 4 + 1];
70+
colB_elem4z = bf4[threadid * 4 + 2];
71+
colB_elem4w = bf4[threadid * 4 + 3];
72+
73+
// __syncthreads();
74+
__half2 Af2;
75+
float2 S;
76+
// auto Bh2ptr = reinterpret_cast<__half2 *>(&colB_elem4);
77+
// auto Bf2x = *Bh2ptr;
78+
// auto Bf2y = *(Bh2ptr+1);
79+
// auto Bf2z = *(Bh2ptr+2);
80+
// auto Bf2w = *(Bh2ptr+3);
81+
auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4);
82+
__half2* ah2lptr;
83+
#pragma unroll
84+
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
85+
ah2lptr = Ah2ptr + i * 4;
86+
Af2 = *(ah2lptr);
87+
acch2 = __hmul2(Af2, colB_elem4x);
88+
Af2 = *(ah2lptr + 1);
89+
acch2 = __hfma2(Af2, colB_elem4y, acch2);
90+
Af2 = *(ah2lptr + 2);
91+
acch2 = __hfma2(Af2, colB_elem4z, acch2);
92+
Af2 = *(ah2lptr + 3);
93+
acch2 = __hfma2(Af2, colB_elem4w, acch2);
94+
S = __half22float2(acch2);
95+
acc[i] = S.x + S.y;
96+
}
97+
98+
#pragma unroll
99+
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
100+
#pragma unroll
101+
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
102+
acc[i] += __shfl_xor(acc[i], mask);
103+
}
104+
}
105+
106+
// Warp leaders store the data to shared memory.
107+
// if (lane == 0) {
108+
// #pragma unroll
109+
// for (int i=0; i<NUM_A_ROWS_PER_BLOCK; i++) {
110+
// red_smem[i][warp] = acc[i];
111+
// }
112+
// }
113+
114+
if (lane < NUM_A_ROWS_PER_BLOCK) {
115+
red_smem[lane][warp] = acc[lane];
116+
}
117+
118+
// Make sure the data is in shared memory.
119+
__syncthreads();
120+
if (qwarpid < NUM_A_ROWS_PER_BLOCK) {
121+
// if (threadid<64) {
122+
// #pragma unroll
123+
// for (int i=0; i<NUM_A_ROWS_PER_BLOCK/2; i++) {
124+
// acc[i+2*qwarpid] = 0.0;
125+
// }
126+
////acc[qwarpid] = 0.0;
127+
128+
////if (qthreadid<num_warps) {
129+
// #pragma unroll
130+
// for (int i=0; i<NUM_A_ROWS_PER_BLOCK/2; i++) {
131+
// acc[i+2*qwarpid] = red_smem[i+2*qwarpid][qthreadid];
132+
// }
133+
////acc[qwarpid] = red_smem[qwarpid][qthreadid];
134+
135+
////}
136+
acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f;
137+
// if (threadid<32) {
138+
#pragma unroll
139+
for (int mask = 16 / 2; mask >= 1; mask /= 2) {
140+
// #pragma unroll
141+
// for (int i=0; i<NUM_A_ROWS_PER_BLOCK/2; i++) {
142+
// acc[i+2*qwarpid] += __shfl_xor(acc[i+2*qwarpid], mask);
143+
// }
144+
acc[qwarpid] += __shfl_xor(acc[qwarpid], mask);
145+
}
146+
float oval2 = __shfl_xor(acc[qwarpid], 16);
147+
// acc[1] = __shfl_xor(acc[1],16);
148+
// acc[3] = __shfl_xor(acc[3],16);
149+
//}
150+
// __syncthreads();
151+
// if (threadid < NUM_A_ROWS_PER_BLOCK/2) {
152+
if (lane == 0 or lane == 32) {
153+
// oval = __float22half2_rn(make_float2(acc[qwarpid],oval2));
154+
// c[blockIdx.x*NUM_A_ROWS_PER_BLOCK/2+qwarpid/2] = oval;
155+
156+
c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] =
157+
silu(acc[qwarpid]) * oval2;
158+
}
159+
} // threadid<WARP_SIZE
160+
}
161+
// define the kernel calling code:
162+
// template <typename T>
163+
void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K,
164+
cudaStream_t stream, const int rows_per_block = 4) {
165+
float4* af4 = reinterpret_cast<float4*>(in_a);
166+
auto* bf4 = reinterpret_cast<__half2*>(in_b);
167+
auto* c = reinterpret_cast<_Float16*>(out_c);
168+
const int d = M / 2;
169+
const int NUM_THREADS = K * 2 / 16;
170+
int NUM_BLOCKS = M / rows_per_block;
171+
if (rows_per_block == 2) {
172+
LLGemm_Silu_kernel<2>
173+
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, d);
174+
} else if (rows_per_block == 4) {
175+
LLGemm_Silu_kernel<4>
176+
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, d);
177+
} else if (rows_per_block == 8) {
178+
LLGemm_Silu_kernel<8>
179+
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, d);
180+
} else if (rows_per_block == 16) {
181+
LLGemm_Silu_kernel<16>
182+
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, d);
183+
} else {
184+
NUM_BLOCKS = M / 4;
185+
LLGemm_Silu_kernel<4>
186+
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, d);
187+
}
188+
189+
cudaError_t err = cudaGetLastError();
190+
if (cudaSuccess != err)
191+
throw std::runtime_error("CUDA kernel failed : " + std::to_string(err));
192+
}

0 commit comments

Comments
 (0)