Skip to content

Commit 85adf72

Browse files
committed
add mamba2 as a unizero backbone option
1 parent 42d4e34 commit 85adf72

File tree

4 files changed

+1377
-4
lines changed

4 files changed

+1377
-4
lines changed

lzero/model/unizero_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \
1010
HFLanguageRepresentationNetwork
1111
from .unizero_world_models.tokenizer import Tokenizer
12-
from .unizero_world_models.world_model import WorldModel
12+
# from .unizero_world_models.world_model import WorldModel
13+
from .unzero_world_models.world_model_mamba2 import WorldModel
14+
from ding.utils import ENV_REGISTRY, set_pkg_seed, get_rank, get_world_size
1315

1416

1517
# use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document.
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# -*- coding: utf-8 -*-
2+
import math
3+
from dataclasses import dataclass, field
4+
from typing import Optional, Tuple, List, Any
5+
6+
import torch
7+
import torch.nn as nn
8+
from torch.nn import functional as F
9+
from ding.torch_utils.network import GRUGatingUnit # Keep if GRU gating is used outside Block
10+
from einops import rearrange
11+
from mamba_ssm import Mamba2
12+
from mamba_ssm.utils.generation import InferenceParams
13+
14+
class Mamba(nn.Module):
15+
"""
16+
Mamba-based model potentially for UniZero architecture.
17+
Replaces the Transformer backbone.
18+
19+
Arguments:
20+
- config (:obj:`MambaConfig`): Configuration for the Mamba model.
21+
"""
22+
23+
def __init__(self, config) -> None:
24+
super().__init__()
25+
self.config = config
26+
self.embed_dim = config.embed_dim
27+
self.drop = nn.Dropout(config.embed_pdrop)
28+
self.blocks = nn.ModuleList()
29+
30+
for i in range(config.num_layers):
31+
mamba_block = Mamba2(
32+
d_model=config.embed_dim,
33+
d_state=128,
34+
d_conv=4,
35+
expand=2,
36+
headdim=64,
37+
ngroups=1,
38+
bias=False,
39+
conv_bias=True,
40+
chunk_size=256,
41+
use_mem_eff_path=True,
42+
layer_idx=i,
43+
)
44+
self.blocks.append(mamba_block)
45+
46+
self.ln_f = nn.LayerNorm(config.embed_dim)
47+
48+
def _get_device(self):
49+
return self.ln_f.weight.device
50+
51+
def _get_dtype(self):
52+
return self.ln_f.weight.dtype
53+
54+
def generate_empty_state(self,
55+
batch_size: int,
56+
max_seq_len: Optional[int] = None,
57+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
58+
"""
59+
为所有 Mamba 层分配零初始化的状态张量 (conv_state, ssm_state),用于推理。
60+
"""
61+
_device = self._get_device()
62+
_dtype = self._get_dtype()
63+
_max_seq_len = max_tokens if max_seq_len is not None else getattr(self.config, 'max_seq_len', 2048)
64+
65+
all_layer_states = []
66+
for mamba_layer in self.blocks:
67+
conv_state, ssm_state = mamba_layer.allocate_inference_cache(
68+
batch_size=batch_size,
69+
max_seqlen=_max_seq_len,
70+
dtype=_dtype
71+
)
72+
all_layer_states.append((conv_state.to(_device), ssm_state.to(_device)))
73+
return all_layer_states
74+
75+
76+
def forward(self, sequences: torch.Tensor, past_mamba_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
77+
seqlen_offset: Optional[int] = 0) -> torch.Tensor:
78+
"""
79+
Forward pass for training or full sequence processing.
80+
81+
Arguments:
82+
- sequences (:obj:`torch.Tensor`): Input tensor of shape (B, L, D) or (B*L, D) if seqlen is provided.
83+
- seqlen (:obj:`Optional[int]`): Sequence length if input is flattened (B*L, D).
84+
- inference_params (:obj:`Optional[Any]`): If provided, indicates potential step-by-step inference mode
85+
(though `step` is preferred for that). Mamba2 forward might use it.
86+
87+
Returns:
88+
- torch.Tensor: Output tensor, same shape principles as input `sequences`.
89+
"""
90+
x = self.drop(sequences)
91+
current_inference_params = None
92+
if past_mamba_states is not None:
93+
batch_size, cur_seq_len, _ = sequences.shape
94+
current_inference_params = InferenceParams(
95+
max_seqlen=cur_seq_len + seqlen_offset,
96+
max_batch_size=batch_size,
97+
seqlen_offset=seqlen_offset
98+
)
99+
for i in range(self.config.num_layers):
100+
current_inference_params.key_value_memory_dict[i] = past_mamba_states[i]
101+
102+
for i, block in enumerate(self.blocks):
103+
x = block(x, inference_params=current_inference_params)
104+
105+
x = self.ln_f(x)
106+
107+
updated_layer_states_list: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None
108+
if current_inference_params is not None:
109+
updated_layer_states_list = []
110+
for i in range(self.config.num_layers):
111+
updated_conv_state, updated_ssm_state = current_inference_params.key_value_memory_dict[i]
112+
updated_layer_states_list.append((updated_conv_state, updated_ssm_state))
113+
114+
return x, updated_layer_states_list
115+
116+

0 commit comments

Comments
 (0)