Skip to content

Commit c36f0aa

Browse files
authored
Fix test_mamba_ssm_ssd.py due to missing _query_start_loc_to_chunk_indices_offsets (vllm-project#25995)
Signed-off-by: Huamin Li <[email protected]>
1 parent 5234dc7 commit c36f0aa

File tree

2 files changed

+118
-63
lines changed

2 files changed

+118
-63
lines changed

tests/kernels/mamba/test_mamba_ssm_ssd.py

Lines changed: 48 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
mamba_chunk_scan_combined_varlen)
1111
from vllm.platforms import current_platform
1212
from vllm.v1.attention.backends.mamba2_attn import (
13-
_query_start_loc_to_chunk_indices_offsets)
13+
compute_varlen_chunk_metadata)
1414

1515
# Added by the IBM Team, 2024
1616

@@ -225,32 +225,30 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
225225
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
226226
B, C, chunk_size)
227227

228-
cu_seqlens = torch.tensor((0, seqlen), device='cuda').cumsum(dim=0)
229-
seq_idx = torch.zeros(seqlen, dtype=torch.int32, device=cu_seqlens.device)
230-
231-
chunk_indices, chunk_offsets = \
232-
_query_start_loc_to_chunk_indices_offsets(
233-
cu_seqlens, chunk_size, cu_seqlens[-1])
234-
228+
cu_seqlens = torch.tensor((0, seqlen), device="cuda").cumsum(dim=0)
229+
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
230+
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
235231
# varlen has implicit batch=1
236232
X = X.squeeze(0)
237233
dt = dt.squeeze(0)
238234
A = A.squeeze(0)
239235
B = B.squeeze(0)
240236
C = C.squeeze(0)
241237
Y = torch.empty_like(X)
242-
final_state = mamba_chunk_scan_combined_varlen(X,
243-
dt,
244-
A,
245-
B,
246-
C,
247-
chunk_size,
248-
D=None,
249-
cu_seqlens=cu_seqlens,
250-
seq_idx=seq_idx,
251-
chunk_indices=chunk_indices,
252-
chunk_offsets=chunk_offsets,
253-
out=Y)
238+
final_state = mamba_chunk_scan_combined_varlen(
239+
X,
240+
dt,
241+
A,
242+
B,
243+
C,
244+
chunk_size,
245+
cu_seqlens=cu_seqlens.to(torch.int32),
246+
cu_chunk_seqlens=cu_chunk_seqlens,
247+
last_chunk_indices=last_chunk_indices,
248+
seq_idx=seq_idx_chunks,
249+
out=Y,
250+
D=None,
251+
)
254252

255253
# just test the last in sequence
256254
torch.testing.assert_close(Y[-1], Y_min[0, -1], atol=atol, rtol=rtol)
@@ -312,14 +310,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
312310
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
313311

314312
states = None
315-
for Y_min, cu_seqlens, seq_idx, (
313+
for Y_min, cu_seqlens, _token_seq_idx, (
316314
A, dt, X, B, C) in generate_continuous_batched_examples(
317315
cases, num_examples, seqlen, last_taken, exhausted, n_heads,
318316
d_head, itype):
319317

320-
chunk_indices, chunk_offsets = \
321-
_query_start_loc_to_chunk_indices_offsets(
322-
cu_seqlens, chunk_size, cu_seqlens[-1])
318+
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
319+
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
323320

324321
Y = torch.empty_like(X)
325322
new_states = mamba_chunk_scan_combined_varlen(
@@ -329,13 +326,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
329326
B,
330327
C,
331328
chunk_size,
329+
cu_seqlens=cu_seqlens.to(torch.int32),
330+
cu_chunk_seqlens=cu_chunk_seqlens,
331+
last_chunk_indices=last_chunk_indices,
332+
seq_idx=seq_idx_chunks,
333+
out=Y,
332334
D=None,
333-
cu_seqlens=cu_seqlens,
334-
seq_idx=seq_idx,
335-
chunk_indices=chunk_indices,
336-
chunk_offsets=chunk_offsets,
337335
initial_states=states,
338-
out=Y,
339336
)
340337

341338
# just test the last in sequence
@@ -403,9 +400,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
403400
device = X.device
404401

405402
## full seqlen computation
406-
chunk_indices, chunk_offsets = \
407-
_query_start_loc_to_chunk_indices_offsets(
408-
cu_seqlens, chunk_size, cu_seqlens[-1])
403+
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
404+
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
409405
Y_ref = torch.empty_like(X)
410406
state_ref = mamba_chunk_scan_combined_varlen(
411407
X,
@@ -414,13 +410,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
414410
B,
415411
C,
416412
chunk_size,
413+
cu_seqlens=cu_seqlens.to(torch.int32),
414+
cu_chunk_seqlens=cu_chunk_seqlens,
415+
last_chunk_indices=last_chunk_indices,
416+
seq_idx=seq_idx_chunks,
417+
out=Y_ref,
417418
D=None,
418-
cu_seqlens=cu_seqlens,
419-
seq_idx=seq_idx,
420-
chunk_indices=chunk_indices,
421-
chunk_offsets=chunk_offsets,
422419
initial_states=None,
423-
out=Y_ref,
424420
)
425421

426422
## chunked seqlen computation
@@ -431,10 +427,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
431427
torch.cumsum(chunked_seqlens, dim=0)
432428
],
433429
dim=0)
434-
chunked_seq_idx = torch.repeat_interleave(
435-
torch.arange(len(chunked_seqlens), device=device),
436-
chunked_seqlens,
437-
output_size=chunked_cu_seqlens[-1]).to(torch.int32)
438430
chunked_input_seq_len = chunked_cu_seqlens[-1]
439431
X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...]
440432
dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...]
@@ -450,9 +442,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
450442
C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
451443
# fmt: on
452444

453-
chunk_indices, chunk_offsets = \
454-
_query_start_loc_to_chunk_indices_offsets(
455-
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
445+
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
446+
compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size))
456447
Y_partial = torch.empty_like(X_chunked)
457448
partial_state = mamba_chunk_scan_combined_varlen(
458449
X_chunked,
@@ -461,13 +452,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
461452
B_chunked,
462453
C_chunked,
463454
chunk_size,
455+
cu_seqlens=chunked_cu_seqlens.to(torch.int32),
456+
cu_chunk_seqlens=cu_chunk_seqlens,
457+
last_chunk_indices=last_chunk_indices,
458+
seq_idx=seq_idx_chunks,
459+
out=Y_partial,
464460
D=None,
465-
cu_seqlens=chunked_cu_seqlens,
466-
seq_idx=chunked_seq_idx,
467-
chunk_indices=chunk_indices,
468-
chunk_offsets=chunk_offsets,
469461
initial_states=None,
470-
out=Y_partial,
471462
)
472463

473464
# remaining chunk
@@ -477,10 +468,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
477468
torch.cumsum(remaining_chunked_seqlens, dim=0)
478469
],
479470
dim=0)
480-
remaining_chunked_seq_idx = torch.repeat_interleave(
481-
torch.arange(len(remaining_chunked_seqlens), device=device),
482-
remaining_chunked_seqlens,
483-
output_size=remaining_chunked_cu_seqlens[-1]).to(torch.int32)
484471
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
485472
# fmt: off
486473
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
@@ -509,11 +496,9 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
509496
assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B)
510497
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)
511498

512-
chunk_indices, chunk_offsets = \
513-
_query_start_loc_to_chunk_indices_offsets(
514-
remaining_chunked_cu_seqlens,
515-
chunk_size,
516-
remaining_chunked_cu_seqlens[-1])
499+
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
500+
compute_varlen_chunk_metadata(remaining_chunked_cu_seqlens,
501+
chunk_size))
517502

518503
Y_chunked = torch.empty_like(remaining_X_chunked)
519504
state_chunked = mamba_chunk_scan_combined_varlen(
@@ -523,13 +508,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
523508
remaining_B_chunked,
524509
remaining_C_chunked,
525510
chunk_size,
511+
cu_seqlens=remaining_chunked_cu_seqlens.to(torch.int32),
512+
cu_chunk_seqlens=cu_chunk_seqlens,
513+
last_chunk_indices=last_chunk_indices,
514+
seq_idx=seq_idx_chunks,
515+
out=Y_chunked,
526516
D=None,
527-
cu_seqlens=remaining_chunked_cu_seqlens,
528-
seq_idx=remaining_chunked_seq_idx,
529-
chunk_indices=chunk_indices,
530-
chunk_offsets=chunk_offsets,
531517
initial_states=partial_state,
532-
out=Y_chunked,
533518
)
534519
Y = concat_batch_f(Y_partial, Y_chunked)
535520

vllm/v1/attention/backends/mamba2_attn.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import itertools
34
from dataclasses import dataclass
45
from typing import Optional
56

@@ -17,6 +18,75 @@
1718
from vllm.v1.kv_cache_interface import AttentionSpec
1819

1920

21+
def compute_varlen_chunk_metadata(
22+
query_start_loc: torch.Tensor,
23+
chunk_size: int,
24+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
25+
"""
26+
Build chunk-aligned, variable-length metadata used by Mamba2 SSD kernels.
27+
28+
Given per-sequence cumulative token starts `query_start_loc` of shape [B+1]
29+
and a physical `chunk_size`, returns three tensors on the same device:
30+
- cu_chunk_seqlens: (nchunks+1,) int32 exclusive prefix-sum of
31+
logical-chunk lengths (each logical chunk never crosses a sequence or
32+
physical-chunk boundary).
33+
- last_chunk_indices: (B,) int32 index of the last logical chunk
34+
for each sequence (=-1 for empty sequences).
35+
- seq_idx_chunks: (nchunks,) int32 sequence index for each logical
36+
chunk in order.
37+
38+
This is intentionally lightweight and CPU-side; it mirrors the metadata
39+
produced by the V1 Mamba2 meta-data builder and is exported so tests
40+
(and other callers) can avoid duplicating the logic.
41+
"""
42+
assert query_start_loc.ndim == 1, "query_start_loc must be 1-D [B+1]"
43+
assert int(query_start_loc[0].item()) == 0, "query_start_loc[0] must be 0"
44+
device = query_start_loc.device
45+
46+
qsl64 = query_start_loc.to(torch.int64)
47+
starts = qsl64[:-1].tolist()
48+
ends = qsl64[1:].tolist()
49+
total = int(qsl64[-1].item())
50+
51+
chunk_lens: list[int] = []
52+
seq_idx_chunks: list[int] = []
53+
last_chunk_indices: list[int] = [-1] * len(starts)
54+
55+
for b, (s, e) in enumerate(zip(starts, ends)):
56+
if e <= s:
57+
# empty sequence
58+
continue
59+
pos = s
60+
while pos < e:
61+
# split at both sequence boundaries and physical chunk boundaries
62+
room = chunk_size - (pos % chunk_size)
63+
take = min(room, e - pos)
64+
chunk_lens.append(int(take))
65+
seq_idx_chunks.append(b)
66+
last_chunk_indices[b] = len(chunk_lens) - 1
67+
pos += take
68+
69+
# Exclusive prefix sum over logical-chunk lengths
70+
if chunk_lens:
71+
cu_chunk_seqlens = torch.tensor([0] +
72+
list(itertools.accumulate(chunk_lens)),
73+
device=device,
74+
dtype=torch.int32)
75+
# Final boundary must equal total tokens
76+
assert int(cu_chunk_seqlens[-1].item()) == total
77+
else:
78+
cu_chunk_seqlens = torch.tensor([0], device=device, dtype=torch.int32)
79+
80+
last_chunk_indices_t = (torch.tensor(
81+
last_chunk_indices, device=device, dtype=torch.int32)
82+
if len(starts) > 0 else torch.empty(
83+
(0, ), device=device, dtype=torch.int32))
84+
seq_idx_chunks_t = torch.tensor(seq_idx_chunks,
85+
device=device,
86+
dtype=torch.int32)
87+
return cu_chunk_seqlens, last_chunk_indices_t, seq_idx_chunks_t
88+
89+
2090
class Mamba2AttentionBackend(AttentionBackend):
2191

2292
@staticmethod

0 commit comments

Comments
 (0)