10
10
mamba_chunk_scan_combined_varlen )
11
11
from vllm .platforms import current_platform
12
12
from vllm .v1 .attention .backends .mamba2_attn import (
13
- _query_start_loc_to_chunk_indices_offsets )
13
+ compute_varlen_chunk_metadata )
14
14
15
15
# Added by the IBM Team, 2024
16
16
@@ -225,32 +225,30 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
225
225
Y_min , final_state_min = ssd_minimal_discrete (X * dt .unsqueeze (- 1 ), A * dt ,
226
226
B , C , chunk_size )
227
227
228
- cu_seqlens = torch .tensor ((0 , seqlen ), device = 'cuda' ).cumsum (dim = 0 )
229
- seq_idx = torch .zeros (seqlen , dtype = torch .int32 , device = cu_seqlens .device )
230
-
231
- chunk_indices , chunk_offsets = \
232
- _query_start_loc_to_chunk_indices_offsets (
233
- cu_seqlens , chunk_size , cu_seqlens [- 1 ])
234
-
228
+ cu_seqlens = torch .tensor ((0 , seqlen ), device = "cuda" ).cumsum (dim = 0 )
229
+ cu_chunk_seqlens , last_chunk_indices , seq_idx_chunks = (
230
+ compute_varlen_chunk_metadata (cu_seqlens , chunk_size ))
235
231
# varlen has implicit batch=1
236
232
X = X .squeeze (0 )
237
233
dt = dt .squeeze (0 )
238
234
A = A .squeeze (0 )
239
235
B = B .squeeze (0 )
240
236
C = C .squeeze (0 )
241
237
Y = torch .empty_like (X )
242
- final_state = mamba_chunk_scan_combined_varlen (X ,
243
- dt ,
244
- A ,
245
- B ,
246
- C ,
247
- chunk_size ,
248
- D = None ,
249
- cu_seqlens = cu_seqlens ,
250
- seq_idx = seq_idx ,
251
- chunk_indices = chunk_indices ,
252
- chunk_offsets = chunk_offsets ,
253
- out = Y )
238
+ final_state = mamba_chunk_scan_combined_varlen (
239
+ X ,
240
+ dt ,
241
+ A ,
242
+ B ,
243
+ C ,
244
+ chunk_size ,
245
+ cu_seqlens = cu_seqlens .to (torch .int32 ),
246
+ cu_chunk_seqlens = cu_chunk_seqlens ,
247
+ last_chunk_indices = last_chunk_indices ,
248
+ seq_idx = seq_idx_chunks ,
249
+ out = Y ,
250
+ D = None ,
251
+ )
254
252
255
253
# just test the last in sequence
256
254
torch .testing .assert_close (Y [- 1 ], Y_min [0 , - 1 ], atol = atol , rtol = rtol )
@@ -312,14 +310,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
312
310
exhausted : dict = {} # map: eg -> boolean indicating example is exhausted
313
311
314
312
states = None
315
- for Y_min , cu_seqlens , seq_idx , (
313
+ for Y_min , cu_seqlens , _token_seq_idx , (
316
314
A , dt , X , B , C ) in generate_continuous_batched_examples (
317
315
cases , num_examples , seqlen , last_taken , exhausted , n_heads ,
318
316
d_head , itype ):
319
317
320
- chunk_indices , chunk_offsets = \
321
- _query_start_loc_to_chunk_indices_offsets (
322
- cu_seqlens , chunk_size , cu_seqlens [- 1 ])
318
+ cu_chunk_seqlens , last_chunk_indices , seq_idx_chunks = (
319
+ compute_varlen_chunk_metadata (cu_seqlens , chunk_size ))
323
320
324
321
Y = torch .empty_like (X )
325
322
new_states = mamba_chunk_scan_combined_varlen (
@@ -329,13 +326,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
329
326
B ,
330
327
C ,
331
328
chunk_size ,
329
+ cu_seqlens = cu_seqlens .to (torch .int32 ),
330
+ cu_chunk_seqlens = cu_chunk_seqlens ,
331
+ last_chunk_indices = last_chunk_indices ,
332
+ seq_idx = seq_idx_chunks ,
333
+ out = Y ,
332
334
D = None ,
333
- cu_seqlens = cu_seqlens ,
334
- seq_idx = seq_idx ,
335
- chunk_indices = chunk_indices ,
336
- chunk_offsets = chunk_offsets ,
337
335
initial_states = states ,
338
- out = Y ,
339
336
)
340
337
341
338
# just test the last in sequence
@@ -403,9 +400,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
403
400
device = X .device
404
401
405
402
## full seqlen computation
406
- chunk_indices , chunk_offsets = \
407
- _query_start_loc_to_chunk_indices_offsets (
408
- cu_seqlens , chunk_size , cu_seqlens [- 1 ])
403
+ cu_chunk_seqlens , last_chunk_indices , seq_idx_chunks = (
404
+ compute_varlen_chunk_metadata (cu_seqlens , chunk_size ))
409
405
Y_ref = torch .empty_like (X )
410
406
state_ref = mamba_chunk_scan_combined_varlen (
411
407
X ,
@@ -414,13 +410,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
414
410
B ,
415
411
C ,
416
412
chunk_size ,
413
+ cu_seqlens = cu_seqlens .to (torch .int32 ),
414
+ cu_chunk_seqlens = cu_chunk_seqlens ,
415
+ last_chunk_indices = last_chunk_indices ,
416
+ seq_idx = seq_idx_chunks ,
417
+ out = Y_ref ,
417
418
D = None ,
418
- cu_seqlens = cu_seqlens ,
419
- seq_idx = seq_idx ,
420
- chunk_indices = chunk_indices ,
421
- chunk_offsets = chunk_offsets ,
422
419
initial_states = None ,
423
- out = Y_ref ,
424
420
)
425
421
426
422
## chunked seqlen computation
@@ -431,10 +427,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
431
427
torch .cumsum (chunked_seqlens , dim = 0 )
432
428
],
433
429
dim = 0 )
434
- chunked_seq_idx = torch .repeat_interleave (
435
- torch .arange (len (chunked_seqlens ), device = device ),
436
- chunked_seqlens ,
437
- output_size = chunked_cu_seqlens [- 1 ]).to (torch .int32 )
438
430
chunked_input_seq_len = chunked_cu_seqlens [- 1 ]
439
431
X_chunked = torch .zeros_like (X )[:chunked_input_seq_len , ...]
440
432
dt_chunked = torch .zeros_like (dt )[:chunked_input_seq_len , ...]
@@ -450,9 +442,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
450
442
C_chunked [chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (C , i ) # noqa: E501
451
443
# fmt: on
452
444
453
- chunk_indices , chunk_offsets = \
454
- _query_start_loc_to_chunk_indices_offsets (
455
- chunked_cu_seqlens , chunk_size , chunked_cu_seqlens [- 1 ])
445
+ cu_chunk_seqlens , last_chunk_indices , seq_idx_chunks = (
446
+ compute_varlen_chunk_metadata (chunked_cu_seqlens , chunk_size ))
456
447
Y_partial = torch .empty_like (X_chunked )
457
448
partial_state = mamba_chunk_scan_combined_varlen (
458
449
X_chunked ,
@@ -461,13 +452,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
461
452
B_chunked ,
462
453
C_chunked ,
463
454
chunk_size ,
455
+ cu_seqlens = chunked_cu_seqlens .to (torch .int32 ),
456
+ cu_chunk_seqlens = cu_chunk_seqlens ,
457
+ last_chunk_indices = last_chunk_indices ,
458
+ seq_idx = seq_idx_chunks ,
459
+ out = Y_partial ,
464
460
D = None ,
465
- cu_seqlens = chunked_cu_seqlens ,
466
- seq_idx = chunked_seq_idx ,
467
- chunk_indices = chunk_indices ,
468
- chunk_offsets = chunk_offsets ,
469
461
initial_states = None ,
470
- out = Y_partial ,
471
462
)
472
463
473
464
# remaining chunk
@@ -477,10 +468,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
477
468
torch .cumsum (remaining_chunked_seqlens , dim = 0 )
478
469
],
479
470
dim = 0 )
480
- remaining_chunked_seq_idx = torch .repeat_interleave (
481
- torch .arange (len (remaining_chunked_seqlens ), device = device ),
482
- remaining_chunked_seqlens ,
483
- output_size = remaining_chunked_cu_seqlens [- 1 ]).to (torch .int32 )
484
471
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens [- 1 ]
485
472
# fmt: off
486
473
remaining_X_chunked = torch .zeros_like (X )[:remaining_chunked_input_seq_len , ...] # noqa: E501
@@ -509,11 +496,9 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
509
496
assert concat_batch_f (B_chunked , remaining_B_chunked ).equal (B )
510
497
assert concat_batch_f (C_chunked , remaining_C_chunked ).equal (C )
511
498
512
- chunk_indices , chunk_offsets = \
513
- _query_start_loc_to_chunk_indices_offsets (
514
- remaining_chunked_cu_seqlens ,
515
- chunk_size ,
516
- remaining_chunked_cu_seqlens [- 1 ])
499
+ cu_chunk_seqlens , last_chunk_indices , seq_idx_chunks = (
500
+ compute_varlen_chunk_metadata (remaining_chunked_cu_seqlens ,
501
+ chunk_size ))
517
502
518
503
Y_chunked = torch .empty_like (remaining_X_chunked )
519
504
state_chunked = mamba_chunk_scan_combined_varlen (
@@ -523,13 +508,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
523
508
remaining_B_chunked ,
524
509
remaining_C_chunked ,
525
510
chunk_size ,
511
+ cu_seqlens = remaining_chunked_cu_seqlens .to (torch .int32 ),
512
+ cu_chunk_seqlens = cu_chunk_seqlens ,
513
+ last_chunk_indices = last_chunk_indices ,
514
+ seq_idx = seq_idx_chunks ,
515
+ out = Y_chunked ,
526
516
D = None ,
527
- cu_seqlens = remaining_chunked_cu_seqlens ,
528
- seq_idx = remaining_chunked_seq_idx ,
529
- chunk_indices = chunk_indices ,
530
- chunk_offsets = chunk_offsets ,
531
517
initial_states = partial_state ,
532
- out = Y_chunked ,
533
518
)
534
519
Y = concat_batch_f (Y_partial , Y_chunked )
535
520
0 commit comments