Skip to content

Commit ee0a366

Browse files
committed
resolve comments
Signed-off-by: ganyi <[email protected]>
1 parent f406805 commit ee0a366

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ def __init__(
157157
parallel_config = vllm_config.parallel_config
158158
self.device = device
159159
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
160-
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
161160

162161
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
163162
self.mla_dims = get_mla_dims(self.model_config)
@@ -179,10 +178,10 @@ def __init__(
179178
device=device,
180179
)
181180
self.qo_indptr = torch.arange(
182-
0, max_num_seqs + 1, dtype=torch.int32, device=device
181+
0, max_num_batched_tokens + 1, dtype=torch.int32, device=device
183182
)
184183
self.paged_kv_last_page_len = torch.ones(
185-
max_num_seqs, dtype=torch.int32, device=device
184+
max_num_batched_tokens, dtype=torch.int32, device=device
186185
)
187186

188187
# These two needs to be calculated in runtime,
@@ -193,7 +192,7 @@ def __init__(
193192
device=device,
194193
)
195194
self.paged_kv_indptr = torch.zeros(
196-
[max_num_seqs + 1], dtype=torch.int32, device=device
195+
[max_num_batched_tokens + 1], dtype=torch.int32, device=device
197196
)
198197

199198
def build(
@@ -203,7 +202,6 @@ def build(
203202
fast_build: bool = False,
204203
) -> ROCMAiterMLASparseMetadata:
205204
num_tokens = common_attn_metadata.num_actual_tokens
206-
num_reqs = common_attn_metadata.num_reqs
207205
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
208206
seg_lengths = np.diff(starts)
209207
req_id_per_token = np.repeat(
@@ -218,11 +216,11 @@ def build(
218216
self.paged_kv_indptr.fill_(0)
219217

220218
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
221-
qo_indptr = self.qo_indptr[: num_reqs + 1]
222-
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
219+
qo_indptr = self.qo_indptr[: num_tokens + 1]
220+
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_tokens]
223221
paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens]
224-
paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1]
225-
paged_kv_indptr_rest = self.paged_kv_indptr[num_reqs + 1 :]
222+
paged_kv_indptr = self.paged_kv_indptr[: num_tokens + 1]
223+
paged_kv_indptr_rest = self.paged_kv_indptr[num_tokens + 1 :]
226224

227225
metadata = ROCMAiterMLASparseMetadata(
228226
num_reqs=common_attn_metadata.num_reqs,

0 commit comments

Comments
 (0)