@@ -566,8 +566,6 @@ def mla_decode_stage1_asm_fwd(
566566 work_indptr : Optional [torch .Tensor ],
567567 work_info_set : Optional [torch .Tensor ],
568568 max_seqlen_q : int ,
569- page_size : int ,
570- nhead_kv : int ,
571569 softmax_scale : float ,
572570 # [batch_size, num_kv_splits, num_heads, v_head_dim]
573571 splitData : torch .Tensor ,
@@ -856,7 +854,6 @@ def get_mla_metadata_info_v1(
856854def get_mla_metadata_v1 (
857855 seqlens_qo_indptr : torch .Tensor ,
858856 seqlens_kv_indptr : torch .Tensor ,
859- kv_last_page_lens : torch .Tensor ,
860857 num_heads_per_head_k : int ,
861858 num_heads_k : int ,
862859 is_causal : bool ,
@@ -866,7 +863,6 @@ def get_mla_metadata_v1(
866863 reduce_indptr : torch .Tensor ,
867864 reduce_final_map : torch .Tensor ,
868865 reduce_partial_map : torch .Tensor ,
869- page_size : int = 1 ,
870866 kv_granularity : int = 16 ,
871867 max_seqlen_qo : int = - 1 ,
872868 uni_seqlen_qo : int = - 1 ,
@@ -880,14 +876,12 @@ def get_mla_metadata_v1(
880876 """
881877 Inputs:
882878 cumulated seqlens of q/o: (batch_size + 1), dtype torch.int32.
883- cumulated page indices of k/v: (batch_size + 1), dtype torch.int32.
884- Length of last page of k/v: (batch_size), dtype torch.int32.
879+ cumulated seqlens of k/v: (batch_size + 1), dtype torch.int32.
885880 num_heads_per_head_k: Equals to num_heads_q // num_heads_k.
886881 num_heads_k: num_heads_k.
887882 is_causal: Whether causal mask is enabled.
888883 Options: Detailed settings for spliting. All of them are optional.
889- page_size: default=1. The size of a page.
890- kv_granularity: default=16. The granularity on kv page nums when cutting batch.
884+ kv_granularity: default=16. The granularity on kv sequence length when cutting batch.
891885 max_seqlen_qo: default=-1. Used to check lds usage and save time. value less than 1 means unknown.
892886 uni_seqlen_qo: default=-1. Sequence length of qo is uniform across batches. value less than 1 means the
893887 length is not fixed.
@@ -905,11 +899,11 @@ def get_mla_metadata_v1(
905899 [2.2] q_start: (#work), The global index in seq where q/o starts. Use global index here can
906900 reduce memory access count in kernel.
907901 [2.3] q_end: (#work), The global index in seq where q/o ends (not included).
908- [2.4] kv_start: (#work), The global index in page where k/v starts.
909- [2.5] kv_end: (#work), The global index in page where k/v ends (not included). Note that
902+ [2.4] kv_start: (#work), The global index in seq where k/v starts.
903+ [2.5] kv_end: (#work), The global index in seq where k/v ends (not included). Note that
910904 this value indicates the end of last qo sequence if there are
911905 multiple qo sequences included in the current work and causal mask
912- is enabled when page_size is 1 .
906+ is enabled.
913907 [2.6] kv_offset: (#work), Remaining length in seq from kv_end to the end of current batch.
914908 [2.7] pad (#work, 1), Pad to 8 DWs.
915909 [3] reduce_indptr: (sum(qo_seqlen_blk_count) + 1),
0 commit comments