|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +import math |
| 3 | +from dataclasses import dataclass, field |
| 4 | +from typing import Optional, Tuple, List, Any |
| 5 | + |
| 6 | +import torch |
| 7 | +import torch.nn as nn |
| 8 | +from torch.nn import functional as F |
| 9 | +from ding.torch_utils.network import GRUGatingUnit # Keep if GRU gating is used outside Block |
| 10 | +from einops import rearrange |
| 11 | +from mamba_ssm import Mamba2 |
| 12 | +from mamba_ssm.utils.generation import InferenceParams |
| 13 | + |
| 14 | +class Mamba(nn.Module): |
| 15 | + """ |
| 16 | + Mamba-based model potentially for UniZero architecture. |
| 17 | + Replaces the Transformer backbone. |
| 18 | +
|
| 19 | + Arguments: |
| 20 | + - config (:obj:`MambaConfig`): Configuration for the Mamba model. |
| 21 | + """ |
| 22 | + |
| 23 | + def __init__(self, config) -> None: |
| 24 | + super().__init__() |
| 25 | + self.config = config |
| 26 | + self.embed_dim = config.embed_dim |
| 27 | + self.drop = nn.Dropout(config.embed_pdrop) |
| 28 | + self.blocks = nn.ModuleList() |
| 29 | + |
| 30 | + for i in range(config.num_layers): |
| 31 | + mamba_block = Mamba2( |
| 32 | + d_model=config.embed_dim, |
| 33 | + d_state=128, |
| 34 | + d_conv=4, |
| 35 | + expand=2, |
| 36 | + headdim=64, |
| 37 | + ngroups=1, |
| 38 | + bias=False, |
| 39 | + conv_bias=True, |
| 40 | + chunk_size=256, |
| 41 | + use_mem_eff_path=True, |
| 42 | + layer_idx=i, |
| 43 | + ) |
| 44 | + self.blocks.append(mamba_block) |
| 45 | + |
| 46 | + self.ln_f = nn.LayerNorm(config.embed_dim) |
| 47 | + |
| 48 | + def _get_device(self): |
| 49 | + return self.ln_f.weight.device |
| 50 | + |
| 51 | + def _get_dtype(self): |
| 52 | + return self.ln_f.weight.dtype |
| 53 | + |
| 54 | + def generate_empty_state(self, |
| 55 | + batch_size: int, |
| 56 | + max_seq_len: Optional[int] = None, |
| 57 | + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: |
| 58 | + """ |
| 59 | + 为所有 Mamba 层分配零初始化的状态张量 (conv_state, ssm_state),用于推理。 |
| 60 | + """ |
| 61 | + _device = self._get_device() |
| 62 | + _dtype = self._get_dtype() |
| 63 | + _max_seq_len = max_tokens if max_seq_len is not None else getattr(self.config, 'max_seq_len', 2048) |
| 64 | + |
| 65 | + all_layer_states = [] |
| 66 | + for mamba_layer in self.blocks: |
| 67 | + conv_state, ssm_state = mamba_layer.allocate_inference_cache( |
| 68 | + batch_size=batch_size, |
| 69 | + max_seqlen=_max_seq_len, |
| 70 | + dtype=_dtype |
| 71 | + ) |
| 72 | + all_layer_states.append((conv_state.to(_device), ssm_state.to(_device))) |
| 73 | + return all_layer_states |
| 74 | + |
| 75 | + |
| 76 | + def forward(self, sequences: torch.Tensor, past_mamba_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
| 77 | + seqlen_offset: Optional[int] = 0) -> torch.Tensor: |
| 78 | + """ |
| 79 | + Forward pass for training or full sequence processing. |
| 80 | +
|
| 81 | + Arguments: |
| 82 | + - sequences (:obj:`torch.Tensor`): Input tensor of shape (B, L, D) or (B*L, D) if seqlen is provided. |
| 83 | + - seqlen (:obj:`Optional[int]`): Sequence length if input is flattened (B*L, D). |
| 84 | + - inference_params (:obj:`Optional[Any]`): If provided, indicates potential step-by-step inference mode |
| 85 | + (though `step` is preferred for that). Mamba2 forward might use it. |
| 86 | +
|
| 87 | + Returns: |
| 88 | + - torch.Tensor: Output tensor, same shape principles as input `sequences`. |
| 89 | + """ |
| 90 | + x = self.drop(sequences) |
| 91 | + current_inference_params = None |
| 92 | + if past_mamba_states is not None: |
| 93 | + batch_size, cur_seq_len, _ = sequences.shape |
| 94 | + current_inference_params = InferenceParams( |
| 95 | + max_seqlen=cur_seq_len + seqlen_offset, |
| 96 | + max_batch_size=batch_size, |
| 97 | + seqlen_offset=seqlen_offset |
| 98 | + ) |
| 99 | + for i in range(self.config.num_layers): |
| 100 | + current_inference_params.key_value_memory_dict[i] = past_mamba_states[i] |
| 101 | + |
| 102 | + for i, block in enumerate(self.blocks): |
| 103 | + x = block(x, inference_params=current_inference_params) |
| 104 | + |
| 105 | + x = self.ln_f(x) |
| 106 | + |
| 107 | + updated_layer_states_list: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None |
| 108 | + if current_inference_params is not None: |
| 109 | + updated_layer_states_list = [] |
| 110 | + for i in range(self.config.num_layers): |
| 111 | + updated_conv_state, updated_ssm_state = current_inference_params.key_value_memory_dict[i] |
| 112 | + updated_layer_states_list.append((updated_conv_state, updated_ssm_state)) |
| 113 | + |
| 114 | + return x, updated_layer_states_list |
| 115 | + |
| 116 | + |
0 commit comments