Skip to content

Commit 6cce1bc

Browse files
committed
update serialization for mamba
1 parent 6e5e191 commit 6cce1bc

File tree

2 files changed

+7
-14
lines changed

2 files changed

+7
-14
lines changed

bayesflow/wrappers/mamba/mamba.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
from collections.abc import Sequence
22

33
import keras
4-
from keras.saving import register_keras_serializable as serializable
54

65
from bayesflow.networks.summary_network import SummaryNetwork
76
from bayesflow.types import Tensor
8-
from bayesflow.utils import keras_kwargs
9-
from bayesflow.utils.decorators import sanitize_input_shape
7+
from bayesflow.utils.serialization import serializable
108

119
from .mamba_block import MambaBlock
1210

1311

14-
@serializable("bayesflow.wrappers")
12+
@serializable
1513
class Mamba(SummaryNetwork):
1614
"""
1715
Wraps a sequence of Mamba modules using the simple Mamba module from:
@@ -71,7 +69,7 @@ def __init__(
7169
Additional keyword arguments passed to the `SummaryNetwork` parent class.
7270
"""
7371

74-
super().__init__(**keras_kwargs(kwargs))
72+
super().__init__(**kwargs)
7573

7674
if device != "cuda":
7775
raise NotImplementedError("MambaSSM only supports cuda as `device`.")
@@ -115,8 +113,3 @@ def call(self, time_series: Tensor, training: bool = True, **kwargs) -> Tensor:
115113
summary = self.summary_stats(summary)
116114

117115
return summary
118-
119-
@sanitize_input_shape
120-
def build(self, input_shape):
121-
super().build(input_shape)
122-
self.call(keras.ops.zeros(input_shape))

bayesflow/wrappers/mamba/mamba_block.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import keras
2-
from keras.saving import register_keras_serializable as serializable
32

43
from bayesflow.types import Tensor
5-
from bayesflow.utils import keras_kwargs
4+
from bayesflow.utils import layer_kwargs
65
from bayesflow.utils.decorators import sanitize_input_shape
6+
from bayesflow.utils.serialization import serializable
77

88

9-
@serializable("bayesflow.wrappers")
9+
@serializable
1010
class MambaBlock(keras.Layer):
1111
"""
1212
Wraps the original Mamba module from, with added functionality for bidirectional processing:
@@ -53,7 +53,7 @@ def __init__(
5353
Additional keyword arguments passed to the `keras.layers.Layer` initializer.
5454
"""
5555

56-
super().__init__(**keras_kwargs(kwargs))
56+
super().__init__(**layer_kwargs(kwargs))
5757

5858
if keras.backend.backend() != "torch":
5959
raise RuntimeError("Mamba is only available using torch backend.")

0 commit comments

Comments
 (0)