@@ -384,6 +384,113 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
384
384
}
385
385
}
386
386
387
+ template <uint32_t block_size,
388
+ uint32_t num_frags_z,
389
+ uint32_t NUM_WARP_Q,
390
+ typename T>
391
+ __device__ __forceinline__ void produce_k_dynamic_scale (
392
+ T* k_smem_scale,
393
+ T* cache_k_reg,
394
+ const int * block_table_now,
395
+ const T* cache_k_scale,
396
+ const uint32_t kv_idx,
397
+ const uint32_t kv_num_heads,
398
+ const uint32_t kv_head_idx,
399
+ const uint32_t chunk_end
400
+ ) {
401
+ const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
402
+ if constexpr (NUM_WARP_Q == 4 ) {
403
+ // 4 warps shared block_size
404
+ const uint32_t tid = ty * 32 + tx;
405
+ int block_id = __ldg (&block_table_now[kv_idx / block_size]);
406
+ if (block_id < 0 ) block_id = 0 ;
407
+ const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
408
+ if (tid < block_size) {
409
+ k_smem_scale[tid] = cache_k_scale_now[tid];
410
+ }
411
+ __syncthreads ();
412
+ const uint32_t row_id = tx / 4 ;
413
+ for (uint32_t fz = 0 ; fz < num_frags_z; fz++) {
414
+ cache_k_reg[fz * 2 ] = k_smem_scale[fz * 16 + row_id];
415
+ cache_k_reg[fz * 2 + 1 ] = k_smem_scale[fz * 16 + row_id + 8 ];
416
+ }
417
+ } else {
418
+ // 1 warp 32 tokens
419
+ const uint32_t kv_idx_now = kv_idx + block_size * ty / 2 ;
420
+ int block_id = __ldg (&block_table_now[kv_idx_now / block_size]);
421
+ if (block_id < 0 ) block_id = 0 ;
422
+ const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
423
+ const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
424
+ if (kv_idx_this_thread < chunk_end) {
425
+ k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2 ) * 32 + tx];
426
+ } else {
427
+ k_smem_scale[ty * 32 + tx] = 0 ;
428
+ }
429
+ __syncwarp ();
430
+ const uint32_t row_id = tx / 4 ;
431
+ for (uint32_t fz = 0 ; fz < num_frags_z; fz++) {
432
+ cache_k_reg[fz * 2 ] = k_smem_scale[ty * 32 + fz * 16 + row_id];
433
+ cache_k_reg[fz * 2 + 1 ] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8 ];
434
+ }
435
+ }
436
+ }
437
+
438
+ template <uint32_t block_size,
439
+ uint32_t num_frags_z,
440
+ uint32_t NUM_WARP_Q,
441
+ typename T>
442
+ __device__ __forceinline__ void produce_v_dynamic_scale (
443
+ T* v_smem_scale,
444
+ T* cache_v_reg,
445
+ const int * block_table_now,
446
+ const T* cache_v_scale,
447
+ const uint32_t kv_idx,
448
+ const uint32_t kv_num_heads,
449
+ const uint32_t kv_head_idx,
450
+ const uint32_t chunk_end
451
+ ) {
452
+ const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
453
+
454
+ if constexpr (NUM_WARP_Q == 4 ) {
455
+ // 4 warps shared block_size
456
+ const uint32_t tid = ty * 32 + tx;
457
+ int block_id = __ldg (&block_table_now[kv_idx / block_size]);
458
+ if (block_id < 0 ) block_id = 0 ;
459
+ const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
460
+ if (tid < block_size) {
461
+ v_smem_scale[tid] = cache_v_scale_now[tid];
462
+ }
463
+ __syncthreads ();
464
+ const uint32_t row_id = tx % 4 * 2 ;
465
+ for (uint32_t fz = 0 ; fz < num_frags_z; fz++) {
466
+ cache_v_reg[fz * 4 ] = v_smem_scale[fz * 16 + row_id];
467
+ cache_v_reg[fz * 4 + 1 ] = v_smem_scale[fz * 16 + row_id + 1 ];
468
+ cache_v_reg[fz * 4 + 2 ] = v_smem_scale[fz * 16 + row_id + 8 ];
469
+ cache_v_reg[fz * 4 + 3 ] = v_smem_scale[fz * 16 + row_id + 9 ];
470
+ }
471
+ } else {
472
+ // 1 warp 32 tokens
473
+ const uint32_t kv_idx_now = kv_idx + block_size * ty / 2 ;
474
+ int block_id = __ldg (&block_table_now[kv_idx_now / block_size]);
475
+ if (block_id < 0 ) block_id = 0 ;
476
+ const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
477
+ const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
478
+ if (kv_idx_this_thread < chunk_end) {
479
+ v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2 ) * 32 + tx];
480
+ } else {
481
+ v_smem_scale[ty * 32 + tx] = 0 ;
482
+ }
483
+ __syncwarp ();
484
+ const uint32_t row_id = tx % 4 * 2 ;
485
+ for (uint32_t fz = 0 ; fz < num_frags_z; fz++) {
486
+ cache_v_reg[fz * 4 ] = v_smem_scale[ty * 32 + fz * 16 + row_id];
487
+ cache_v_reg[fz * 4 + 1 ] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1 ];
488
+ cache_v_reg[fz * 4 + 2 ] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8 ];
489
+ cache_v_reg[fz * 4 + 3 ] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9 ];
490
+ }
491
+ }
492
+ }
493
+
387
494
template <SharedMemFillMode fill_mode,
388
495
uint32_t num_warps,
389
496
uint32_t block_size,
@@ -816,7 +923,8 @@ template <uint32_t num_frags_x,
816
923
typename T,
817
924
typename CacheT,
818
925
bool is_scale_channel_wise = false ,
819
- bool IsFP8=false >
926
+ bool IsFP8 = false ,
927
+ bool IsDynamicC8 = false >
820
928
__device__ __forceinline__ void compute_qk_c8 (smem_t * q_smem,
821
929
uint32_t * q_smem_offset_r,
822
930
smem_t * k_smem,
@@ -860,20 +968,27 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
860
968
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2 ]);
861
969
convert_c8<T,IsFP8>(b_frag_dq_T + 4 , b_frag[fy * 2 + 1 ]);
862
970
// scale zp
863
- if constexpr (is_scale_channel_wise) {
864
- const int scale_col = (ky * 2 + fy) * 4 ;
865
- b_frag_dq_T[0 ] *= cache_k_scale[scale_col];
866
- b_frag_dq_T[1 ] *= cache_k_scale[scale_col + 1 ];
867
- b_frag_dq_T[2 ] *= cache_k_scale[scale_col + 2 ];
868
- b_frag_dq_T[3 ] *= cache_k_scale[scale_col + 3 ];
869
- b_frag_dq_T[4 ] *= cache_k_scale[scale_col];
870
- b_frag_dq_T[5 ] *= cache_k_scale[scale_col + 1 ];
871
- b_frag_dq_T[6 ] *= cache_k_scale[scale_col + 2 ];
872
- b_frag_dq_T[7 ] *= cache_k_scale[scale_col + 3 ];
971
+ if constexpr (!IsDynamicC8) {
972
+ if constexpr (is_scale_channel_wise) {
973
+ const int scale_col = (ky * 2 + fy) * 4 ;
974
+ b_frag_dq_T[0 ] *= cache_k_scale[scale_col];
975
+ b_frag_dq_T[1 ] *= cache_k_scale[scale_col + 1 ];
976
+ b_frag_dq_T[2 ] *= cache_k_scale[scale_col + 2 ];
977
+ b_frag_dq_T[3 ] *= cache_k_scale[scale_col + 3 ];
978
+ b_frag_dq_T[4 ] *= cache_k_scale[scale_col];
979
+ b_frag_dq_T[5 ] *= cache_k_scale[scale_col + 1 ];
980
+ b_frag_dq_T[6 ] *= cache_k_scale[scale_col + 2 ];
981
+ b_frag_dq_T[7 ] *= cache_k_scale[scale_col + 3 ];
982
+ } else {
983
+ #pragma unroll
984
+ for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
985
+ b_frag_dq_T[b_i] *= cache_k_scale[0 ];
986
+ }
987
+ }
873
988
} else {
874
989
#pragma unroll
875
990
for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
876
- b_frag_dq_T[b_i] *= cache_k_scale[0 ];
991
+ b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4 ];
877
992
}
878
993
}
879
994
#pragma unroll
@@ -1093,7 +1208,9 @@ template <uint32_t num_frags_x,
1093
1208
uint32_t block_size,
1094
1209
typename T,
1095
1210
typename CacheT,
1096
- bool is_scale_channel_wise = false , bool IsFP8=false >
1211
+ bool is_scale_channel_wise = false ,
1212
+ bool IsFP8 = false ,
1213
+ bool IsDynamicC8 = false >
1097
1214
__device__ __forceinline__ void compute_sfm_v_c8 (
1098
1215
smem_t * v_smem,
1099
1216
uint32_t * v_smem_offset_r,
@@ -1135,16 +1252,28 @@ __device__ __forceinline__ void compute_sfm_v_c8(
1135
1252
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2 ]);
1136
1253
convert_c8<T,IsFP8>(b_frag_dq_T + 4 , b_frag[fz * 2 + 1 ]);
1137
1254
// scale zp
1138
- if constexpr (is_scale_channel_wise) {
1255
+ if constexpr (!IsDynamicC8) {
1256
+ if constexpr (is_scale_channel_wise) {
1139
1257
#pragma unroll
1140
- for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
1141
- b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2 ];
1142
- }
1143
- } else {
1258
+ for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
1259
+ b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2 ];
1260
+ }
1261
+ } else {
1144
1262
#pragma unroll
1145
- for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
1146
- b_frag_dq_T[b_i] *= cache_v_scale[0 ];
1263
+ for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
1264
+ b_frag_dq_T[b_i] *= cache_v_scale[0 ];
1265
+ }
1147
1266
}
1267
+ } else {
1268
+ const int scale_col = (kz * 2 + fz) * 4 ;
1269
+ b_frag_dq_T[0 ] *= cache_v_scale[scale_col];
1270
+ b_frag_dq_T[1 ] *= cache_v_scale[scale_col + 1 ];
1271
+ b_frag_dq_T[2 ] *= cache_v_scale[scale_col + 2 ];
1272
+ b_frag_dq_T[3 ] *= cache_v_scale[scale_col + 3 ];
1273
+ b_frag_dq_T[4 ] *= cache_v_scale[scale_col];
1274
+ b_frag_dq_T[5 ] *= cache_v_scale[scale_col + 1 ];
1275
+ b_frag_dq_T[6 ] *= cache_v_scale[scale_col + 2 ];
1276
+ b_frag_dq_T[7 ] *= cache_v_scale[scale_col + 3 ];
1148
1277
}
1149
1278
#pragma unroll
1150
1279
for (uint32_t fx = 0 ; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
@@ -1171,7 +1300,9 @@ template <uint32_t num_frags_x,
1171
1300
uint32_t block_size,
1172
1301
typename T,
1173
1302
typename CacheT,
1174
- bool is_scale_channel_wise = false , bool IsFP8=false >
1303
+ bool is_scale_channel_wise = false ,
1304
+ bool IsFP8 = false ,
1305
+ bool IsDynamicC8 = false >
1175
1306
__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec (
1176
1307
smem_t * v_smem,
1177
1308
uint32_t * v_smem_offset_r,
@@ -1215,16 +1346,28 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
1215
1346
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2 ]);
1216
1347
convert_c8<T,IsFP8>(b_frag_dq_T + 4 , b_frag[fz * 2 + 1 ]);
1217
1348
// scale zp
1218
- if constexpr (is_scale_channel_wise) {
1349
+ if constexpr (!IsDynamicC8) {
1350
+ if constexpr (is_scale_channel_wise) {
1219
1351
#pragma unroll
1220
- for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
1221
- b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2 ];
1352
+ for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
1353
+ b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2 ];
1354
+ }
1355
+ } else {
1356
+ #pragma unroll
1357
+ for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
1358
+ b_frag_dq_T[b_i] *= cache_v_scale[0 ];
1359
+ }
1222
1360
}
1223
1361
} else {
1224
- #pragma unroll
1225
- for (uint32_t b_i = 0 ; b_i < 8 ; ++b_i) {
1226
- b_frag_dq_T[b_i] *= cache_v_scale[0 ];
1227
- }
1362
+ const int scale_col = (kz * 2 + fz) * 4 ;
1363
+ b_frag_dq_T[0 ] *= cache_v_scale[scale_col];
1364
+ b_frag_dq_T[1 ] *= cache_v_scale[scale_col + 1 ];
1365
+ b_frag_dq_T[2 ] *= cache_v_scale[scale_col + 2 ];
1366
+ b_frag_dq_T[3 ] *= cache_v_scale[scale_col + 3 ];
1367
+ b_frag_dq_T[4 ] *= cache_v_scale[scale_col];
1368
+ b_frag_dq_T[5 ] *= cache_v_scale[scale_col + 1 ];
1369
+ b_frag_dq_T[6 ] *= cache_v_scale[scale_col + 2 ];
1370
+ b_frag_dq_T[7 ] *= cache_v_scale[scale_col + 3 ];
1228
1371
}
1229
1372
#pragma unroll
1230
1373
for (uint32_t fx = 0 ; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
0 commit comments