11import keras
2+
23# from keras.saving import register_keras_serializable as serializable
34try :
45 from mamba_ssm import Mamba
78
89from ...networks .summary_network import SummaryNetwork
910
11+
1012# @serializable(package="bayesflow.wrappers")
1113class MambaSSM (SummaryNetwork ):
1214 def __init__ (
@@ -22,13 +24,13 @@ def __init__(
2224 pooling : bool = True ,
2325 dropout : int | float | None = 0.5 ,
2426 device : str = "cuda" ,
25- ** kwargs
27+ ** kwargs ,
2628 ):
2729 """
2830 A time-series summarization network using Mamba-based State Space Models (SSM).
29- This model processes sequential input data using the Mamba SSM layer, followed by
31+ This model processes sequential input data using the Mamba SSM layer, followed by
3032 optional pooling, dropout, and a dense layer for extracting summary statistics.
31-
33+
3234 Mamba2 support currently unabailble due to stability issues
3335
3436 Parameters
@@ -58,31 +60,26 @@ def __init__(
5860 **kwargs : dict
5961 Additional keyword arguments passed to the `SummaryNetwork` parent class.
6062 """
61-
63+
6264 super ().__init__ (** kwargs )
6365 if device != "cuda" :
6466 raise NotImplementedError ("MambaSSM currently only supports cuda" )
65-
67+
6668 self .mamba_blocks = [
6769 Mamba (
68- d_model = feature_dim ,
69- d_state = state_dim ,
70- d_conv = conv_dim ,
71- expand = expand ,
72- dt_min = dt_min ,
73- dt_max = dt_max
70+ d_model = feature_dim , d_state = state_dim , d_conv = conv_dim , expand = expand , dt_min = dt_min , dt_max = dt_max
7471 ).to (device )
7572 for _ in range (mamba_blocks )
7673 ]
77-
74+
7875 self .layernorm = keras .layers .LayerNormalization (axis = - 1 )
79-
76+
8077 self .pooling = pooling
8178 if pooling :
8279 self .pooling = keras .layers .GlobalAveragePooling1D ()
8380 self .dropout = keras .layers .Dropout (dropout )
8481 self .summary_stats = keras .layers .Dense (summary_dim )
85-
82+
8683 def call (self , time_series , ** kwargs ):
8784 summary = time_series
8885 for mamba_block in self .mamba_blocks :
@@ -94,4 +91,4 @@ def call(self, time_series, **kwargs):
9491 summary = self .pooling (summary )
9592 summary = self .dropout (summary , ** kwargs )
9693 summary = self .summary_stats (summary )
97- return summary
94+ return summary
0 commit comments