11# coding=utf-8
22"""PyTorch MAMBA model."""
33from dataclasses import dataclass
4- from typing import Dict , Iterable , List , Optional , Tuple
4+ from typing import Iterable , List , Optional , Tuple
55
66import torch
77from causal_conv1d import causal_conv1d_fn , causal_conv1d_update
2828 VocabParallelEmbedding )
2929from vllm .model_executor .model_loader .weight_utils import default_weight_loader
3030from vllm .model_executor .models .interfaces import HasInnerState
31+ from vllm .model_executor .models .mamba_cache import MambaCacheManager
3132from vllm .model_executor .sampling_metadata import SamplingMetadata
3233from vllm .model_executor .utils import set_weight_attrs
3334from vllm .sequence import IntermediateTensors , SamplerOutput
@@ -420,15 +421,10 @@ def __init__(
420421 self .unpadded_vocab_size += lora_config .lora_extra_vocab_size
421422
422423 self .lm_head = self .backbone .embeddings
423- # Current step used indices
424- self .current_indices : List [int ] = []
424+
425425 # Used to track and store by the Mamba cache between steps.
426- self .mamba_cache : Tuple [torch .Tensor , torch .Tensor ] = tuple ()
427- # Used as an input_buffer for the CUDA graph runs.
428- self .mamba_gc_cache_buffer : Tuple [torch .Tensor , torch .Tensor ] = tuple ()
429- # Maps between the request id and a dict that maps between the seq_id
430- # and its index inside the self.mamba_cache
431- self .mamba_cache_indices_mapping : Dict [str , Dict [int , int ]] = {}
426+ self .mamba_cache : Optional [MambaCacheManager ] = None
427+
432428 self .logits_processor = LogitsProcessor (self .unpadded_vocab_size ,
433429 config .vocab_size )
434430 self .sampler = Sampler ()
@@ -440,8 +436,14 @@ def forward(self,
440436 attn_metadata : AttentionMetadata ,
441437 intermediate_tensors : Optional [IntermediateTensors ] = None ,
442438 ** kwargs ):
443- if not self .mamba_cache :
444- self ._prepare_mamba_cache ()
439+ if self .mamba_cache is None :
440+ max_batch_size = (_get_graph_batch_size (
441+ self .scheduler_config .max_num_seqs ) if self .scheduler_config
442+ else max (_BATCH_SIZES_TO_CAPTURE ) + 2 )
443+ self .mamba_cache = MambaCacheManager (self .lm_head .weight .dtype ,
444+ self .config .num_hidden_layers ,
445+ max_batch_size ,
446+ * self ._get_mamba_cache_shape ())
445447
446448 if "seqlen_agnostic_capture_inputs" not in kwargs :
447449 # We get here only on Prefill/Eager mode runs
@@ -451,169 +453,25 @@ def forward(self,
451453
452454 request_ids_to_seq_ids = kwargs ["request_ids_to_seq_ids" ]
453455 finished_requests_ids = kwargs ["finished_requests_ids" ]
456+ self .mamba_cache .release_finished_requests (finished_requests_ids )
457+
454458 batch_size = input_ids .shape [0 ]
455459 if attn_metadata .prefill_metadata :
456460 batch_size = len (request_ids_to_seq_ids )
457- (
458- current_seqlen_agnostic_cache ,
459- indices ,
460- ) = self ._prepare_current_run_mamba_cache (request_ids_to_seq_ids ,
461- batch_size ,
462- finished_requests_ids )
463- finished_requests_ids = kwargs ["finished_requests_ids" ]
464- self ._release_mamba_cache (finished_requests_ids )
461+ mamba_cache_tensors = self .mamba_cache .prepare_current_run_state (
462+ request_ids_to_seq_ids , batch_size , finished_requests_ids )
463+
465464 else :
466465 # CUDA graph capturing runs
467- current_seqlen_agnostic_cache , indices = (
468- kwargs ["seqlen_agnostic_capture_inputs" ],
469- [],
470- )
471- self .current_indices = indices
466+ mamba_cache_tensors = kwargs ["seqlen_agnostic_capture_inputs" ]
472467
473468 hidden_states = self .backbone (input_ids , positions , kv_caches ,
474- attn_metadata ,
475- current_seqlen_agnostic_cache [0 ],
476- current_seqlen_agnostic_cache [1 ])
477-
478- if "seqlen_agnostic_capture_inputs" not in kwargs :
479- self ._copy_mamba_cache_by_indices (self .current_indices ,
480- current_seqlen_agnostic_cache )
469+ attn_metadata , mamba_cache_tensors [0 ],
470+ mamba_cache_tensors [1 ])
481471
482472 return hidden_states
483473
484- def _copy_mamba_cache_by_indices (
485- self , indices : List [int ],
486- current_seqlen_agnostic_cache : Tuple [torch .Tensor , torch .Tensor ]):
487- for i , offset in enumerate (indices ):
488- self ._copy_mamba_cache (offset , i , current_seqlen_agnostic_cache )
489-
490- def _copy_mamba_cache (self , index_to : int , index_from : int ,
491- from_buffer : Tuple [torch .Tensor , torch .Tensor ]):
492- assert len (self .mamba_cache ) > 0
493- for (cache_t , from_buffer_t ) in zip (self .mamba_cache , from_buffer ):
494- cache_t [:, index_to ].copy_ (from_buffer_t [:, index_from ],
495- non_blocking = True )
496-
497- def _assign_seq_id_to_mamba_cache (self , cur_rid : str ,
498- seqs_id : List [int ]) -> List [int ]:
499- indices_for_current_run = []
500- for seq_id in seqs_id :
501- if cur_rid not in self .mamba_cache_indices_mapping :
502- self .mamba_cache_indices_mapping [cur_rid ] = {}
503- first_free_index = self ._first_free_index_in_mamba_cache ()
504- self .mamba_cache_indices_mapping [cur_rid ][
505- seq_id ] = first_free_index
506- index_for_current_run = first_free_index
507- ## case of decoding n>1, copy prefill cache to decoding indices
508- elif seq_id not in (seq_ids2indices :=
509- self .mamba_cache_indices_mapping [cur_rid ]):
510- first_free_index = self ._first_free_index_in_mamba_cache ()
511- index_exist = list (seq_ids2indices .values ())[0 ]
512- self ._copy_mamba_cache (index_from = index_exist ,
513- index_to = first_free_index ,
514- from_buffer = self .mamba_cache )
515- self .mamba_cache_indices_mapping [cur_rid ][
516- seq_id ] = first_free_index
517- index_for_current_run = first_free_index
518- else :
519- index_for_current_run = self .mamba_cache_indices_mapping [
520- cur_rid ][seq_id ]
521-
522- indices_for_current_run .append (index_for_current_run )
523- return indices_for_current_run
524-
525- def _prepare_current_run_mamba_cache (
526- self , request_ids_to_seq_ids : Dict [str , list [int ]], batch_size : int ,
527- finished_requests_ids : List [str ]
528- ) -> Tuple [Tuple [torch .Tensor , torch .Tensor ], List [int ]]:
529- indices_for_current_run = []
530- for request_id , seqs_id in request_ids_to_seq_ids .items ():
531- if request_id in finished_requests_ids :
532- # Do not allocate cache for requests that run
533- # and finish right after
534- continue
535- indices_for_current_run += self ._assign_seq_id_to_mamba_cache (
536- request_id , seqs_id )
537- ## Pad the batch in case of running batch that was not captured via CG
538- padded_indices = indices_for_current_run .copy ()
539- pad_index = self ._first_free_index_in_mamba_cache ()
540-
541- for _ in range (batch_size - len (indices_for_current_run )):
542- padded_indices .append (pad_index )
543-
544- conv_state = self .mamba_cache [0 ][:, padded_indices ]
545- temporal_state = self .mamba_cache [1 ][:, padded_indices ]
546-
547- return (conv_state , temporal_state ), indices_for_current_run
548-
549- def copy_inputs_before_cuda_graphs (self , input_buffers , ** kwargs ):
550- """
551- Copy the relevant Mamba cache into the CUDA graph input buffer
552- that was provided during the capture runs
553- (MambaForCausalLM.mamba_gc_cache_buffer).
554- """
555- assert all (
556- key in kwargs
557- for key in ["request_ids_to_seq_ids" , "finished_requests_ids" ])
558- finished_requests_ids = kwargs ["finished_requests_ids" ]
559- self ._release_mamba_cache (finished_requests_ids )
560- request_ids_to_seq_ids = kwargs ["request_ids_to_seq_ids" ]
561- cg_batch_size = input_buffers ['input_ids' ].shape [0 ]
562- (
563- current_mamba_cache ,
564- indices ,
565- ) = self ._prepare_current_run_mamba_cache (request_ids_to_seq_ids ,
566- cg_batch_size ,
567- finished_requests_ids )
568- self .current_indices = indices
569- finished_requests_ids = kwargs ["finished_requests_ids" ]
570- self ._release_mamba_cache (finished_requests_ids )
571-
572- for input_buffer , current_cache_buffer in zip (
573- input_buffers ["seqlen_agnostic_capture_inputs" ],
574- current_mamba_cache ):
575- input_buffer .copy_ (current_cache_buffer , non_blocking = True )
576-
577- def copy_outputs_after_cuda_graphs (self , input_buffers , ** kwargs ):
578- """
579- Copy the relevant Mamba cache from the CUDA graph input_buffers
580- back to the MambaForCausalLM.mamba_cache after CUDA
581- graph replay run is done.
582- """
583- self ._copy_mamba_cache_by_indices (
584- self .current_indices ,
585- input_buffers ["seqlen_agnostic_capture_inputs" ])
586-
587- def get_seqlen_agnostic_capture_inputs (self , batch_size : int ):
588- """
589- Provide the CUDA graph capture runs with a buffer in adjusted size.
590- The buffer is used to maintain the Mamba Cache during the CUDA graph
591- replay runs.
592- """
593- return tuple (buffer [:, :batch_size ]
594- for buffer in self .mamba_gc_cache_buffer )
595-
596- def _release_mamba_cache (self , finished_seq_groups_req_ids : List [str ]):
597- for req_id in finished_seq_groups_req_ids :
598- if req_id in self .mamba_cache_indices_mapping :
599- self .mamba_cache_indices_mapping .pop (req_id )
600-
601- def _first_free_index_in_mamba_cache (self ) -> int :
602- if self .mamba_cache :
603- max_possible_batch_size = self .mamba_cache [0 ].shape [1 ]
604- occupied = [
605- id for seq_ids in self .mamba_cache_indices_mapping .values ()
606- for id in seq_ids .values ()
607- ]
608- first_free_index = [
609- i not in occupied for i in range (max_possible_batch_size )
610- ].index (True )
611- return first_free_index
612- return 0
613-
614- def _get_mamba_cache_shape (
615- self
616- ) -> Tuple [Optional [Tuple [int , int ]], Optional [Tuple [int , int ]]]:
474+ def _get_mamba_cache_shape (self ) -> Tuple [Tuple [int , int ], Tuple [int , int ]]:
617475 world_size = get_tensor_model_parallel_world_size ()
618476 conv_state_shape = (
619477 self .config .intermediate_size // world_size ,
@@ -625,25 +483,12 @@ def _get_mamba_cache_shape(
625483 )
626484 return conv_state_shape , temporal_state_shape
627485
628- def _prepare_mamba_cache (self ):
629- dtype = self .lm_head .weight .dtype
630- num_mamba_layers = self .config .num_hidden_layers
631- max_batch_size = (_get_graph_batch_size (
632- self .scheduler_config .max_num_seqs ) if self .scheduler_config else
633- max (_BATCH_SIZES_TO_CAPTURE )) + 10
634- conv_state_shape , temporal_state_shape = self ._get_mamba_cache_shape ()
635- assert conv_state_shape is not None and temporal_state_shape is not None
636-
637- for buffername in ["mamba_cache" , "mamba_gc_cache_buffer" ]:
638- buffer = (torch .empty (size = (num_mamba_layers , max_batch_size ) +
639- conv_state_shape ,
640- dtype = dtype ,
641- device = "cuda" ),
642- torch .empty (size = (num_mamba_layers , max_batch_size ) +
643- temporal_state_shape ,
644- dtype = dtype ,
645- device = "cuda" ))
646- setattr (self , buffername , buffer )
486+ def copy_inputs_before_cuda_graphs (self , input_buffers , ** kwargs ):
487+ return self .mamba_cache .copy_inputs_before_cuda_graphs (
488+ input_buffers , ** kwargs )
489+
490+ def get_seqlen_agnostic_capture_inputs (self , batch_size : int ):
491+ return self .mamba_cache .get_seqlen_agnostic_capture_inputs (batch_size )
647492
648493 def compute_logits (self , hidden_states : torch .Tensor ,
649494 sampling_metadata : SamplingMetadata ) -> torch .Tensor :
0 commit comments