Skip to content

Commit 1e86418

Browse files
rsmallbluecarryyu
andauthored
optimize dy_cfp8's performance (#4145)
Co-authored-by: carryyu <[email protected]>
1 parent 5027ed7 commit 1e86418

File tree

2 files changed

+191
-103
lines changed

2 files changed

+191
-103
lines changed

custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh

Lines changed: 139 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,16 @@ __global__ void multi_query_append_attention_c8_kernel(
204204
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
205205
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
206206
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
207-
T* k_smem_scale = nullptr;
208-
T* v_smem_scale = nullptr;
207+
T* k_smem_scale_ptr = nullptr;
208+
T* v_smem_scale_ptr = nullptr;
209+
smem_t k_scale_smem;
210+
smem_t v_scale_smem;
209211
if constexpr (IsDynamicC8) {
210-
k_smem_scale = reinterpret_cast<T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
212+
k_smem_scale_ptr = reinterpret_cast<T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
211213
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
212-
v_smem_scale = k_smem_scale + num_frags_z * 16;
214+
v_smem_scale_ptr = k_smem_scale_ptr + num_frags_z * 16;
215+
k_scale_smem.base = reinterpret_cast<b128_t*>(k_smem_scale_ptr);
216+
v_scale_smem.base = reinterpret_cast<b128_t*>(v_smem_scale_ptr);
213217
}
214218

215219

@@ -271,6 +275,20 @@ __global__ void multi_query_append_attention_c8_kernel(
271275
kv_idx_base,
272276
chunk_end,
273277
const_k_offset);
278+
if constexpr (IsDynamicC8) {
279+
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
280+
BLOCK_SIZE,
281+
num_frags_z,
282+
NUM_WARP_Q>(
283+
k_scale_smem,
284+
block_table_now,
285+
cache_k_scale,
286+
kv_idx_base,
287+
kv_num_heads,
288+
kv_head_idx,
289+
chunk_end
290+
);
291+
}
274292
commit_group();
275293
produce_v_blockwise_c8<SharedMemFillMode::kNoFill,
276294
NUM_WARPS,
@@ -288,24 +306,32 @@ __global__ void multi_query_append_attention_c8_kernel(
288306
kv_idx_base,
289307
chunk_end,
290308
const_v_offset);
309+
if constexpr (IsDynamicC8) {
310+
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
311+
BLOCK_SIZE,
312+
num_frags_z,
313+
NUM_WARP_Q>(
314+
v_scale_smem,
315+
block_table_now,
316+
cache_v_scale,
317+
kv_idx_base,
318+
kv_num_heads,
319+
kv_head_idx,
320+
chunk_end
321+
);
322+
}
291323
commit_group();
292324

293325
#pragma unroll 1
294326
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
327+
wait_group<1>();
328+
__syncthreads();
295329
if constexpr (IsDynamicC8) {
296-
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
297-
k_smem_scale,
298-
cache_k_scale_reg,
299-
block_table_now,
300-
cache_k_scale,
301-
kv_idx_base,
302-
kv_num_heads,
303-
kv_head_idx,
304-
chunk_end
330+
produce_k_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
331+
k_smem_scale_ptr,
332+
cache_k_scale_reg
305333
);
306334
}
307-
wait_group<1>();
308-
__syncthreads();
309335
// s = qk
310336
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
311337
&qo_smem,
@@ -358,21 +384,29 @@ __global__ void multi_query_append_attention_c8_kernel(
358384
kv_idx_base,
359385
chunk_end,
360386
const_k_offset);
361-
commit_group();
362387
if constexpr (IsDynamicC8) {
363-
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
364-
v_smem_scale,
365-
cache_v_scale_reg,
388+
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
389+
BLOCK_SIZE,
390+
num_frags_z,
391+
NUM_WARP_Q>(
392+
k_scale_smem,
366393
block_table_now,
367-
cache_v_scale,
368-
ori_kv_idx_base,
394+
cache_k_scale,
395+
kv_idx_base,
369396
kv_num_heads,
370397
kv_head_idx,
371398
chunk_end
372399
);
373400
}
401+
commit_group();
374402
wait_group<1>();
375403
__syncthreads();
404+
if constexpr (IsDynamicC8) {
405+
produce_v_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
406+
v_smem_scale_ptr,
407+
cache_v_scale_reg
408+
);
409+
}
376410

377411
// compute sfm*v
378412
compute_sfm_v_c8<num_frags_x,
@@ -403,6 +437,20 @@ __global__ void multi_query_append_attention_c8_kernel(
403437
kv_idx_base,
404438
chunk_end,
405439
const_v_offset);
440+
if constexpr (IsDynamicC8) {
441+
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
442+
BLOCK_SIZE,
443+
num_frags_z,
444+
NUM_WARP_Q>(
445+
v_scale_smem,
446+
block_table_now,
447+
cache_v_scale,
448+
kv_idx_base,
449+
kv_num_heads,
450+
kv_head_idx,
451+
chunk_end
452+
);
453+
}
406454
commit_group();
407455

408456
}
@@ -674,12 +722,16 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
674722
smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
675723
v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
676724
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
677-
T* k_smem_scale = nullptr;
678-
T* v_smem_scale = nullptr;
725+
T* k_smem_scale_ptr = nullptr;
726+
T* v_smem_scale_ptr = nullptr;
727+
smem_t k_scale_smem;
728+
smem_t v_scale_smem;
679729
if constexpr (IsDynamicC8) {
680-
k_smem_scale = reinterpret_cast<T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
681-
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
682-
v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16;
730+
k_smem_scale_ptr = reinterpret_cast<T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
731+
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
732+
v_smem_scale_ptr = k_smem_scale_ptr + NUM_WARP_KV * num_frags_z * 16;
733+
k_scale_smem.base = reinterpret_cast<b128_t*>(k_smem_scale_ptr);
734+
v_scale_smem.base = reinterpret_cast<b128_t*>(v_smem_scale_ptr);
683735
}
684736

685737
const uint32_t num_iterations = div_up(
@@ -743,6 +795,20 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
743795
kv_idx_base,
744796
chunk_end,
745797
const_k_offset);
798+
if constexpr (IsDynamicC8) {
799+
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
800+
BLOCK_SIZE,
801+
num_frags_z,
802+
NUM_WARP_Q>(
803+
k_scale_smem,
804+
block_table_now,
805+
cache_k_scale,
806+
kv_idx_base,
807+
kv_num_heads,
808+
kv_head_idx,
809+
chunk_end
810+
);
811+
}
746812
commit_group();
747813
produce_v_blockwise_c8<SharedMemFillMode::kNoFill,
748814
NUM_WARPS,
@@ -760,23 +826,31 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
760826
kv_idx_base,
761827
chunk_end,
762828
const_v_offset);
829+
if constexpr (IsDynamicC8) {
830+
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
831+
BLOCK_SIZE,
832+
num_frags_z,
833+
NUM_WARP_Q>(
834+
v_scale_smem,
835+
block_table_now,
836+
cache_v_scale,
837+
kv_idx_base,
838+
kv_num_heads,
839+
kv_head_idx,
840+
chunk_end
841+
);
842+
}
763843
commit_group();
764844
#pragma unroll 1
765845
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
846+
wait_group<1>();
847+
__syncthreads();
766848
if constexpr (IsDynamicC8) {
767-
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
768-
k_smem_scale,
769-
cache_k_scale_reg,
770-
block_table_now,
771-
cache_k_scale,
772-
kv_idx_base,
773-
kv_num_heads,
774-
kv_head_idx,
775-
chunk_end
849+
produce_k_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
850+
k_smem_scale_ptr,
851+
cache_k_scale_reg
776852
);
777853
}
778-
wait_group<1>();
779-
__syncthreads();
780854

781855
// s = qk
782856
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
@@ -830,21 +904,29 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
830904
kv_idx_base,
831905
chunk_end,
832906
const_k_offset);
833-
commit_group();
834907
if constexpr (IsDynamicC8) {
835-
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
836-
v_smem_scale,
837-
cache_v_scale_reg,
908+
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
909+
BLOCK_SIZE,
910+
num_frags_z,
911+
NUM_WARP_Q>(
912+
k_scale_smem,
838913
block_table_now,
839-
cache_v_scale,
840-
ori_kv_idx_base,
914+
cache_k_scale,
915+
kv_idx_base,
841916
kv_num_heads,
842917
kv_head_idx,
843918
chunk_end
844919
);
845920
}
921+
commit_group();
846922
wait_group<1>();
847923
__syncthreads();
924+
if constexpr (IsDynamicC8) {
925+
produce_v_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
926+
v_smem_scale_ptr,
927+
cache_v_scale_reg
928+
);
929+
}
848930

849931
// compute sfm * v
850932
compute_sfm_v_c8_iter_sq_bvec<num_frags_x,
@@ -875,6 +957,20 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
875957
kv_idx_base,
876958
chunk_end,
877959
const_v_offset);
960+
if constexpr (IsDynamicC8) {
961+
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
962+
BLOCK_SIZE,
963+
num_frags_z,
964+
NUM_WARP_Q>(
965+
v_scale_smem,
966+
block_table_now,
967+
cache_v_scale,
968+
kv_idx_base,
969+
kv_num_heads,
970+
kv_head_idx,
971+
chunk_end
972+
);
973+
}
878974
commit_group();
879975
}
880976
wait_group<0>();

0 commit comments

Comments
 (0)