@@ -647,6 +647,8 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
647647 const uint32_t warp_idx = cutlass::canonical_warp_idx ();
648648 const uint32_t warp_idx_in_wg = cutlass::canonical_warp_idx () % 4 ;
649649
650+ PROFILER_INIT (params, smem_storage, variant, warp_group_idx, 2 , (threadIdx .x % 128 == 0 ));
651+
650652 using MainloopPipeline = typename KTraits::MainloopPipeline;
651653 using PipelineParams = typename MainloopPipeline::Params;
652654 using PipelineState = typename MainloopPipeline::PipelineState;
@@ -710,9 +712,11 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
710712 block_size, kv_indices, ckv_offset, kpe_offset);
711713 if (has_kv) {
712714 pipeline_kv.producer_acquire (smem_pipe_write_kv);
715+ PROFILER_EVENT_START (variant, ProfileEventType::kIssueLoadKV );
713716 load_kv<true , KTraits>(&smem_storage, ckv, kpe, packed_kv_bound,
714717 block_iter_base + kv_tile_idx * CTA_TILE_KV,
715718 smem_pipe_write_kv.index (), ckv_offset, kpe_offset);
719+ PROFILER_EVENT_END (variant, ProfileEventType::kIssueLoadKV );
716720 pipeline_kv.producer_commit (smem_pipe_write_kv, cutlass::arch::cpasync_barrier_arrive);
717721 kv_tile_idx -= 1 ;
718722 ++smem_pipe_write_kv;
@@ -722,19 +726,23 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
722726 }
723727
724728 pipeline_q.producer_acquire (smem_pipe_write_q);
729+ PROFILER_EVENT_START (variant, ProfileEventType::kIssueLoadQ );
725730 load_q<KTraits>(&smem_storage, q_nope + q_indptr * q_nope_stride_n,
726731 q_pe + q_indptr * q_pe_stride_n, q_nope_stride_n, q_nope_stride_h,
727732 q_pe_stride_n, q_pe_stride_h, qo_upperbound, qo_packed_idx_base,
728733 params.num_heads );
734+ PROFILER_EVENT_END (variant, ProfileEventType::kIssueLoadQ );
729735 pipeline_q.producer_commit (smem_pipe_write_q, cutlass::arch::cpasync_barrier_arrive);
730736 ++smem_pipe_write_q;
731737
732738#pragma unroll 1
733739 for (; kv_tile_idx >= 0 ; --kv_tile_idx) {
734740 pipeline_kv.producer_acquire (smem_pipe_write_kv);
741+ PROFILER_EVENT_START (variant, ProfileEventType::kIssueLoadKV );
735742 load_kv<false , KTraits>(&smem_storage, ckv, kpe, packed_kv_bound,
736743 block_iter_base + kv_tile_idx * CTA_TILE_KV,
737744 smem_pipe_write_kv.index (), ckv_offset, kpe_offset);
745+ PROFILER_EVENT_END (variant, ProfileEventType::kIssueLoadKV );
738746 if (kv_tile_idx > 0 ) {
739747 prefetch_offset<KTraits>(block_iter_base + (kv_tile_idx - 1 ) * CTA_TILE_KV,
740748 packed_kv_bound, ckv_stride_page, ckv_stride_n, kpe_stride_page,
@@ -745,23 +753,31 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
745753
746754 barrier_sync (KTraits::NUM_THREADS, NamedBarriers::kOScaleReady );
747755 load_o_scale_smem<KTraits>(&smem_storage, o_scale);
756+ PROFILER_EVENT_START (variant, ProfileEventType::kRescaleO );
748757 rescale_o_<KTraits>(o_scale, o_frag);
758+ PROFILER_EVENT_END (variant, ProfileEventType::kRescaleO );
749759 consumer_wait (pipeline_kv, smem_pipe_read_kv);
750760 __syncthreads ();
761+ PROFILER_EVENT_START (variant, ProfileEventType::kGemmPV );
751762 compute_mla_pv<KTraits>(&smem_storage, smem_pipe_read_kv.index (), o_frag);
752763 warpgroup_wait<0 >();
764+ PROFILER_EVENT_END (variant, ProfileEventType::kGemmPV );
753765 pipeline_kv.consumer_release (smem_pipe_read_kv);
754766 ++smem_pipe_read_kv;
755767 }
756768
757769 if (has_kv) {
758770 barrier_sync (KTraits::NUM_THREADS, NamedBarriers::kOScaleReady );
759771 load_o_scale_smem<KTraits>(&smem_storage, o_scale);
772+ PROFILER_EVENT_START (variant, ProfileEventType::kRescaleO );
760773 rescale_o_<KTraits>(o_scale, o_frag);
774+ PROFILER_EVENT_END (variant, ProfileEventType::kRescaleO );
761775 consumer_wait (pipeline_kv, smem_pipe_read_kv);
762776 __syncthreads ();
777+ PROFILER_EVENT_START (variant, ProfileEventType::kGemmPV );
763778 compute_mla_pv<KTraits>(&smem_storage, smem_pipe_read_kv.index (), o_frag);
764779 warpgroup_wait<0 >();
780+ PROFILER_EVENT_END (variant, ProfileEventType::kGemmPV );
765781 pipeline_kv.consumer_release (smem_pipe_read_kv);
766782 ++smem_pipe_read_kv;
767783 }
@@ -774,12 +790,14 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
774790 }
775791 normalize_d_<KTraits>(&smem_storage, o_frag, m, d);
776792 finalize_m_<KTraits>(variant, m);
793+ PROFILER_EVENT_START (variant, ProfileEventType::kWriteO );
777794 write_o<false , KTraits>(
778795 &smem_storage, smem_pipe_write_kv.index (), final_o + q_indptr * o_stride_n,
779796 final_lse ? final_lse + q_indptr * num_heads : nullptr ,
780797 (partial_indptr == -1 ) ? nullptr : partial_o + partial_indptr * KTraits::HEAD_DIM_CKV,
781798 (partial_indptr == -1 ) ? nullptr : partial_lse + partial_indptr, o_frag, m, d, o_stride_n,
782799 o_stride_h, qo_upperbound, qo_packed_idx_base, num_heads);
800+ PROFILER_EVENT_END (variant, ProfileEventType::kWriteO );
783801 __syncthreads ();
784802 }
785803 } else {
@@ -816,57 +834,82 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
816834#pragma unroll 1
817835 for (; kv_tile_idx >= mask_tile_idx && kv_tile_idx > 0 ; --kv_tile_idx) {
818836 consumer_wait (pipeline_kv, smem_pipe_read_kv);
837+ PROFILER_EVENT_START (variant, ProfileEventType::kGemmQK );
819838 compute_mla_qk<KTraits>(&smem_storage, smem_pipe_read_kv.index (), s_frag);
820839 warpgroup_wait<0 >();
840+ PROFILER_EVENT_END (variant, ProfileEventType::kGemmQK );
821841 logits_mask_<KTraits>(qo_packed_idx_base, kv_start + kv_tile_idx * CTA_TILE_KV, q_len,
822842 kv_len, kv_end, num_heads, s_frag);
843+ PROFILER_EVENT_START (variant, ProfileEventType::kSoftmaxUpdate );
823844 update_md_<KTraits>(&smem_storage, variant, s_frag, m, d, o_scale);
845+ PROFILER_EVENT_END (variant, ProfileEventType::kSoftmaxUpdate );
824846 write_o_scale_smem<KTraits>(&smem_storage, o_scale);
847+
825848 convert_s_to_p<KTraits>(s_frag, p_frag);
826849 write_p_rmem_smem<KTraits>(&smem_storage, smem_pipe_read_kv.index (), p_frag);
827850 barrier_arrive (KTraits::NUM_THREADS, NamedBarriers::kOScaleReady );
851+ PROFILER_EVENT_START (variant, ProfileEventType::kRescaleO );
828852 rescale_o_<KTraits>(o_scale, o_frag);
853+ PROFILER_EVENT_END (variant, ProfileEventType::kRescaleO );
829854 __syncthreads ();
855+ PROFILER_EVENT_START (variant, ProfileEventType::kGemmPV );
830856 compute_mla_pv<KTraits>(&smem_storage, smem_pipe_read_kv.index (), o_frag);
831857 warpgroup_wait<0 >();
858+ PROFILER_EVENT_END (variant, ProfileEventType::kGemmPV );
832859 pipeline_kv.consumer_release (smem_pipe_read_kv);
833860 ++smem_pipe_read_kv;
834861 }
835862
836863#pragma unroll 1
837864 for (; kv_tile_idx + 1 > NUM_STAGES; --kv_tile_idx) {
838865 consumer_wait (pipeline_kv, smem_pipe_read_kv);
866+ PROFILER_EVENT_START (variant, ProfileEventType::kGemmQK );
839867 compute_mla_qk<KTraits>(&smem_storage, smem_pipe_read_kv.index (), s_frag);
840868 warpgroup_wait<0 >();
869+ PROFILER_EVENT_END (variant, ProfileEventType::kGemmQK );
870+ PROFILER_EVENT_START (variant, ProfileEventType::kSoftmaxUpdate );
841871 update_md_<KTraits>(&smem_storage, variant, s_frag, m, d, o_scale);
872+ PROFILER_EVENT_END (variant, ProfileEventType::kSoftmaxUpdate );
842873 write_o_scale_smem<KTraits>(&smem_storage, o_scale);
843874 convert_s_to_p<KTraits>(s_frag, p_frag);
844875 write_p_rmem_smem<KTraits>(&smem_storage, smem_pipe_read_kv.index (), p_frag);
845876 barrier_arrive (KTraits::NUM_THREADS, NamedBarriers::kOScaleReady );
877+ PROFILER_EVENT_START (variant, ProfileEventType::kRescaleO );
846878 rescale_o_<KTraits>(o_scale, o_frag);
879+ PROFILER_EVENT_END (variant, ProfileEventType::kRescaleO );
847880 __syncthreads ();
881+ PROFILER_EVENT_START (variant, ProfileEventType::kGemmPV );
848882 compute_mla_pv<KTraits>(&smem_storage, smem_pipe_read_kv.index (), o_frag);
849883 warpgroup_wait<0 >();
884+ PROFILER_EVENT_END (variant, ProfileEventType::kGemmPV );
850885 pipeline_kv.consumer_release (smem_pipe_read_kv);
851886 ++smem_pipe_read_kv;
852887 }
853888
854889#pragma unroll 1
855890 for (; kv_tile_idx >= 0 ; --kv_tile_idx) {
856891 consumer_wait (pipeline_kv, smem_pipe_read_kv);
892+ PROFILER_EVENT_START (variant, ProfileEventType::kGemmQK );
857893 compute_mla_qk<KTraits>(&smem_storage, smem_pipe_read_kv.index (), s_frag);
858894 warpgroup_wait<0 >();
895+ PROFILER_EVENT_END (variant, ProfileEventType::kGemmQK );
859896 logits_mask_<KTraits>(qo_packed_idx_base, kv_start + kv_tile_idx * CTA_TILE_KV, q_len,
860897 kv_len, kv_end, num_heads, s_frag);
898+ PROFILER_EVENT_START (variant, ProfileEventType::kSoftmaxUpdate );
861899 update_md_<KTraits>(&smem_storage, variant, s_frag, m, d, o_scale);
900+ PROFILER_EVENT_END (variant, ProfileEventType::kSoftmaxUpdate );
862901 write_o_scale_smem<KTraits>(&smem_storage, o_scale);
863902 convert_s_to_p<KTraits>(s_frag, p_frag);
864903 write_p_rmem_smem<KTraits>(&smem_storage, smem_pipe_read_kv.index (), p_frag);
865904 barrier_arrive (KTraits::NUM_THREADS, NamedBarriers::kOScaleReady );
905+ PROFILER_EVENT_START (variant, ProfileEventType::kRescaleO );
866906 rescale_o_<KTraits>(o_scale, o_frag);
907+ PROFILER_EVENT_END (variant, ProfileEventType::kRescaleO );
867908 __syncthreads ();
909+ PROFILER_EVENT_START (variant, ProfileEventType::kGemmPV );
868910 compute_mla_pv<KTraits>(&smem_storage, smem_pipe_read_kv.index (), o_frag);
869911 warpgroup_wait<0 >();
912+ PROFILER_EVENT_END (variant, ProfileEventType::kGemmPV );
870913 pipeline_kv.consumer_release (smem_pipe_read_kv);
871914 ++smem_pipe_read_kv;
872915 }
@@ -886,12 +929,14 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
886929 normalize_d_<KTraits>(&smem_storage, o_frag, m, d);
887930 finalize_m_<KTraits>(variant, m);
888931 barrier_arrive (KTraits::NUM_THREADS, NamedBarriers::kMDReady );
932+ PROFILER_EVENT_START (variant, ProfileEventType::kWriteO );
889933 write_o<true , KTraits>(
890934 &smem_storage, smem_pipe_read_kv.index (), final_o + q_indptr * o_stride_n,
891935 final_lse ? final_lse + q_indptr * num_heads : nullptr ,
892936 (partial_indptr == -1 ) ? nullptr : partial_o + partial_indptr * KTraits::HEAD_DIM_CKV,
893937 (partial_indptr == -1 ) ? nullptr : partial_lse + partial_indptr, o_frag, m, d, o_stride_n,
894938 o_stride_h, qo_upperbound, qo_packed_idx_base, num_heads);
939+ PROFILER_EVENT_END (variant, ProfileEventType::kWriteO );
895940 __syncthreads ();
896941 }
897942 }
0 commit comments