Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit b2a8cd8

Browse files
committed
Refactor mamba to use the MambaCacheManager
1 parent b9723fe commit b2a8cd8

File tree

3 files changed

+79
-222
lines changed

3 files changed

+79
-222
lines changed

vllm/model_executor/models/jamba.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def __init__(
594594
if not lora_config else lora_config.lora_vocab_padding_size,
595595
)
596596
# Used to track and store by the Mamba cache between steps.
597-
self.mamba_cache = MambaCacheManager(config)
597+
self.mamba_cache: Optional[MambaCacheManager] = None
598598

599599
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
600600
config.vocab_size)
@@ -607,12 +607,19 @@ def forward(self,
607607
attn_metadata: AttentionMetadata,
608608
intermediate_tensors: Optional[IntermediateTensors] = None,
609609
**kwargs):
610-
if not self.mamba_cache.initialized:
610+
if self.mamba_cache is None:
611611
max_batch_size = (_get_graph_batch_size(
612612
self.scheduler_config.max_num_seqs) if self.scheduler_config
613613
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
614-
self.mamba_cache.initialize_tensors(self.lm_head.weight.dtype,
615-
max_batch_size)
614+
615+
layers_type = self.config.layers_block_type
616+
num_mamba_layers = sum(
617+
[layer_type == "mamba" for layer_type in layers_type])
618+
619+
self.mamba_cache = MambaCacheManager(self.lm_head.weight.dtype,
620+
num_mamba_layers,
621+
max_batch_size,
622+
*self._get_mamba_cache_shape())
616623

617624
if "seqlen_agnostic_capture_inputs" not in kwargs:
618625
# We get here only on Prefill/Eager mode runs
@@ -623,6 +630,7 @@ def forward(self,
623630
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
624631
finished_requests_ids = kwargs["finished_requests_ids"]
625632
self.mamba_cache.release_finished_requests(finished_requests_ids)
633+
626634
batch_size = input_ids.shape[0]
627635
if attn_metadata.prefill_metadata:
628636
batch_size = len(request_ids_to_seq_ids)
@@ -637,6 +645,19 @@ def forward(self,
637645
mamba_cache_tensors[1])
638646
return hidden_states
639647

648+
def _get_mamba_cache_shape(self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
649+
world_size = get_tensor_model_parallel_world_size()
650+
hidden_size = self.config.hidden_size
651+
conv_state_shape = (
652+
self.config.mamba_expand * hidden_size // world_size,
653+
self.config.mamba_d_conv,
654+
)
655+
temporal_state_shape = (
656+
self.config.mamba_expand * hidden_size // world_size,
657+
self.config.mamba_d_state,
658+
)
659+
return conv_state_shape, temporal_state_shape
660+
640661
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
641662
return self.mamba_cache.copy_inputs_before_cuda_graphs(
642663
input_buffers, **kwargs)

vllm/model_executor/models/mamba.py

Lines changed: 28 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# coding=utf-8
22
"""PyTorch MAMBA model."""
33
from dataclasses import dataclass
4-
from typing import Dict, Iterable, List, Optional, Tuple
4+
from typing import Iterable, List, Optional, Tuple
55

66
import torch
77
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
@@ -28,6 +28,7 @@
2828
VocabParallelEmbedding)
2929
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3030
from vllm.model_executor.models.interfaces import HasInnerState
31+
from vllm.model_executor.models.mamba_cache import MambaCacheManager
3132
from vllm.model_executor.sampling_metadata import SamplingMetadata
3233
from vllm.model_executor.utils import set_weight_attrs
3334
from vllm.sequence import IntermediateTensors, SamplerOutput
@@ -420,15 +421,10 @@ def __init__(
420421
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
421422

422423
self.lm_head = self.backbone.embeddings
423-
# Current step used indices
424-
self.current_indices: List[int] = []
424+
425425
# Used to track and store by the Mamba cache between steps.
426-
self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
427-
# Used as an input_buffer for the CUDA graph runs.
428-
self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple()
429-
# Maps between the request id and a dict that maps between the seq_id
430-
# and its index inside the self.mamba_cache
431-
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
426+
self.mamba_cache: Optional[MambaCacheManager] = None
427+
432428
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
433429
config.vocab_size)
434430
self.sampler = Sampler()
@@ -440,8 +436,14 @@ def forward(self,
440436
attn_metadata: AttentionMetadata,
441437
intermediate_tensors: Optional[IntermediateTensors] = None,
442438
**kwargs):
443-
if not self.mamba_cache:
444-
self._prepare_mamba_cache()
439+
if self.mamba_cache is None:
440+
max_batch_size = (_get_graph_batch_size(
441+
self.scheduler_config.max_num_seqs) if self.scheduler_config
442+
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
443+
self.mamba_cache = MambaCacheManager(self.lm_head.weight.dtype,
444+
self.config.num_hidden_layers,
445+
max_batch_size,
446+
*self._get_mamba_cache_shape())
445447

446448
if "seqlen_agnostic_capture_inputs" not in kwargs:
447449
# We get here only on Prefill/Eager mode runs
@@ -451,169 +453,25 @@ def forward(self,
451453

452454
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
453455
finished_requests_ids = kwargs["finished_requests_ids"]
456+
self.mamba_cache.release_finished_requests(finished_requests_ids)
457+
454458
batch_size = input_ids.shape[0]
455459
if attn_metadata.prefill_metadata:
456460
batch_size = len(request_ids_to_seq_ids)
457-
(
458-
current_seqlen_agnostic_cache,
459-
indices,
460-
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
461-
batch_size,
462-
finished_requests_ids)
463-
finished_requests_ids = kwargs["finished_requests_ids"]
464-
self._release_mamba_cache(finished_requests_ids)
461+
mamba_cache_tensors = self.mamba_cache.prepare_current_run_state(
462+
request_ids_to_seq_ids, batch_size, finished_requests_ids)
463+
465464
else:
466465
# CUDA graph capturing runs
467-
current_seqlen_agnostic_cache, indices = (
468-
kwargs["seqlen_agnostic_capture_inputs"],
469-
[],
470-
)
471-
self.current_indices = indices
466+
mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"]
472467

473468
hidden_states = self.backbone(input_ids, positions, kv_caches,
474-
attn_metadata,
475-
current_seqlen_agnostic_cache[0],
476-
current_seqlen_agnostic_cache[1])
477-
478-
if "seqlen_agnostic_capture_inputs" not in kwargs:
479-
self._copy_mamba_cache_by_indices(self.current_indices,
480-
current_seqlen_agnostic_cache)
469+
attn_metadata, mamba_cache_tensors[0],
470+
mamba_cache_tensors[1])
481471

482472
return hidden_states
483473

484-
def _copy_mamba_cache_by_indices(
485-
self, indices: List[int],
486-
current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]):
487-
for i, offset in enumerate(indices):
488-
self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache)
489-
490-
def _copy_mamba_cache(self, index_to: int, index_from: int,
491-
from_buffer: Tuple[torch.Tensor, torch.Tensor]):
492-
assert len(self.mamba_cache) > 0
493-
for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer):
494-
cache_t[:, index_to].copy_(from_buffer_t[:, index_from],
495-
non_blocking=True)
496-
497-
def _assign_seq_id_to_mamba_cache(self, cur_rid: str,
498-
seqs_id: List[int]) -> List[int]:
499-
indices_for_current_run = []
500-
for seq_id in seqs_id:
501-
if cur_rid not in self.mamba_cache_indices_mapping:
502-
self.mamba_cache_indices_mapping[cur_rid] = {}
503-
first_free_index = self._first_free_index_in_mamba_cache()
504-
self.mamba_cache_indices_mapping[cur_rid][
505-
seq_id] = first_free_index
506-
index_for_current_run = first_free_index
507-
## case of decoding n>1, copy prefill cache to decoding indices
508-
elif seq_id not in (seq_ids2indices :=
509-
self.mamba_cache_indices_mapping[cur_rid]):
510-
first_free_index = self._first_free_index_in_mamba_cache()
511-
index_exist = list(seq_ids2indices.values())[0]
512-
self._copy_mamba_cache(index_from=index_exist,
513-
index_to=first_free_index,
514-
from_buffer=self.mamba_cache)
515-
self.mamba_cache_indices_mapping[cur_rid][
516-
seq_id] = first_free_index
517-
index_for_current_run = first_free_index
518-
else:
519-
index_for_current_run = self.mamba_cache_indices_mapping[
520-
cur_rid][seq_id]
521-
522-
indices_for_current_run.append(index_for_current_run)
523-
return indices_for_current_run
524-
525-
def _prepare_current_run_mamba_cache(
526-
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int,
527-
finished_requests_ids: List[str]
528-
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
529-
indices_for_current_run = []
530-
for request_id, seqs_id in request_ids_to_seq_ids.items():
531-
if request_id in finished_requests_ids:
532-
# Do not allocate cache for requests that run
533-
# and finish right after
534-
continue
535-
indices_for_current_run += self._assign_seq_id_to_mamba_cache(
536-
request_id, seqs_id)
537-
## Pad the batch in case of running batch that was not captured via CG
538-
padded_indices = indices_for_current_run.copy()
539-
pad_index = self._first_free_index_in_mamba_cache()
540-
541-
for _ in range(batch_size - len(indices_for_current_run)):
542-
padded_indices.append(pad_index)
543-
544-
conv_state = self.mamba_cache[0][:, padded_indices]
545-
temporal_state = self.mamba_cache[1][:, padded_indices]
546-
547-
return (conv_state, temporal_state), indices_for_current_run
548-
549-
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
550-
"""
551-
Copy the relevant Mamba cache into the CUDA graph input buffer
552-
that was provided during the capture runs
553-
(MambaForCausalLM.mamba_gc_cache_buffer).
554-
"""
555-
assert all(
556-
key in kwargs
557-
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
558-
finished_requests_ids = kwargs["finished_requests_ids"]
559-
self._release_mamba_cache(finished_requests_ids)
560-
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
561-
cg_batch_size = input_buffers['input_ids'].shape[0]
562-
(
563-
current_mamba_cache,
564-
indices,
565-
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
566-
cg_batch_size,
567-
finished_requests_ids)
568-
self.current_indices = indices
569-
finished_requests_ids = kwargs["finished_requests_ids"]
570-
self._release_mamba_cache(finished_requests_ids)
571-
572-
for input_buffer, current_cache_buffer in zip(
573-
input_buffers["seqlen_agnostic_capture_inputs"],
574-
current_mamba_cache):
575-
input_buffer.copy_(current_cache_buffer, non_blocking=True)
576-
577-
def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs):
578-
"""
579-
Copy the relevant Mamba cache from the CUDA graph input_buffers
580-
back to the MambaForCausalLM.mamba_cache after CUDA
581-
graph replay run is done.
582-
"""
583-
self._copy_mamba_cache_by_indices(
584-
self.current_indices,
585-
input_buffers["seqlen_agnostic_capture_inputs"])
586-
587-
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
588-
"""
589-
Provide the CUDA graph capture runs with a buffer in adjusted size.
590-
The buffer is used to maintain the Mamba Cache during the CUDA graph
591-
replay runs.
592-
"""
593-
return tuple(buffer[:, :batch_size]
594-
for buffer in self.mamba_gc_cache_buffer)
595-
596-
def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
597-
for req_id in finished_seq_groups_req_ids:
598-
if req_id in self.mamba_cache_indices_mapping:
599-
self.mamba_cache_indices_mapping.pop(req_id)
600-
601-
def _first_free_index_in_mamba_cache(self) -> int:
602-
if self.mamba_cache:
603-
max_possible_batch_size = self.mamba_cache[0].shape[1]
604-
occupied = [
605-
id for seq_ids in self.mamba_cache_indices_mapping.values()
606-
for id in seq_ids.values()
607-
]
608-
first_free_index = [
609-
i not in occupied for i in range(max_possible_batch_size)
610-
].index(True)
611-
return first_free_index
612-
return 0
613-
614-
def _get_mamba_cache_shape(
615-
self
616-
) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]:
474+
def _get_mamba_cache_shape(self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
617475
world_size = get_tensor_model_parallel_world_size()
618476
conv_state_shape = (
619477
self.config.intermediate_size // world_size,
@@ -625,25 +483,12 @@ def _get_mamba_cache_shape(
625483
)
626484
return conv_state_shape, temporal_state_shape
627485

628-
def _prepare_mamba_cache(self):
629-
dtype = self.lm_head.weight.dtype
630-
num_mamba_layers = self.config.num_hidden_layers
631-
max_batch_size = (_get_graph_batch_size(
632-
self.scheduler_config.max_num_seqs) if self.scheduler_config else
633-
max(_BATCH_SIZES_TO_CAPTURE)) + 10
634-
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
635-
assert conv_state_shape is not None and temporal_state_shape is not None
636-
637-
for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
638-
buffer = (torch.empty(size=(num_mamba_layers, max_batch_size) +
639-
conv_state_shape,
640-
dtype=dtype,
641-
device="cuda"),
642-
torch.empty(size=(num_mamba_layers, max_batch_size) +
643-
temporal_state_shape,
644-
dtype=dtype,
645-
device="cuda"))
646-
setattr(self, buffername, buffer)
486+
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
487+
return self.mamba_cache.copy_inputs_before_cuda_graphs(
488+
input_buffers, **kwargs)
489+
490+
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
491+
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
647492

648493
def compute_logits(self, hidden_states: torch.Tensor,
649494
sampling_metadata: SamplingMetadata) -> torch.Tensor:

0 commit comments

Comments
 (0)