@@ -204,12 +204,16 @@ __global__ void multi_query_append_attention_c8_kernel(
204
204
smem_t k_smem (smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof (T)),
205
205
v_smem (smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof (T) +
206
206
num_frags_z * 16 * HEAD_DIM * sizeof (CacheT));
207
- T* k_smem_scale = nullptr ;
208
- T* v_smem_scale = nullptr ;
207
+ T* k_smem_scale_ptr = nullptr ;
208
+ T* v_smem_scale_ptr = nullptr ;
209
+ smem_t k_scale_smem;
210
+ smem_t v_scale_smem;
209
211
if constexpr (IsDynamicC8) {
210
- k_smem_scale = reinterpret_cast <T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof (T) +
212
+ k_smem_scale_ptr = reinterpret_cast <T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof (T) +
211
213
num_frags_z * 16 * HEAD_DIM * sizeof (CacheT) * 2 );
212
- v_smem_scale = k_smem_scale + num_frags_z * 16 ;
214
+ v_smem_scale_ptr = k_smem_scale_ptr + num_frags_z * 16 ;
215
+ k_scale_smem.base = reinterpret_cast <b128_t *>(k_smem_scale_ptr);
216
+ v_scale_smem.base = reinterpret_cast <b128_t *>(v_smem_scale_ptr);
213
217
}
214
218
215
219
@@ -271,6 +275,20 @@ __global__ void multi_query_append_attention_c8_kernel(
271
275
kv_idx_base,
272
276
chunk_end,
273
277
const_k_offset);
278
+ if constexpr (IsDynamicC8) {
279
+ produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero ,
280
+ BLOCK_SIZE,
281
+ num_frags_z,
282
+ NUM_WARP_Q>(
283
+ k_scale_smem,
284
+ block_table_now,
285
+ cache_k_scale,
286
+ kv_idx_base,
287
+ kv_num_heads,
288
+ kv_head_idx,
289
+ chunk_end
290
+ );
291
+ }
274
292
commit_group ();
275
293
produce_v_blockwise_c8<SharedMemFillMode::kNoFill ,
276
294
NUM_WARPS,
@@ -288,24 +306,32 @@ __global__ void multi_query_append_attention_c8_kernel(
288
306
kv_idx_base,
289
307
chunk_end,
290
308
const_v_offset);
309
+ if constexpr (IsDynamicC8) {
310
+ produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero ,
311
+ BLOCK_SIZE,
312
+ num_frags_z,
313
+ NUM_WARP_Q>(
314
+ v_scale_smem,
315
+ block_table_now,
316
+ cache_v_scale,
317
+ kv_idx_base,
318
+ kv_num_heads,
319
+ kv_head_idx,
320
+ chunk_end
321
+ );
322
+ }
291
323
commit_group ();
292
324
293
325
#pragma unroll 1
294
326
for (uint32_t iter = 0 ; iter < num_iterations; ++iter) {
327
+ wait_group<1 >();
328
+ __syncthreads ();
295
329
if constexpr (IsDynamicC8) {
296
- produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
297
- k_smem_scale,
298
- cache_k_scale_reg,
299
- block_table_now,
300
- cache_k_scale,
301
- kv_idx_base,
302
- kv_num_heads,
303
- kv_head_idx,
304
- chunk_end
330
+ produce_k_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
331
+ k_smem_scale_ptr,
332
+ cache_k_scale_reg
305
333
);
306
334
}
307
- wait_group<1 >();
308
- __syncthreads ();
309
335
// s = qk
310
336
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
311
337
&qo_smem,
@@ -358,21 +384,29 @@ __global__ void multi_query_append_attention_c8_kernel(
358
384
kv_idx_base,
359
385
chunk_end,
360
386
const_k_offset);
361
- commit_group ();
362
387
if constexpr (IsDynamicC8) {
363
- produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
364
- v_smem_scale,
365
- cache_v_scale_reg,
388
+ produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero ,
389
+ BLOCK_SIZE,
390
+ num_frags_z,
391
+ NUM_WARP_Q>(
392
+ k_scale_smem,
366
393
block_table_now,
367
- cache_v_scale ,
368
- ori_kv_idx_base ,
394
+ cache_k_scale ,
395
+ kv_idx_base ,
369
396
kv_num_heads,
370
397
kv_head_idx,
371
398
chunk_end
372
399
);
373
400
}
401
+ commit_group ();
374
402
wait_group<1 >();
375
403
__syncthreads ();
404
+ if constexpr (IsDynamicC8) {
405
+ produce_v_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
406
+ v_smem_scale_ptr,
407
+ cache_v_scale_reg
408
+ );
409
+ }
376
410
377
411
// compute sfm*v
378
412
compute_sfm_v_c8<num_frags_x,
@@ -403,6 +437,20 @@ __global__ void multi_query_append_attention_c8_kernel(
403
437
kv_idx_base,
404
438
chunk_end,
405
439
const_v_offset);
440
+ if constexpr (IsDynamicC8) {
441
+ produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero ,
442
+ BLOCK_SIZE,
443
+ num_frags_z,
444
+ NUM_WARP_Q>(
445
+ v_scale_smem,
446
+ block_table_now,
447
+ cache_v_scale,
448
+ kv_idx_base,
449
+ kv_num_heads,
450
+ kv_head_idx,
451
+ chunk_end
452
+ );
453
+ }
406
454
commit_group ();
407
455
408
456
}
@@ -674,12 +722,16 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
674
722
smem_t k_smem (smem + num_frags_x * 16 * HEAD_DIM * sizeof (T)),
675
723
v_smem (smem + num_frags_x * 16 * HEAD_DIM * sizeof (T) +
676
724
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof (CacheT));
677
- T* k_smem_scale = nullptr ;
678
- T* v_smem_scale = nullptr ;
725
+ T* k_smem_scale_ptr = nullptr ;
726
+ T* v_smem_scale_ptr = nullptr ;
727
+ smem_t k_scale_smem;
728
+ smem_t v_scale_smem;
679
729
if constexpr (IsDynamicC8) {
680
- k_smem_scale = reinterpret_cast <T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof (T) +
681
- NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof (CacheT) * 2 );
682
- v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16 ;
730
+ k_smem_scale_ptr = reinterpret_cast <T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof (T) +
731
+ NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof (CacheT) * 2 );
732
+ v_smem_scale_ptr = k_smem_scale_ptr + NUM_WARP_KV * num_frags_z * 16 ;
733
+ k_scale_smem.base = reinterpret_cast <b128_t *>(k_smem_scale_ptr);
734
+ v_scale_smem.base = reinterpret_cast <b128_t *>(v_smem_scale_ptr);
683
735
}
684
736
685
737
const uint32_t num_iterations = div_up (
@@ -743,6 +795,20 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
743
795
kv_idx_base,
744
796
chunk_end,
745
797
const_k_offset);
798
+ if constexpr (IsDynamicC8) {
799
+ produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero ,
800
+ BLOCK_SIZE,
801
+ num_frags_z,
802
+ NUM_WARP_Q>(
803
+ k_scale_smem,
804
+ block_table_now,
805
+ cache_k_scale,
806
+ kv_idx_base,
807
+ kv_num_heads,
808
+ kv_head_idx,
809
+ chunk_end
810
+ );
811
+ }
746
812
commit_group ();
747
813
produce_v_blockwise_c8<SharedMemFillMode::kNoFill ,
748
814
NUM_WARPS,
@@ -760,23 +826,31 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
760
826
kv_idx_base,
761
827
chunk_end,
762
828
const_v_offset);
829
+ if constexpr (IsDynamicC8) {
830
+ produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero ,
831
+ BLOCK_SIZE,
832
+ num_frags_z,
833
+ NUM_WARP_Q>(
834
+ v_scale_smem,
835
+ block_table_now,
836
+ cache_v_scale,
837
+ kv_idx_base,
838
+ kv_num_heads,
839
+ kv_head_idx,
840
+ chunk_end
841
+ );
842
+ }
763
843
commit_group ();
764
844
#pragma unroll 1
765
845
for (uint32_t iter = 0 ; iter < num_iterations; ++iter) {
846
+ wait_group<1 >();
847
+ __syncthreads ();
766
848
if constexpr (IsDynamicC8) {
767
- produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
768
- k_smem_scale,
769
- cache_k_scale_reg,
770
- block_table_now,
771
- cache_k_scale,
772
- kv_idx_base,
773
- kv_num_heads,
774
- kv_head_idx,
775
- chunk_end
849
+ produce_k_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
850
+ k_smem_scale_ptr,
851
+ cache_k_scale_reg
776
852
);
777
853
}
778
- wait_group<1 >();
779
- __syncthreads ();
780
854
781
855
// s = qk
782
856
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
@@ -830,21 +904,29 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
830
904
kv_idx_base,
831
905
chunk_end,
832
906
const_k_offset);
833
- commit_group ();
834
907
if constexpr (IsDynamicC8) {
835
- produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
836
- v_smem_scale,
837
- cache_v_scale_reg,
908
+ produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero ,
909
+ BLOCK_SIZE,
910
+ num_frags_z,
911
+ NUM_WARP_Q>(
912
+ k_scale_smem,
838
913
block_table_now,
839
- cache_v_scale ,
840
- ori_kv_idx_base ,
914
+ cache_k_scale ,
915
+ kv_idx_base ,
841
916
kv_num_heads,
842
917
kv_head_idx,
843
918
chunk_end
844
919
);
845
920
}
921
+ commit_group ();
846
922
wait_group<1 >();
847
923
__syncthreads ();
924
+ if constexpr (IsDynamicC8) {
925
+ produce_v_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
926
+ v_smem_scale_ptr,
927
+ cache_v_scale_reg
928
+ );
929
+ }
848
930
849
931
// compute sfm * v
850
932
compute_sfm_v_c8_iter_sq_bvec<num_frags_x,
@@ -875,6 +957,20 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
875
957
kv_idx_base,
876
958
chunk_end,
877
959
const_v_offset);
960
+ if constexpr (IsDynamicC8) {
961
+ produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero ,
962
+ BLOCK_SIZE,
963
+ num_frags_z,
964
+ NUM_WARP_Q>(
965
+ v_scale_smem,
966
+ block_table_now,
967
+ cache_v_scale,
968
+ kv_idx_base,
969
+ kv_num_heads,
970
+ kv_head_idx,
971
+ chunk_end
972
+ );
973
+ }
878
974
commit_group ();
879
975
}
880
976
wait_group<0 >();
0 commit comments