Skip to content

Commit 25aa2d9

Browse files
rsmallbluecarryyu
andauthored
cp dynamic Cfp8 (#4120)
* supports dynamic Cfp8 * add unittest * fix dynamic Cfp8 computing error * fix Cfp8 for RL load --------- Co-authored-by: carryyu <[email protected]>
1 parent b6caf6e commit 25aa2d9

21 files changed

+1419
-218
lines changed

custom_ops/gpu_ops/append_attention.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ void AppendAttentionKernel(
140140
key_cache,
141141
value_cache,
142142
attn_mask,
143-
cache_k_dequant_scales,
144-
cache_v_dequant_scales,
143+
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales,
144+
cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales,
145145
cache_k_zp,
146146
cache_v_zp,
147147
out_linear_shifts,

custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh

Lines changed: 200 additions & 102 deletions
Large diffs are not rendered by default.

custom_ops/gpu_ops/append_attn/append_attention_func.cuh

Lines changed: 171 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,113 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
384384
}
385385
}
386386

387+
template<uint32_t block_size,
388+
uint32_t num_frags_z,
389+
uint32_t NUM_WARP_Q,
390+
typename T>
391+
__device__ __forceinline__ void produce_k_dynamic_scale(
392+
T* k_smem_scale,
393+
T* cache_k_reg,
394+
const int* block_table_now,
395+
const T* cache_k_scale,
396+
const uint32_t kv_idx,
397+
const uint32_t kv_num_heads,
398+
const uint32_t kv_head_idx,
399+
const uint32_t chunk_end
400+
) {
401+
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
402+
if constexpr (NUM_WARP_Q == 4) {
403+
// 4 warps shared block_size
404+
const uint32_t tid = ty * 32 + tx;
405+
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
406+
if (block_id < 0) block_id = 0;
407+
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
408+
if (tid < block_size) {
409+
k_smem_scale[tid] = cache_k_scale_now[tid];
410+
}
411+
__syncthreads();
412+
const uint32_t row_id = tx / 4;
413+
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
414+
cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id];
415+
cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8];
416+
}
417+
} else {
418+
// 1 warp 32 tokens
419+
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
420+
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
421+
if (block_id < 0) block_id = 0;
422+
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
423+
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
424+
if (kv_idx_this_thread < chunk_end) {
425+
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
426+
} else {
427+
k_smem_scale[ty * 32 + tx] = 0;
428+
}
429+
__syncwarp();
430+
const uint32_t row_id = tx / 4;
431+
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
432+
cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id];
433+
cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8];
434+
}
435+
}
436+
}
437+
438+
template<uint32_t block_size,
439+
uint32_t num_frags_z,
440+
uint32_t NUM_WARP_Q,
441+
typename T>
442+
__device__ __forceinline__ void produce_v_dynamic_scale(
443+
T* v_smem_scale,
444+
T* cache_v_reg,
445+
const int* block_table_now,
446+
const T* cache_v_scale,
447+
const uint32_t kv_idx,
448+
const uint32_t kv_num_heads,
449+
const uint32_t kv_head_idx,
450+
const uint32_t chunk_end
451+
) {
452+
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
453+
454+
if constexpr (NUM_WARP_Q == 4) {
455+
// 4 warps shared block_size
456+
const uint32_t tid = ty * 32 + tx;
457+
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
458+
if (block_id < 0) block_id = 0;
459+
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
460+
if (tid < block_size) {
461+
v_smem_scale[tid] = cache_v_scale_now[tid];
462+
}
463+
__syncthreads();
464+
const uint32_t row_id = tx % 4 * 2;
465+
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
466+
cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id];
467+
cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1];
468+
cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8];
469+
cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9];
470+
}
471+
} else {
472+
// 1 warp 32 tokens
473+
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
474+
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
475+
if (block_id < 0) block_id = 0;
476+
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
477+
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
478+
if (kv_idx_this_thread < chunk_end) {
479+
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
480+
} else {
481+
v_smem_scale[ty * 32 + tx] = 0;
482+
}
483+
__syncwarp();
484+
const uint32_t row_id = tx % 4 * 2;
485+
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
486+
cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id];
487+
cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1];
488+
cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8];
489+
cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9];
490+
}
491+
}
492+
}
493+
387494
template <SharedMemFillMode fill_mode,
388495
uint32_t num_warps,
389496
uint32_t block_size,
@@ -816,7 +923,8 @@ template <uint32_t num_frags_x,
816923
typename T,
817924
typename CacheT,
818925
bool is_scale_channel_wise = false,
819-
bool IsFP8=false>
926+
bool IsFP8 = false,
927+
bool IsDynamicC8 = false>
820928
__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
821929
uint32_t* q_smem_offset_r,
822930
smem_t* k_smem,
@@ -860,20 +968,27 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
860968
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
861969
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
862970
// scale zp
863-
if constexpr (is_scale_channel_wise) {
864-
const int scale_col = (ky * 2 + fy) * 4;
865-
b_frag_dq_T[0] *= cache_k_scale[scale_col];
866-
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
867-
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
868-
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
869-
b_frag_dq_T[4] *= cache_k_scale[scale_col];
870-
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
871-
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
872-
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
971+
if constexpr (!IsDynamicC8) {
972+
if constexpr (is_scale_channel_wise) {
973+
const int scale_col = (ky * 2 + fy) * 4;
974+
b_frag_dq_T[0] *= cache_k_scale[scale_col];
975+
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
976+
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
977+
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
978+
b_frag_dq_T[4] *= cache_k_scale[scale_col];
979+
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
980+
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
981+
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
982+
} else {
983+
#pragma unroll
984+
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
985+
b_frag_dq_T[b_i] *= cache_k_scale[0];
986+
}
987+
}
873988
} else {
874989
#pragma unroll
875990
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
876-
b_frag_dq_T[b_i] *= cache_k_scale[0];
991+
b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4];
877992
}
878993
}
879994
#pragma unroll
@@ -1093,7 +1208,9 @@ template <uint32_t num_frags_x,
10931208
uint32_t block_size,
10941209
typename T,
10951210
typename CacheT,
1096-
bool is_scale_channel_wise = false, bool IsFP8=false>
1211+
bool is_scale_channel_wise = false,
1212+
bool IsFP8 = false,
1213+
bool IsDynamicC8 = false>
10971214
__device__ __forceinline__ void compute_sfm_v_c8(
10981215
smem_t* v_smem,
10991216
uint32_t* v_smem_offset_r,
@@ -1135,16 +1252,28 @@ __device__ __forceinline__ void compute_sfm_v_c8(
11351252
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
11361253
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
11371254
// scale zp
1138-
if constexpr (is_scale_channel_wise) {
1255+
if constexpr (!IsDynamicC8) {
1256+
if constexpr (is_scale_channel_wise) {
11391257
#pragma unroll
1140-
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
1141-
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
1142-
}
1143-
} else {
1258+
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
1259+
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
1260+
}
1261+
} else {
11441262
#pragma unroll
1145-
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
1146-
b_frag_dq_T[b_i] *= cache_v_scale[0];
1263+
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
1264+
b_frag_dq_T[b_i] *= cache_v_scale[0];
1265+
}
11471266
}
1267+
} else {
1268+
const int scale_col = (kz * 2 + fz) * 4;
1269+
b_frag_dq_T[0] *= cache_v_scale[scale_col];
1270+
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
1271+
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
1272+
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
1273+
b_frag_dq_T[4] *= cache_v_scale[scale_col];
1274+
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
1275+
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
1276+
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
11481277
}
11491278
#pragma unroll
11501279
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
@@ -1171,7 +1300,9 @@ template <uint32_t num_frags_x,
11711300
uint32_t block_size,
11721301
typename T,
11731302
typename CacheT,
1174-
bool is_scale_channel_wise = false, bool IsFP8=false>
1303+
bool is_scale_channel_wise = false,
1304+
bool IsFP8 = false,
1305+
bool IsDynamicC8 = false>
11751306
__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
11761307
smem_t* v_smem,
11771308
uint32_t* v_smem_offset_r,
@@ -1215,16 +1346,28 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
12151346
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
12161347
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
12171348
// scale zp
1218-
if constexpr (is_scale_channel_wise) {
1349+
if constexpr (!IsDynamicC8) {
1350+
if constexpr (is_scale_channel_wise) {
12191351
#pragma unroll
1220-
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
1221-
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
1352+
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
1353+
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
1354+
}
1355+
} else {
1356+
#pragma unroll
1357+
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
1358+
b_frag_dq_T[b_i] *= cache_v_scale[0];
1359+
}
12221360
}
12231361
} else {
1224-
#pragma unroll
1225-
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
1226-
b_frag_dq_T[b_i] *= cache_v_scale[0];
1227-
}
1362+
const int scale_col = (kz * 2 + fz) * 4;
1363+
b_frag_dq_T[0] *= cache_v_scale[scale_col];
1364+
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
1365+
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
1366+
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
1367+
b_frag_dq_T[4] *= cache_v_scale[scale_col];
1368+
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
1369+
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
1370+
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
12281371
}
12291372
#pragma unroll
12301373
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16

custom_ops/gpu_ops/append_attn/append_attention_kernel.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ void CascadeAppendAttentionC8Kernel(
103103
const bool causal,
104104
const bool is_decoder,
105105
const bool enable_prefill,
106+
const std::string& cache_quant_type_str,
106107
cudaStream_t& stream,
107108
paddle::Tensor* out);
108109

@@ -264,9 +265,10 @@ void CascadeAppendAttentionKernel(
264265
causal,
265266
is_decoder,
266267
enable_prefill,
268+
cache_quant_type_str,
267269
stream,
268270
out);
269-
} else if (cache_quant_type_str == "cache_fp8") {
271+
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
270272
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
271273
qkv,
272274
cache_k,
@@ -299,6 +301,7 @@ void CascadeAppendAttentionKernel(
299301
causal,
300302
is_decoder,
301303
enable_prefill,
304+
cache_quant_type_str,
302305
stream,
303306
out);
304307
} else if (cache_quant_type_str == "cache_int4_zp") {

0 commit comments

Comments
 (0)