1616
1717@serializable ("bayesflow.wrappers" )
1818class MambaBlock (keras .Layer ):
19+ """
20+ Wraps the original Mamba module from, with added functionality for bidirectional processing:
21+ https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
22+
23+ Copyright (c) 2023, Tri Dao, Albert Gu.
24+ """
25+
1926 def __init__ (
2027 self ,
2128 state_dim : int ,
2229 conv_dim : int ,
2330 feature_dim : int = 16 ,
24- expand : int = 1 ,
25- dt_min = 0.001 ,
26- dt_max = 0.1 ,
31+ expand : int = 2 ,
32+ bidirectional : bool = True ,
33+ dt_min : float = 0.001 ,
34+ dt_max : float = 0.1 ,
2735 device : str = "cuda" ,
2836 ** kwargs ,
2937 ):
@@ -58,6 +66,8 @@ def __init__(
5866 if keras .backend .backend () != "torch" :
5967 raise EnvironmentError ("Mamba is only available using torch backend." )
6068
69+ self .bidirectional = bidirectional
70+
6171 self .mamba = Mamba (
6272 d_model = feature_dim , d_state = state_dim , d_conv = conv_dim , expand = expand , dt_min = dt_min , dt_max = dt_max
6373 ).to (device )
@@ -70,6 +80,13 @@ def __init__(
7080 self .layer_norm = keras .layers .LayerNormalization ()
7181
7282 def call (self , x : Tensor , training : bool = False , ** kwargs ) -> Tensor :
83+ out_forward = self ._call (x , training = training , ** kwargs )
84+ if self .bidirectional :
85+ out_backward = self ._call (keras .ops .flip (x , axis = 1 ), training = training , ** kwargs )
86+ return keras .ops .concatenate ((out_forward , out_backward ), axis = - 1 )
87+ return out_forward
88+
89+ def _call (self , x : Tensor , training : bool = False , ** kwargs ) -> Tensor :
7390 x = self .input_projector (x )
7491 h = self .mamba (x )
7592 out = self .layer_norm (h + x , training = training , ** kwargs )
@@ -84,7 +101,7 @@ def build(self, input_shape):
84101@serializable ("bayesflow.wrappers" )
85102class MambaSSM (SummaryNetwork ):
86103 """
87- Wraps the original Mamba module from:
104+ Wraps a sequence of Mamba modules using the simple Mamba module from:
88105 https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
89106
90107 Copyright (c) 2023, Tri Dao, Albert Gu.
@@ -94,9 +111,10 @@ def __init__(
94111 self ,
95112 summary_dim : int = 16 ,
96113 feature_dims : Sequence [int ] = (64 , 64 ),
97- state_dims : Sequence [int ] = (128 , 128 ),
114+ state_dims : Sequence [int ] = (64 , 64 ),
98115 conv_dims : Sequence [int ] = (64 , 64 ),
99116 expand_dims : Sequence [int ] = (2 , 2 ),
117+ bidirectional : bool = True ,
100118 dt_min : float = 0.001 ,
101119 dt_max : float = 0.1 ,
102120 dropout : float = 0.05 ,
@@ -143,7 +161,8 @@ def __init__(
143161
144162 self .mamba_blocks = []
145163 for feature_dim , state_dim , conv_dim , expand in zip (feature_dims , state_dims , conv_dims , expand_dims ):
146- self .mamba_blocks .append (MambaBlock (feature_dim , state_dim , conv_dim , expand , dt_min , dt_max , device ))
164+ mamba = MambaBlock (feature_dim , state_dim , conv_dim , expand , bidirectional , dt_min , dt_max , device )
165+ self .mamba_blocks .append (mamba )
147166
148167 self .pooling_layer = keras .layers .GlobalAveragePooling1D ()
149168 self .dropout = keras .layers .Dropout (dropout )
0 commit comments