1515from einops import rearrange
1616
1717from .kv_caching import KeysValues
18+ from mamba_ssm import Mamba2
1819
1920
2021@dataclass
@@ -239,7 +240,8 @@ def __init__(self, config: TransformerConfig) -> None:
239240
240241 self .ln1 = nn .LayerNorm (config .embed_dim )
241242 self .ln2 = nn .LayerNorm (config .embed_dim )
242- self .attn = SelfAttention (config )
243+ # self.attn = SelfAttention(config)
244+ self .attn = Mamba2 (d_model = config .embed_dim , d_state = 64 , d_conv = 4 , expand = 2 )
243245 self .mlp = nn .Sequential (
244246 nn .Linear (config .embed_dim , 4 * config .embed_dim ),
245247 nn .GELU (approximate = 'tanh' ),
@@ -261,7 +263,8 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None
261263 Returns:
262264 - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim).
263265 """
264- x_attn = self .attn (self .ln1 (x ), past_keys_values , valid_context_lengths , freqs_cis )
266+ # x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths, freqs_cis)
267+ x_attn = self .attn (self .ln1 (x ))
265268 if self .gru_gating :
266269 x = self .gate1 (x , x_attn )
267270 x = self .gate2 (x , self .mlp (self .ln2 (x )))
0 commit comments