Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bayesflow/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mamba import MambaSSM

Check warning on line 1 in bayesflow/wrappers/__init__.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/__init__.py#L1

Added line #L1 was not covered by tests
205 changes: 205 additions & 0 deletions bayesflow/wrappers/mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from collections.abc import Sequence

Check warning on line 1 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L1

Added line #L1 was not covered by tests

import keras
from keras.saving import register_keras_serializable as serializable

Check warning on line 4 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L3-L4

Added lines #L3 - L4 were not covered by tests

from bayesflow.networks.summary_network import SummaryNetwork
from bayesflow.types import Tensor
from bayesflow.utils import logging
from bayesflow.utils.decorators import sanitize_input_shape

Check warning on line 9 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L6-L9

Added lines #L6 - L9 were not covered by tests

try:
from mamba_ssm import Mamba
except ImportError:
logging.error("Mamba class is not available. Please, install the mamba-ssm library via `pip install mamba-ssm`.")

Check warning on line 14 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L11-L14

Added lines #L11 - L14 were not covered by tests


@serializable("bayesflow.wrappers")
class MambaBlock(keras.Layer):

Check warning on line 18 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L17-L18

Added lines #L17 - L18 were not covered by tests
"""
Wraps the original Mamba module from, with added functionality for bidirectional processing:
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py

Copyright (c) 2023, Tri Dao, Albert Gu.
"""

def __init__(

Check warning on line 26 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L26

Added line #L26 was not covered by tests
self,
state_dim: int,
conv_dim: int,
feature_dim: int = 16,
expand: int = 2,
bidirectional: bool = True,
dt_min: float = 0.001,
dt_max: float = 0.1,
device: str = "cuda",
**kwargs,
):
"""
A Keras layer implementing a Mamba-based sequence processing block.

This layer applies a Mamba model for sequence modeling, preceded by a
convolutional projection and followed by layer normalization.

Parameters
----------
state_dim : int
The dimension of the state space in the Mamba model.
conv_dim : int
The dimension of the convolutional layer used in Mamba.
feature_dim : int, optional
The feature dimension for input projection and Mamba processing (default is 16).
expand : int, optional
Expansion factor for Mamba's internal dimension (default is 1).
dt_min : float, optional
Minimum delta time for Mamba (default is 0.001).
dt_max : float, optional
Maximum delta time for Mamba (default is 0.1).
device : str, optional
The device to which the Mamba model is moved, typically "cuda" or "cpu" (default is "cuda").
**kwargs : dict
Additional keyword arguments passed to the `keras.layers.Layer` initializer.
"""

super().__init__(**kwargs)

Check warning on line 64 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L64

Added line #L64 was not covered by tests

if keras.backend.backend() != "torch":
raise EnvironmentError("Mamba is only available using torch backend.")

Check warning on line 67 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L66-L67

Added lines #L66 - L67 were not covered by tests

self.bidirectional = bidirectional

Check warning on line 69 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L69

Added line #L69 was not covered by tests

self.mamba = Mamba(

Check warning on line 71 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L71

Added line #L71 was not covered by tests
d_model=feature_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max
).to(device)

self.input_projector = keras.layers.Conv1D(

Check warning on line 75 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L75

Added line #L75 was not covered by tests
feature_dim,
kernel_size=1,
strides=1,
)
self.layer_norm = keras.layers.LayerNormalization()

Check warning on line 80 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L80

Added line #L80 was not covered by tests

def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
out_forward = self._call(x, training=training, **kwargs)
if self.bidirectional:
out_backward = self._call(keras.ops.flip(x, axis=1), training=training, **kwargs)
return keras.ops.concatenate((out_forward, out_backward), axis=-1)
return out_forward

Check warning on line 87 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L82-L87

Added lines #L82 - L87 were not covered by tests

def _call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
x = self.input_projector(x)
h = self.mamba(x)
out = self.layer_norm(h + x, training=training, **kwargs)
return out

Check warning on line 93 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L89-L93

Added lines #L89 - L93 were not covered by tests

@sanitize_input_shape
def build(self, input_shape):
super().build(input_shape)
self.call(keras.ops.zeros(input_shape))

Check warning on line 98 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L95-L98

Added lines #L95 - L98 were not covered by tests


@serializable("bayesflow.wrappers")
class MambaSSM(SummaryNetwork):

Check warning on line 102 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L101-L102

Added lines #L101 - L102 were not covered by tests
"""
Wraps a sequence of Mamba modules using the simple Mamba module from:
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py

Copyright (c) 2023, Tri Dao, Albert Gu.
"""

def __init__(

Check warning on line 110 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L110

Added line #L110 was not covered by tests
self,
summary_dim: int = 16,
feature_dims: Sequence[int] = (64, 64),
state_dims: Sequence[int] = (64, 64),
conv_dims: Sequence[int] = (64, 64),
expand_dims: Sequence[int] = (2, 2),
bidirectional: bool = True,
dt_min: float = 0.001,
dt_max: float = 0.1,
dropout: float = 0.05,
device: str = "cuda",
**kwargs,
):
"""
A time-series summarization network using Mamba-based State Space Models (SSM). This model processes
sequential input data using a sequence of Mamba SSM layers (determined by the length of the tuples),
followed by optional pooling, dropout, and a dense layer for extracting summary statistics.

Parameters
----------
summary_dim : Sequence[int], optional
The output dimensionality of the summary statistics layer (default is 16).
feature_dims : Sequence[int], optional
The feature dimension for each mamba block, default is (64, 64),
state_dims : Sequence[int], optional
The dimensionality of the internal state in each Mamba block, default is (64, 64)
conv_dims : Sequence[int], optional
The dimensionality of the convolutional layer in each Mamba block, default is (32, 32)
expand_dims : Sequence[int], optional
The expansion factors for the hidden state in each Mamba block, default is (2, 2)
dt_min : float, optional
Minimum dynamic state evolution over time (default is 0.001).
dt_max : float, optional
Maximum dynamic state evolution over time (default is 0.1).
pooling : bool, optional
Whether to apply global average pooling (default is True).
dropout : int, float, or None, optional
Dropout rate applied before the summary layer (default is 0.5).
dropout: float, optional
Dropout probability; dropout is applied to the pooled summary vector.
device : str, optional
The computing device. Currently, only "cuda" is supported (default is "cuda").
**kwargs : dict
Additional keyword arguments passed to the `SummaryNetwork` parent class.
"""

super().__init__(**kwargs)

Check warning on line 157 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L157

Added line #L157 was not covered by tests

if device != "cuda":
raise NotImplementedError("MambaSSM only supports cuda as `device`.")

Check warning on line 160 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L159-L160

Added lines #L159 - L160 were not covered by tests

self.mamba_blocks = []
for feature_dim, state_dim, conv_dim, expand in zip(feature_dims, state_dims, conv_dims, expand_dims):
mamba = MambaBlock(feature_dim, state_dim, conv_dim, expand, bidirectional, dt_min, dt_max, device)
self.mamba_blocks.append(mamba)

Check warning on line 165 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L162-L165

Added lines #L162 - L165 were not covered by tests

self.pooling_layer = keras.layers.GlobalAveragePooling1D()
self.dropout = keras.layers.Dropout(dropout)
self.summary_stats = keras.layers.Dense(summary_dim)

Check warning on line 169 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L167-L169

Added lines #L167 - L169 were not covered by tests

def call(self, time_series: Tensor, training: bool = True, **kwargs) -> Tensor:

Check warning on line 171 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L171

Added line #L171 was not covered by tests
"""
Apply a sequence of Mamba blocks, followed by pooling, dropout, and summary statistics.

Parameters
----------
time_series : Tensor
Input tensor representing the time series data, typically of shape
(batch_size, sequence_length, feature_dim).
training : bool, optional
Whether the model is in training mode (default is True). Affects behavior of
layers like dropout.
**kwargs : dict
Additional keyword arguments (not used in this method).

Returns
-------
Tensor
Output tensor after applying Mamba blocks, pooling, dropout, and summary statistics.
"""

summary = time_series
for mamba_block in self.mamba_blocks:
summary = mamba_block(summary, training=training)

Check warning on line 194 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L192-L194

Added lines #L192 - L194 were not covered by tests

summary = self.pooling_layer(summary)
summary = self.dropout(summary, training=training)
summary = self.summary_stats(summary)

Check warning on line 198 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L196-L198

Added lines #L196 - L198 were not covered by tests

return summary

Check warning on line 200 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L200

Added line #L200 was not covered by tests

@sanitize_input_shape
def build(self, input_shape):
super().build(input_shape)
self.call(keras.ops.zeros(input_shape))

Check warning on line 205 in bayesflow/wrappers/mamba.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/wrappers/mamba.py#L202-L205

Added lines #L202 - L205 were not covered by tests