Skip to content

Commit 42d4e34

Browse files
committed
feature: add mamba2 to replace self attention
1 parent d9957ef commit 42d4e34

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

lzero/model/unizero_world_models/transformer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from einops import rearrange
1616

1717
from .kv_caching import KeysValues
18+
from mamba_ssm import Mamba2
1819

1920

2021
@dataclass
@@ -239,7 +240,8 @@ def __init__(self, config: TransformerConfig) -> None:
239240

240241
self.ln1 = nn.LayerNorm(config.embed_dim)
241242
self.ln2 = nn.LayerNorm(config.embed_dim)
242-
self.attn = SelfAttention(config)
243+
# self.attn = SelfAttention(config)
244+
self.attn = Mamba2(d_model=config.embed_dim, d_state=64, d_conv=4, expand=2)
243245
self.mlp = nn.Sequential(
244246
nn.Linear(config.embed_dim, 4 * config.embed_dim),
245247
nn.GELU(approximate='tanh'),
@@ -261,7 +263,8 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None
261263
Returns:
262264
- torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim).
263265
"""
264-
x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths, freqs_cis)
266+
# x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths, freqs_cis)
267+
x_attn = self.attn(self.ln1(x))
265268
if self.gru_gating:
266269
x = self.gate1(x, x_attn)
267270
x = self.gate2(x, self.mlp(self.ln2(x)))

0 commit comments

Comments
 (0)