Skip to content

Commit 55576c6

Browse files
authored
misc: fix instrument code for mla profiler (flashinfer-ai#1014)
Follow up of flashinfer-ai#952 , this PR adds the instrument code base to profile mla hopper implementation (fix flashinfer-ai#995 )
1 parent f579ca2 commit 55576c6

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

include/flashinfer/attention/mla_hopper.cuh

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,8 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
647647
const uint32_t warp_idx = cutlass::canonical_warp_idx();
648648
const uint32_t warp_idx_in_wg = cutlass::canonical_warp_idx() % 4;
649649

650+
PROFILER_INIT(params, smem_storage, variant, warp_group_idx, 2, (threadIdx.x % 128 == 0));
651+
650652
using MainloopPipeline = typename KTraits::MainloopPipeline;
651653
using PipelineParams = typename MainloopPipeline::Params;
652654
using PipelineState = typename MainloopPipeline::PipelineState;
@@ -710,9 +712,11 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
710712
block_size, kv_indices, ckv_offset, kpe_offset);
711713
if (has_kv) {
712714
pipeline_kv.producer_acquire(smem_pipe_write_kv);
715+
PROFILER_EVENT_START(variant, ProfileEventType::kIssueLoadKV);
713716
load_kv<true, KTraits>(&smem_storage, ckv, kpe, packed_kv_bound,
714717
block_iter_base + kv_tile_idx * CTA_TILE_KV,
715718
smem_pipe_write_kv.index(), ckv_offset, kpe_offset);
719+
PROFILER_EVENT_END(variant, ProfileEventType::kIssueLoadKV);
716720
pipeline_kv.producer_commit(smem_pipe_write_kv, cutlass::arch::cpasync_barrier_arrive);
717721
kv_tile_idx -= 1;
718722
++smem_pipe_write_kv;
@@ -722,19 +726,23 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
722726
}
723727

724728
pipeline_q.producer_acquire(smem_pipe_write_q);
729+
PROFILER_EVENT_START(variant, ProfileEventType::kIssueLoadQ);
725730
load_q<KTraits>(&smem_storage, q_nope + q_indptr * q_nope_stride_n,
726731
q_pe + q_indptr * q_pe_stride_n, q_nope_stride_n, q_nope_stride_h,
727732
q_pe_stride_n, q_pe_stride_h, qo_upperbound, qo_packed_idx_base,
728733
params.num_heads);
734+
PROFILER_EVENT_END(variant, ProfileEventType::kIssueLoadQ);
729735
pipeline_q.producer_commit(smem_pipe_write_q, cutlass::arch::cpasync_barrier_arrive);
730736
++smem_pipe_write_q;
731737

732738
#pragma unroll 1
733739
for (; kv_tile_idx >= 0; --kv_tile_idx) {
734740
pipeline_kv.producer_acquire(smem_pipe_write_kv);
741+
PROFILER_EVENT_START(variant, ProfileEventType::kIssueLoadKV);
735742
load_kv<false, KTraits>(&smem_storage, ckv, kpe, packed_kv_bound,
736743
block_iter_base + kv_tile_idx * CTA_TILE_KV,
737744
smem_pipe_write_kv.index(), ckv_offset, kpe_offset);
745+
PROFILER_EVENT_END(variant, ProfileEventType::kIssueLoadKV);
738746
if (kv_tile_idx > 0) {
739747
prefetch_offset<KTraits>(block_iter_base + (kv_tile_idx - 1) * CTA_TILE_KV,
740748
packed_kv_bound, ckv_stride_page, ckv_stride_n, kpe_stride_page,
@@ -745,23 +753,31 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
745753

746754
barrier_sync(KTraits::NUM_THREADS, NamedBarriers::kOScaleReady);
747755
load_o_scale_smem<KTraits>(&smem_storage, o_scale);
756+
PROFILER_EVENT_START(variant, ProfileEventType::kRescaleO);
748757
rescale_o_<KTraits>(o_scale, o_frag);
758+
PROFILER_EVENT_END(variant, ProfileEventType::kRescaleO);
749759
consumer_wait(pipeline_kv, smem_pipe_read_kv);
750760
__syncthreads();
761+
PROFILER_EVENT_START(variant, ProfileEventType::kGemmPV);
751762
compute_mla_pv<KTraits>(&smem_storage, smem_pipe_read_kv.index(), o_frag);
752763
warpgroup_wait<0>();
764+
PROFILER_EVENT_END(variant, ProfileEventType::kGemmPV);
753765
pipeline_kv.consumer_release(smem_pipe_read_kv);
754766
++smem_pipe_read_kv;
755767
}
756768

757769
if (has_kv) {
758770
barrier_sync(KTraits::NUM_THREADS, NamedBarriers::kOScaleReady);
759771
load_o_scale_smem<KTraits>(&smem_storage, o_scale);
772+
PROFILER_EVENT_START(variant, ProfileEventType::kRescaleO);
760773
rescale_o_<KTraits>(o_scale, o_frag);
774+
PROFILER_EVENT_END(variant, ProfileEventType::kRescaleO);
761775
consumer_wait(pipeline_kv, smem_pipe_read_kv);
762776
__syncthreads();
777+
PROFILER_EVENT_START(variant, ProfileEventType::kGemmPV);
763778
compute_mla_pv<KTraits>(&smem_storage, smem_pipe_read_kv.index(), o_frag);
764779
warpgroup_wait<0>();
780+
PROFILER_EVENT_END(variant, ProfileEventType::kGemmPV);
765781
pipeline_kv.consumer_release(smem_pipe_read_kv);
766782
++smem_pipe_read_kv;
767783
}
@@ -774,12 +790,14 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
774790
}
775791
normalize_d_<KTraits>(&smem_storage, o_frag, m, d);
776792
finalize_m_<KTraits>(variant, m);
793+
PROFILER_EVENT_START(variant, ProfileEventType::kWriteO);
777794
write_o<false, KTraits>(
778795
&smem_storage, smem_pipe_write_kv.index(), final_o + q_indptr * o_stride_n,
779796
final_lse ? final_lse + q_indptr * num_heads : nullptr,
780797
(partial_indptr == -1) ? nullptr : partial_o + partial_indptr * KTraits::HEAD_DIM_CKV,
781798
(partial_indptr == -1) ? nullptr : partial_lse + partial_indptr, o_frag, m, d, o_stride_n,
782799
o_stride_h, qo_upperbound, qo_packed_idx_base, num_heads);
800+
PROFILER_EVENT_END(variant, ProfileEventType::kWriteO);
783801
__syncthreads();
784802
}
785803
} else {
@@ -816,57 +834,82 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
816834
#pragma unroll 1
817835
for (; kv_tile_idx >= mask_tile_idx && kv_tile_idx > 0; --kv_tile_idx) {
818836
consumer_wait(pipeline_kv, smem_pipe_read_kv);
837+
PROFILER_EVENT_START(variant, ProfileEventType::kGemmQK);
819838
compute_mla_qk<KTraits>(&smem_storage, smem_pipe_read_kv.index(), s_frag);
820839
warpgroup_wait<0>();
840+
PROFILER_EVENT_END(variant, ProfileEventType::kGemmQK);
821841
logits_mask_<KTraits>(qo_packed_idx_base, kv_start + kv_tile_idx * CTA_TILE_KV, q_len,
822842
kv_len, kv_end, num_heads, s_frag);
843+
PROFILER_EVENT_START(variant, ProfileEventType::kSoftmaxUpdate);
823844
update_md_<KTraits>(&smem_storage, variant, s_frag, m, d, o_scale);
845+
PROFILER_EVENT_END(variant, ProfileEventType::kSoftmaxUpdate);
824846
write_o_scale_smem<KTraits>(&smem_storage, o_scale);
847+
825848
convert_s_to_p<KTraits>(s_frag, p_frag);
826849
write_p_rmem_smem<KTraits>(&smem_storage, smem_pipe_read_kv.index(), p_frag);
827850
barrier_arrive(KTraits::NUM_THREADS, NamedBarriers::kOScaleReady);
851+
PROFILER_EVENT_START(variant, ProfileEventType::kRescaleO);
828852
rescale_o_<KTraits>(o_scale, o_frag);
853+
PROFILER_EVENT_END(variant, ProfileEventType::kRescaleO);
829854
__syncthreads();
855+
PROFILER_EVENT_START(variant, ProfileEventType::kGemmPV);
830856
compute_mla_pv<KTraits>(&smem_storage, smem_pipe_read_kv.index(), o_frag);
831857
warpgroup_wait<0>();
858+
PROFILER_EVENT_END(variant, ProfileEventType::kGemmPV);
832859
pipeline_kv.consumer_release(smem_pipe_read_kv);
833860
++smem_pipe_read_kv;
834861
}
835862

836863
#pragma unroll 1
837864
for (; kv_tile_idx + 1 > NUM_STAGES; --kv_tile_idx) {
838865
consumer_wait(pipeline_kv, smem_pipe_read_kv);
866+
PROFILER_EVENT_START(variant, ProfileEventType::kGemmQK);
839867
compute_mla_qk<KTraits>(&smem_storage, smem_pipe_read_kv.index(), s_frag);
840868
warpgroup_wait<0>();
869+
PROFILER_EVENT_END(variant, ProfileEventType::kGemmQK);
870+
PROFILER_EVENT_START(variant, ProfileEventType::kSoftmaxUpdate);
841871
update_md_<KTraits>(&smem_storage, variant, s_frag, m, d, o_scale);
872+
PROFILER_EVENT_END(variant, ProfileEventType::kSoftmaxUpdate);
842873
write_o_scale_smem<KTraits>(&smem_storage, o_scale);
843874
convert_s_to_p<KTraits>(s_frag, p_frag);
844875
write_p_rmem_smem<KTraits>(&smem_storage, smem_pipe_read_kv.index(), p_frag);
845876
barrier_arrive(KTraits::NUM_THREADS, NamedBarriers::kOScaleReady);
877+
PROFILER_EVENT_START(variant, ProfileEventType::kRescaleO);
846878
rescale_o_<KTraits>(o_scale, o_frag);
879+
PROFILER_EVENT_END(variant, ProfileEventType::kRescaleO);
847880
__syncthreads();
881+
PROFILER_EVENT_START(variant, ProfileEventType::kGemmPV);
848882
compute_mla_pv<KTraits>(&smem_storage, smem_pipe_read_kv.index(), o_frag);
849883
warpgroup_wait<0>();
884+
PROFILER_EVENT_END(variant, ProfileEventType::kGemmPV);
850885
pipeline_kv.consumer_release(smem_pipe_read_kv);
851886
++smem_pipe_read_kv;
852887
}
853888

854889
#pragma unroll 1
855890
for (; kv_tile_idx >= 0; --kv_tile_idx) {
856891
consumer_wait(pipeline_kv, smem_pipe_read_kv);
892+
PROFILER_EVENT_START(variant, ProfileEventType::kGemmQK);
857893
compute_mla_qk<KTraits>(&smem_storage, smem_pipe_read_kv.index(), s_frag);
858894
warpgroup_wait<0>();
895+
PROFILER_EVENT_END(variant, ProfileEventType::kGemmQK);
859896
logits_mask_<KTraits>(qo_packed_idx_base, kv_start + kv_tile_idx * CTA_TILE_KV, q_len,
860897
kv_len, kv_end, num_heads, s_frag);
898+
PROFILER_EVENT_START(variant, ProfileEventType::kSoftmaxUpdate);
861899
update_md_<KTraits>(&smem_storage, variant, s_frag, m, d, o_scale);
900+
PROFILER_EVENT_END(variant, ProfileEventType::kSoftmaxUpdate);
862901
write_o_scale_smem<KTraits>(&smem_storage, o_scale);
863902
convert_s_to_p<KTraits>(s_frag, p_frag);
864903
write_p_rmem_smem<KTraits>(&smem_storage, smem_pipe_read_kv.index(), p_frag);
865904
barrier_arrive(KTraits::NUM_THREADS, NamedBarriers::kOScaleReady);
905+
PROFILER_EVENT_START(variant, ProfileEventType::kRescaleO);
866906
rescale_o_<KTraits>(o_scale, o_frag);
907+
PROFILER_EVENT_END(variant, ProfileEventType::kRescaleO);
867908
__syncthreads();
909+
PROFILER_EVENT_START(variant, ProfileEventType::kGemmPV);
868910
compute_mla_pv<KTraits>(&smem_storage, smem_pipe_read_kv.index(), o_frag);
869911
warpgroup_wait<0>();
912+
PROFILER_EVENT_END(variant, ProfileEventType::kGemmPV);
870913
pipeline_kv.consumer_release(smem_pipe_read_kv);
871914
++smem_pipe_read_kv;
872915
}
@@ -886,12 +929,14 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
886929
normalize_d_<KTraits>(&smem_storage, o_frag, m, d);
887930
finalize_m_<KTraits>(variant, m);
888931
barrier_arrive(KTraits::NUM_THREADS, NamedBarriers::kMDReady);
932+
PROFILER_EVENT_START(variant, ProfileEventType::kWriteO);
889933
write_o<true, KTraits>(
890934
&smem_storage, smem_pipe_read_kv.index(), final_o + q_indptr * o_stride_n,
891935
final_lse ? final_lse + q_indptr * num_heads : nullptr,
892936
(partial_indptr == -1) ? nullptr : partial_o + partial_indptr * KTraits::HEAD_DIM_CKV,
893937
(partial_indptr == -1) ? nullptr : partial_lse + partial_indptr, o_frag, m, d, o_stride_n,
894938
o_stride_h, qo_upperbound, qo_packed_idx_base, num_heads);
939+
PROFILER_EVENT_END(variant, ProfileEventType::kWriteO);
895940
__syncthreads();
896941
}
897942
}

0 commit comments

Comments
 (0)