33import keras
44# from keras.saving import register_keras_serializable as serializable
55try :
6- from mamba_ssm import Mamba , Mamba2
6+ from mamba_ssm import Mamba
77except ImportError :
88 print ("Mamba Wrapper is not available" )
99
1414class MambaSSM (SummaryNetwork ):
1515 def __init__ (
1616 self ,
17- ssm_dim : int ,
17+ feature_dim : int ,
1818 summary_dim : int = 8 ,
1919 mamba_blocks : int = 2 ,
2020 state_dim : int = 16 ,
@@ -24,19 +24,19 @@ def __init__(
2424 dt_max : float = 0.1 ,
2525 pooling : bool = True ,
2626 dropout : int | float | None = 0.5 ,
27- mamba_version : int = 2 ,
2827 device : str = "cuda" ,
29- d_ssm : int = 1 ,
3028 ** kwargs
3129 ):
3230 """
3331 A time-series summarization network using Mamba-based State Space Models (SSM).
3432 This model processes sequential input data using the Mamba SSM layer, followed by
3533 optional pooling, dropout, and a dense layer for extracting summary statistics.
34+
35+ Mamba2 support currently unabailble due to stability issues
3636
3737 Parameters
3838 ----------
39- ssm_dim : int
39+ feature_dim : int
4040 The dimensionality of the Mamba SSM model.
4141 summary_dim : int, optional
4242 The output dimensionality of the summary statistics layer (default is 8).
@@ -56,8 +56,6 @@ def __init__(
5656 Whether to apply global average pooling (default is True).
5757 dropout : int, float, or None, optional
5858 Dropout rate applied before the summary layer (default is 0.5).
59- mamba_version : int, optional
60- The version of Mamba to apply (default is 2).
6159 device : str, optional
6260 The computing device. Currently, only "cuda" is supported (default is "cuda").
6361 **kwargs : dict
@@ -68,15 +66,7 @@ def __init__(
6866 if device != "cuda" :
6967 raise NotImplementedError ("MambaSSM currently only supports cuda" )
7068
71- if mamba_version == 1 :
72- self .mamba_blocks = [Mamba (d_model = ssm_dim , d_state = state_dim , d_conv = conv_dim , expand = expand , dt_min = dt_min , dt_max = dt_max ).to (device ) for _ in range (mamba_blocks )]
73- elif mamba_version == 2 :
74- self .mamba_blocks = [Mamba2 (d_model = ssm_dim , d_state = state_dim , d_conv = conv_dim , expand = expand , dt_min = dt_min , dt_max = dt_max , d_ssm = d_ssm ).to (device ) for _ in range (mamba_blocks )]
75- else :
76- raise NotImplementedError ("Mamba version must be 1 or 2" )
77-
78-
79-
69+ self .mamba_blocks = [Mamba (d_model = feature_dim , d_state = state_dim , d_conv = conv_dim , expand = expand , dt_min = dt_min , dt_max = dt_max ).to (device ) for _ in range (mamba_blocks )]
8070
8171 self .layernorm = keras .layers .LayerNormalization (axis = - 1 )
8272
0 commit comments