@@ -157,7 +157,6 @@ def __init__(
157157 parallel_config = vllm_config .parallel_config
158158 self .device = device
159159 max_num_batched_tokens = vllm_config .scheduler_config .max_num_batched_tokens
160- max_num_seqs = vllm_config .scheduler_config .max_num_seqs
161160
162161 self .num_heads = self .model_config .get_num_attention_heads (parallel_config )
163162 self .mla_dims = get_mla_dims (self .model_config )
@@ -179,10 +178,10 @@ def __init__(
179178 device = device ,
180179 )
181180 self .qo_indptr = torch .arange (
182- 0 , max_num_seqs + 1 , dtype = torch .int32 , device = device
181+ 0 , max_num_batched_tokens + 1 , dtype = torch .int32 , device = device
183182 )
184183 self .paged_kv_last_page_len = torch .ones (
185- max_num_seqs , dtype = torch .int32 , device = device
184+ max_num_batched_tokens , dtype = torch .int32 , device = device
186185 )
187186
188187 # These two needs to be calculated in runtime,
@@ -193,7 +192,7 @@ def __init__(
193192 device = device ,
194193 )
195194 self .paged_kv_indptr = torch .zeros (
196- [max_num_seqs + 1 ], dtype = torch .int32 , device = device
195+ [max_num_batched_tokens + 1 ], dtype = torch .int32 , device = device
197196 )
198197
199198 def build (
@@ -203,7 +202,6 @@ def build(
203202 fast_build : bool = False ,
204203 ) -> ROCMAiterMLASparseMetadata :
205204 num_tokens = common_attn_metadata .num_actual_tokens
206- num_reqs = common_attn_metadata .num_reqs
207205 starts = np .asarray (common_attn_metadata .query_start_loc_cpu , dtype = np .int32 )
208206 seg_lengths = np .diff (starts )
209207 req_id_per_token = np .repeat (
@@ -218,11 +216,11 @@ def build(
218216 self .paged_kv_indptr .fill_ (0 )
219217
220218 req_id_per_token = self .req_id_per_token_buffer [:num_tokens ]
221- qo_indptr = self .qo_indptr [: num_reqs + 1 ]
222- paged_kv_last_page_len = self .paged_kv_last_page_len [:num_reqs ]
219+ qo_indptr = self .qo_indptr [: num_tokens + 1 ]
220+ paged_kv_last_page_len = self .paged_kv_last_page_len [:num_tokens ]
223221 paged_kv_indices = self .paged_kv_indices [: num_tokens * self .topk_tokens ]
224- paged_kv_indptr = self .paged_kv_indptr [: num_reqs + 1 ]
225- paged_kv_indptr_rest = self .paged_kv_indptr [num_reqs + 1 :]
222+ paged_kv_indptr = self .paged_kv_indptr [: num_tokens + 1 ]
223+ paged_kv_indptr_rest = self .paged_kv_indptr [num_tokens + 1 :]
226224
227225 metadata = ROCMAiterMLASparseMetadata (
228226 num_reqs = common_attn_metadata .num_reqs ,
0 commit comments