From 492479182e12abb32c7180682950f39ff5853763 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Wed, 9 Apr 2025 20:59:36 -0400 Subject: [PATCH 1/4] Add first wrapper: Mamba --- bayesflow/wrappers/__init__.py | 1 + bayesflow/wrappers/mamba.py | 186 +++++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+) create mode 100644 bayesflow/wrappers/__init__.py create mode 100644 bayesflow/wrappers/mamba.py diff --git a/bayesflow/wrappers/__init__.py b/bayesflow/wrappers/__init__.py new file mode 100644 index 000000000..e518cf6ae --- /dev/null +++ b/bayesflow/wrappers/__init__.py @@ -0,0 +1 @@ +from .mamba import MambaSSM diff --git a/bayesflow/wrappers/mamba.py b/bayesflow/wrappers/mamba.py new file mode 100644 index 000000000..12b2fb84f --- /dev/null +++ b/bayesflow/wrappers/mamba.py @@ -0,0 +1,186 @@ +from collections.abc import Sequence + +import keras +from keras.saving import register_keras_serializable as serializable + +from bayesflow.networks.summary_network import SummaryNetwork +from bayesflow.types import Tensor +from bayesflow.utils import logging +from bayesflow.utils.decorators import sanitize_input_shape + +try: + from mamba_ssm import Mamba +except ImportError: + logging.error("Mamba class is not available. Please, install the mamba-ssm library via `pip install mamba-ssm`.") + + +@serializable("bayesflow.wrappers") +class MambaBlock(keras.Layer): + def __init__( + self, + state_dim: int, + conv_dim: int, + feature_dim: int = 16, + expand: int = 1, + dt_min=0.001, + dt_max=0.1, + device: str = "cuda", + **kwargs, + ): + """ + A Keras layer implementing a Mamba-based sequence processing block. + + This layer applies a Mamba model for sequence modeling, preceded by a + convolutional projection and followed by layer normalization. + + Parameters + ---------- + state_dim : int + The dimension of the state space in the Mamba model. + conv_dim : int + The dimension of the convolutional layer used in Mamba. + feature_dim : int, optional + The feature dimension for input projection and Mamba processing (default is 16). + expand : int, optional + Expansion factor for Mamba's internal dimension (default is 1). + dt_min : float, optional + Minimum delta time for Mamba (default is 0.001). + dt_max : float, optional + Maximum delta time for Mamba (default is 0.1). + device : str, optional + The device to which the Mamba model is moved, typically "cuda" or "cpu" (default is "cuda"). + **kwargs : dict + Additional keyword arguments passed to the `keras.layers.Layer` initializer. + """ + + super().__init__(**kwargs) + + if keras.backend.backend() != "torch": + raise EnvironmentError("Mamba is only available using torch backend.") + + self.mamba = Mamba( + d_model=feature_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max + ).to(device) + + self.input_projector = keras.layers.Conv1D( + feature_dim, + kernel_size=1, + strides=1, + ) + self.layer_norm = keras.layers.LayerNormalization() + + def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: + x = self.input_projector(x) + h = self.mamba(x) + out = self.layer_norm(h + x, training=training, **kwargs) + return out + + @sanitize_input_shape + def build(self, input_shape): + super().build(input_shape) + self.call(keras.ops.zeros(input_shape)) + + +@serializable("bayesflow.wrappers") +class MambaSSM(SummaryNetwork): + """ + Wraps the original Mamba module from: + https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py + + Copyright (c) 2023, Tri Dao, Albert Gu. + """ + + def __init__( + self, + summary_dim: int = 16, + feature_dims: Sequence[int] = (64, 64), + state_dims: Sequence[int] = (128, 128), + conv_dims: Sequence[int] = (64, 64), + expand_dims: Sequence[int] = (2, 2), + dt_min: float = 0.001, + dt_max: float = 0.1, + dropout: float = 0.05, + device: str = "cuda", + **kwargs, + ): + """ + A time-series summarization network using Mamba-based State Space Models (SSM). This model processes + sequential input data using a sequence of Mamba SSM layers (determined by the length of the tuples), + followed by optional pooling, dropout, and a dense layer for extracting summary statistics. + + Parameters + ---------- + summary_dim : Sequence[int], optional + The output dimensionality of the summary statistics layer (default is 16). + feature_dims : Sequence[int], optional + The feature dimension for each mamba block, default is (64, 64), + state_dims : Sequence[int], optional + The dimensionality of the internal state in each Mamba block, default is (64, 64) + conv_dims : Sequence[int], optional + The dimensionality of the convolutional layer in each Mamba block, default is (32, 32) + expand_dims : Sequence[int], optional + The expansion factors for the hidden state in each Mamba block, default is (2, 2) + dt_min : float, optional + Minimum dynamic state evolution over time (default is 0.001). + dt_max : float, optional + Maximum dynamic state evolution over time (default is 0.1). + pooling : bool, optional + Whether to apply global average pooling (default is True). + dropout : int, float, or None, optional + Dropout rate applied before the summary layer (default is 0.5). + dropout: float, optional + Dropout probability; dropout is applied to the pooled summary vector. + device : str, optional + The computing device. Currently, only "cuda" is supported (default is "cuda"). + **kwargs : dict + Additional keyword arguments passed to the `SummaryNetwork` parent class. + """ + + super().__init__(**kwargs) + + if device != "cuda": + raise NotImplementedError("MambaSSM only supports cuda as `device`.") + + self.mamba_blocks = [] + for feature_dim, state_dim, conv_dim, expand in zip(feature_dims, state_dims, conv_dims, expand_dims): + self.mamba_blocks.append(MambaBlock(feature_dim, state_dim, conv_dim, expand, dt_min, dt_max, device)) + + self.pooling_layer = keras.layers.GlobalAveragePooling1D() + self.dropout = keras.layers.Dropout(dropout) + self.summary_stats = keras.layers.Dense(summary_dim) + + def call(self, time_series: Tensor, training: bool = True, **kwargs) -> Tensor: + """ + Apply a sequence of Mamba blocks, followed by pooling, dropout, and summary statistics. + + Parameters + ---------- + time_series : Tensor + Input tensor representing the time series data, typically of shape + (batch_size, sequence_length, feature_dim). + training : bool, optional + Whether the model is in training mode (default is True). Affects behavior of + layers like dropout. + **kwargs : dict + Additional keyword arguments (not used in this method). + + Returns + ------- + Tensor + Output tensor after applying Mamba blocks, pooling, dropout, and summary statistics. + """ + + summary = time_series + for mamba_block in self.mamba_blocks: + summary = mamba_block(summary, training=training) + + summary = self.pooling_layer(summary) + summary = self.dropout(summary, training=training) + summary = self.summary_stats(summary) + + return summary + + @sanitize_input_shape + def build(self, input_shape): + super().build(input_shape) + self.call(keras.ops.zeros(input_shape)) From c1f6bb94df641b2075bfbc69432b84abf6d3c1a6 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Fri, 11 Apr 2025 11:00:39 -0400 Subject: [PATCH 2/4] Finalize wrapper --- bayesflow/wrappers/mamba.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/bayesflow/wrappers/mamba.py b/bayesflow/wrappers/mamba.py index 12b2fb84f..35fcbfdf0 100644 --- a/bayesflow/wrappers/mamba.py +++ b/bayesflow/wrappers/mamba.py @@ -16,14 +16,22 @@ @serializable("bayesflow.wrappers") class MambaBlock(keras.Layer): + """ + Wraps the original Mamba module from, with added functionality for bidirectional processing: + https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py + + Copyright (c) 2023, Tri Dao, Albert Gu. + """ + def __init__( self, state_dim: int, conv_dim: int, feature_dim: int = 16, - expand: int = 1, - dt_min=0.001, - dt_max=0.1, + expand: int = 2, + bidirectional: bool = True, + dt_min: float = 0.001, + dt_max: float = 0.1, device: str = "cuda", **kwargs, ): @@ -58,6 +66,8 @@ def __init__( if keras.backend.backend() != "torch": raise EnvironmentError("Mamba is only available using torch backend.") + self.bidirectional = bidirectional + self.mamba = Mamba( d_model=feature_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max ).to(device) @@ -70,6 +80,13 @@ def __init__( self.layer_norm = keras.layers.LayerNormalization() def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: + out_forward = self._call(x, training=training, **kwargs) + if self.bidirectional: + out_backward = self._call(keras.ops.flip(x, axis=1), training=training, **kwargs) + return keras.ops.concatenate((out_forward, out_backward), axis=-1) + return out_forward + + def _call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: x = self.input_projector(x) h = self.mamba(x) out = self.layer_norm(h + x, training=training, **kwargs) @@ -84,7 +101,7 @@ def build(self, input_shape): @serializable("bayesflow.wrappers") class MambaSSM(SummaryNetwork): """ - Wraps the original Mamba module from: + Wraps a sequence of Mamba modules using the simple Mamba module from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py Copyright (c) 2023, Tri Dao, Albert Gu. @@ -94,9 +111,10 @@ def __init__( self, summary_dim: int = 16, feature_dims: Sequence[int] = (64, 64), - state_dims: Sequence[int] = (128, 128), + state_dims: Sequence[int] = (64, 64), conv_dims: Sequence[int] = (64, 64), expand_dims: Sequence[int] = (2, 2), + bidirectional: bool = True, dt_min: float = 0.001, dt_max: float = 0.1, dropout: float = 0.05, @@ -143,7 +161,8 @@ def __init__( self.mamba_blocks = [] for feature_dim, state_dim, conv_dim, expand in zip(feature_dims, state_dims, conv_dims, expand_dims): - self.mamba_blocks.append(MambaBlock(feature_dim, state_dim, conv_dim, expand, dt_min, dt_max, device)) + mamba = MambaBlock(feature_dim, state_dim, conv_dim, expand, bidirectional, dt_min, dt_max, device) + self.mamba_blocks.append(mamba) self.pooling_layer = keras.layers.GlobalAveragePooling1D() self.dropout = keras.layers.Dropout(dropout) From 43f561a28598662d6fe3397882d82f5f3301523c Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Fri, 11 Apr 2025 11:20:46 -0400 Subject: [PATCH 3/4] Refactor mamba wrapper --- bayesflow/wrappers/__init__.py | 2 +- bayesflow/wrappers/mamba/__init__.py | 1 + bayesflow/wrappers/{ => mamba}/mamba.py | 99 ++------------------ bayesflow/wrappers/mamba/mamba_block.py | 114 ++++++++++++++++++++++++ 4 files changed, 124 insertions(+), 92 deletions(-) create mode 100644 bayesflow/wrappers/mamba/__init__.py rename bayesflow/wrappers/{ => mamba}/mamba.py (58%) create mode 100644 bayesflow/wrappers/mamba/mamba_block.py diff --git a/bayesflow/wrappers/__init__.py b/bayesflow/wrappers/__init__.py index e518cf6ae..efd3a2a3e 100644 --- a/bayesflow/wrappers/__init__.py +++ b/bayesflow/wrappers/__init__.py @@ -1 +1 @@ -from .mamba import MambaSSM +from .mamba import Mamba diff --git a/bayesflow/wrappers/mamba/__init__.py b/bayesflow/wrappers/mamba/__init__.py new file mode 100644 index 000000000..b1e9d403a --- /dev/null +++ b/bayesflow/wrappers/mamba/__init__.py @@ -0,0 +1 @@ +from mamba import Mamba diff --git a/bayesflow/wrappers/mamba.py b/bayesflow/wrappers/mamba/mamba.py similarity index 58% rename from bayesflow/wrappers/mamba.py rename to bayesflow/wrappers/mamba/mamba.py index 35fcbfdf0..0dba70d43 100644 --- a/bayesflow/wrappers/mamba.py +++ b/bayesflow/wrappers/mamba/mamba.py @@ -5,106 +5,23 @@ from bayesflow.networks.summary_network import SummaryNetwork from bayesflow.types import Tensor -from bayesflow.utils import logging +from bayesflow.utils import keras_kwargs from bayesflow.utils.decorators import sanitize_input_shape -try: - from mamba_ssm import Mamba -except ImportError: - logging.error("Mamba class is not available. Please, install the mamba-ssm library via `pip install mamba-ssm`.") +from .mamba_block import MambaBlock @serializable("bayesflow.wrappers") -class MambaBlock(keras.Layer): +class Mamba(SummaryNetwork): """ - Wraps the original Mamba module from, with added functionality for bidirectional processing: + Wraps a sequence of Mamba modules using the simple Mamba module from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py Copyright (c) 2023, Tri Dao, Albert Gu. - """ - - def __init__( - self, - state_dim: int, - conv_dim: int, - feature_dim: int = 16, - expand: int = 2, - bidirectional: bool = True, - dt_min: float = 0.001, - dt_max: float = 0.1, - device: str = "cuda", - **kwargs, - ): - """ - A Keras layer implementing a Mamba-based sequence processing block. - - This layer applies a Mamba model for sequence modeling, preceded by a - convolutional projection and followed by layer normalization. - - Parameters - ---------- - state_dim : int - The dimension of the state space in the Mamba model. - conv_dim : int - The dimension of the convolutional layer used in Mamba. - feature_dim : int, optional - The feature dimension for input projection and Mamba processing (default is 16). - expand : int, optional - Expansion factor for Mamba's internal dimension (default is 1). - dt_min : float, optional - Minimum delta time for Mamba (default is 0.001). - dt_max : float, optional - Maximum delta time for Mamba (default is 0.1). - device : str, optional - The device to which the Mamba model is moved, typically "cuda" or "cpu" (default is "cuda"). - **kwargs : dict - Additional keyword arguments passed to the `keras.layers.Layer` initializer. - """ - - super().__init__(**kwargs) - - if keras.backend.backend() != "torch": - raise EnvironmentError("Mamba is only available using torch backend.") - - self.bidirectional = bidirectional - - self.mamba = Mamba( - d_model=feature_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max - ).to(device) - - self.input_projector = keras.layers.Conv1D( - feature_dim, - kernel_size=1, - strides=1, - ) - self.layer_norm = keras.layers.LayerNormalization() - - def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: - out_forward = self._call(x, training=training, **kwargs) - if self.bidirectional: - out_backward = self._call(keras.ops.flip(x, axis=1), training=training, **kwargs) - return keras.ops.concatenate((out_forward, out_backward), axis=-1) - return out_forward - - def _call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: - x = self.input_projector(x) - h = self.mamba(x) - out = self.layer_norm(h + x, training=training, **kwargs) - return out - - @sanitize_input_shape - def build(self, input_shape): - super().build(input_shape) - self.call(keras.ops.zeros(input_shape)) + Example usage in a BayesFlow workflow as a summary network: -@serializable("bayesflow.wrappers") -class MambaSSM(SummaryNetwork): - """ - Wraps a sequence of Mamba modules using the simple Mamba module from: - https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py - - Copyright (c) 2023, Tri Dao, Albert Gu. + `summary_net = bayesflow.wrappers.Mamba(summary_dim=32)` """ def __init__( @@ -150,11 +67,11 @@ def __init__( Dropout probability; dropout is applied to the pooled summary vector. device : str, optional The computing device. Currently, only "cuda" is supported (default is "cuda"). - **kwargs : dict + **kwargs : Additional keyword arguments passed to the `SummaryNetwork` parent class. """ - super().__init__(**kwargs) + super().__init__(**keras_kwargs(kwargs)) if device != "cuda": raise NotImplementedError("MambaSSM only supports cuda as `device`.") diff --git a/bayesflow/wrappers/mamba/mamba_block.py b/bayesflow/wrappers/mamba/mamba_block.py new file mode 100644 index 000000000..fcc3d8dce --- /dev/null +++ b/bayesflow/wrappers/mamba/mamba_block.py @@ -0,0 +1,114 @@ +import keras +from keras.saving import register_keras_serializable as serializable + +from bayesflow.types import Tensor +from bayesflow.utils import logging, keras_kwargs +from bayesflow.utils.decorators import sanitize_input_shape + +try: + from mamba_ssm import Mamba +except ImportError: + logging.error("Mamba class is not available. Please, install the mamba-ssm library via `pip install mamba-ssm`.") + + +@serializable("bayesflow.wrappers") +class MambaBlock(keras.Layer): + """ + Wraps the original Mamba module from, with added functionality for bidirectional processing: + https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py + + Copyright (c) 2023, Tri Dao, Albert Gu. + """ + + def __init__( + self, + state_dim: int, + conv_dim: int, + feature_dim: int = 16, + expand: int = 2, + bidirectional: bool = True, + dt_min: float = 0.001, + dt_max: float = 0.1, + device: str = "cuda", + **kwargs, + ): + """ + A Keras layer implementing a Mamba-based sequence processing block. + + This layer applies a Mamba model for sequence modeling, preceded by a + convolutional projection and followed by layer normalization. + + Parameters + ---------- + state_dim : int + The dimension of the state space in the Mamba model. + conv_dim : int + The dimension of the convolutional layer used in Mamba. + feature_dim : int, optional + The feature dimension for input projection and Mamba processing (default is 16). + expand : int, optional + Expansion factor for Mamba's internal dimension (default is 1). + dt_min : float, optional + Minimum delta time for Mamba (default is 0.001). + dt_max : float, optional + Maximum delta time for Mamba (default is 0.1). + device : str, optional + The device to which the Mamba model is moved, typically "cuda" or "cpu" (default is "cuda"). + **kwargs : + Additional keyword arguments passed to the `keras.layers.Layer` initializer. + """ + + super().__init__(**keras_kwargs(kwargs)) + + if keras.backend.backend() != "torch": + raise RuntimeError("Mamba is only available using torch backend.") + + self.bidirectional = bidirectional + + self.mamba = Mamba( + d_model=feature_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max + ).to(device) + + self.input_projector = keras.layers.Conv1D( + feature_dim, + kernel_size=1, + strides=1, + ) + self.layer_norm = keras.layers.LayerNormalization() + + def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: + """ + Applies the Mamba layer to the input tensor `x`, optionally in a bidirectional manner. + + Parameters + ---------- + x : Tensor + Input tensor of shape `(batch_size, sequence_length, input_dim)`. + training : bool, optional + Whether the layer should behave in training mode (e.g., applying dropout). Default is False. + **kwargs : dict + Additional keyword arguments passed to the internal `_call` method. + + Returns + ------- + Tensor + Output tensor of shape `(batch_size, sequence_length, feature_dim)` if unidirectional, + or `(batch_size, sequence_length, 2 * feature_dim)` if bidirectional. + """ + + out_forward = self._call(x, training=training, **kwargs) + if self.bidirectional: + out_backward = self._call(keras.ops.flip(x, axis=-2), training=training, **kwargs) + return keras.ops.concatenate((out_forward, out_backward), axis=-1) + return out_forward + + def _call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: + x = self.input_projector(x) + h = self.mamba(x) + out = self.layer_norm(h + x, training=training, **kwargs) + return out + + @sanitize_input_shape + def build(self, input_shape): + super().build(input_shape) + self.call(keras.ops.zeros(input_shape)) From a77025ee0deb1616c2dce97650d5c03ce7ec4664 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Fri, 11 Apr 2025 11:39:03 -0400 Subject: [PATCH 4/4] Fix import --- bayesflow/wrappers/mamba/__init__.py | 2 +- tests/test_wrappers/__init__.py | 0 tests/test_wrappers/conftest.py | 22 ++++++++++++++++++++++ tests/test_wrappers/test_mamba.py | 25 +++++++++++++++++++++++++ 4 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 tests/test_wrappers/__init__.py create mode 100644 tests/test_wrappers/conftest.py create mode 100644 tests/test_wrappers/test_mamba.py diff --git a/bayesflow/wrappers/mamba/__init__.py b/bayesflow/wrappers/mamba/__init__.py index b1e9d403a..efd3a2a3e 100644 --- a/bayesflow/wrappers/mamba/__init__.py +++ b/bayesflow/wrappers/mamba/__init__.py @@ -1 +1 @@ -from mamba import Mamba +from .mamba import Mamba diff --git a/tests/test_wrappers/__init__.py b/tests/test_wrappers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_wrappers/conftest.py b/tests/test_wrappers/conftest.py new file mode 100644 index 000000000..8a336a41b --- /dev/null +++ b/tests/test_wrappers/conftest.py @@ -0,0 +1,22 @@ +import pytest + + +@pytest.fixture() +def inference_network(): + from bayesflow.networks import CouplingFlow + + return CouplingFlow(depth=2) + + +@pytest.fixture() +def random_time_series(): + import keras + + return keras.random.normal(shape=(2, 80, 2)) + + +@pytest.fixture() +def mamba_summary_network(): + from bayesflow.wrappers.mamba import Mamba + + return Mamba(summary_dim=4, feature_dims=(2, 2), state_dims=(4, 4), conv_dims=(8, 8)) diff --git a/tests/test_wrappers/test_mamba.py b/tests/test_wrappers/test_mamba.py new file mode 100644 index 000000000..4d7dcfb6b --- /dev/null +++ b/tests/test_wrappers/test_mamba.py @@ -0,0 +1,25 @@ +import pytest + +import bayesflow as bf + + +@pytest.mark.torch +def test_mamba_summary(random_time_series, mamba_summary_network): + out = mamba_summary_network(random_time_series) + # Batch size 2, summary dim 4 + assert out.shape == (2, 4) + + +@pytest.mark.torch +def test_mamba_trains(random_time_series, inference_network, mamba_summary_network): + workflow = bf.BasicWorkflow( + inference_network=inference_network, + summary_network=mamba_summary_network, + inference_variables=["parameters"], + summary_variables=["observables"], + simulator=bf.simulators.SIR(), + ) + + history = workflow.fit_online(epochs=2, batch_size=8, num_batches_per_epoch=2) + assert "loss" in list(history.history.keys()) + assert len(history.history["loss"]) == 2