Skip to content

Commit 7b91c1a

Browse files
committed
Linting
Additional updates from ruff
1 parent 69db800 commit 7b91c1a

File tree

3 files changed

+14
-17
lines changed

3 files changed

+14
-17
lines changed

bayesflow/wrappers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .mamba import MambaSSM
1+
from .mamba import MambaSSM
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .mamba import MambaSSM
1+
from .mamba import MambaSSM

bayesflow/wrappers/mamba/mamba.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import keras
2+
23
# from keras.saving import register_keras_serializable as serializable
34
try:
45
from mamba_ssm import Mamba
@@ -7,6 +8,7 @@
78

89
from ...networks.summary_network import SummaryNetwork
910

11+
1012
# @serializable(package="bayesflow.wrappers")
1113
class MambaSSM(SummaryNetwork):
1214
def __init__(
@@ -22,13 +24,13 @@ def __init__(
2224
pooling: bool = True,
2325
dropout: int | float | None = 0.5,
2426
device: str = "cuda",
25-
**kwargs
27+
**kwargs,
2628
):
2729
"""
2830
A time-series summarization network using Mamba-based State Space Models (SSM).
29-
This model processes sequential input data using the Mamba SSM layer, followed by
31+
This model processes sequential input data using the Mamba SSM layer, followed by
3032
optional pooling, dropout, and a dense layer for extracting summary statistics.
31-
33+
3234
Mamba2 support currently unabailble due to stability issues
3335
3436
Parameters
@@ -58,31 +60,26 @@ def __init__(
5860
**kwargs : dict
5961
Additional keyword arguments passed to the `SummaryNetwork` parent class.
6062
"""
61-
63+
6264
super().__init__(**kwargs)
6365
if device != "cuda":
6466
raise NotImplementedError("MambaSSM currently only supports cuda")
65-
67+
6668
self.mamba_blocks = [
6769
Mamba(
68-
d_model=feature_dim,
69-
d_state=state_dim,
70-
d_conv=conv_dim,
71-
expand=expand,
72-
dt_min=dt_min,
73-
dt_max=dt_max
70+
d_model=feature_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max
7471
).to(device)
7572
for _ in range(mamba_blocks)
7673
]
77-
74+
7875
self.layernorm = keras.layers.LayerNormalization(axis=-1)
79-
76+
8077
self.pooling = pooling
8178
if pooling:
8279
self.pooling = keras.layers.GlobalAveragePooling1D()
8380
self.dropout = keras.layers.Dropout(dropout)
8481
self.summary_stats = keras.layers.Dense(summary_dim)
85-
82+
8683
def call(self, time_series, **kwargs):
8784
summary = time_series
8885
for mamba_block in self.mamba_blocks:
@@ -94,4 +91,4 @@ def call(self, time_series, **kwargs):
9491
summary = self.pooling(summary)
9592
summary = self.dropout(summary, **kwargs)
9693
summary = self.summary_stats(summary)
97-
return summary
94+
return summary

0 commit comments

Comments
 (0)