|
| 1 | +from collections.abc import Sequence |
| 2 | + |
| 3 | +import keras |
| 4 | +from keras.saving import register_keras_serializable as serializable |
| 5 | + |
| 6 | +from bayesflow.networks.summary_network import SummaryNetwork |
| 7 | +from bayesflow.types import Tensor |
| 8 | +from bayesflow.utils import logging |
| 9 | +from bayesflow.utils.decorators import sanitize_input_shape |
| 10 | + |
| 11 | +try: |
| 12 | + from mamba_ssm import Mamba |
| 13 | +except ImportError: |
| 14 | + logging.error("Mamba class is not available. Please, install the mamba-ssm library via `pip install mamba-ssm`.") |
| 15 | + |
| 16 | + |
| 17 | +@serializable("bayesflow.wrappers") |
| 18 | +class MambaBlock(keras.Layer): |
| 19 | + def __init__( |
| 20 | + self, |
| 21 | + state_dim: int, |
| 22 | + conv_dim: int, |
| 23 | + feature_dim: int = 16, |
| 24 | + expand: int = 1, |
| 25 | + dt_min=0.001, |
| 26 | + dt_max=0.1, |
| 27 | + device: str = "cuda", |
| 28 | + **kwargs, |
| 29 | + ): |
| 30 | + """ |
| 31 | + A Keras layer implementing a Mamba-based sequence processing block. |
| 32 | +
|
| 33 | + This layer applies a Mamba model for sequence modeling, preceded by a |
| 34 | + convolutional projection and followed by layer normalization. |
| 35 | +
|
| 36 | + Parameters |
| 37 | + ---------- |
| 38 | + state_dim : int |
| 39 | + The dimension of the state space in the Mamba model. |
| 40 | + conv_dim : int |
| 41 | + The dimension of the convolutional layer used in Mamba. |
| 42 | + feature_dim : int, optional |
| 43 | + The feature dimension for input projection and Mamba processing (default is 16). |
| 44 | + expand : int, optional |
| 45 | + Expansion factor for Mamba's internal dimension (default is 1). |
| 46 | + dt_min : float, optional |
| 47 | + Minimum delta time for Mamba (default is 0.001). |
| 48 | + dt_max : float, optional |
| 49 | + Maximum delta time for Mamba (default is 0.1). |
| 50 | + device : str, optional |
| 51 | + The device to which the Mamba model is moved, typically "cuda" or "cpu" (default is "cuda"). |
| 52 | + **kwargs : dict |
| 53 | + Additional keyword arguments passed to the `keras.layers.Layer` initializer. |
| 54 | + """ |
| 55 | + |
| 56 | + super().__init__(**kwargs) |
| 57 | + |
| 58 | + if keras.backend.backend() != "torch": |
| 59 | + raise EnvironmentError("Mamba is only available using torch backend.") |
| 60 | + |
| 61 | + self.mamba = Mamba( |
| 62 | + d_model=feature_dim, d_state=state_dim, d_conv=conv_dim, expand=expand, dt_min=dt_min, dt_max=dt_max |
| 63 | + ).to(device) |
| 64 | + |
| 65 | + self.input_projector = keras.layers.Conv1D( |
| 66 | + feature_dim, |
| 67 | + kernel_size=1, |
| 68 | + strides=1, |
| 69 | + ) |
| 70 | + self.layer_norm = keras.layers.LayerNormalization() |
| 71 | + |
| 72 | + def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor: |
| 73 | + x = self.input_projector(x) |
| 74 | + h = self.mamba(x) |
| 75 | + out = self.layer_norm(h + x, training=training, **kwargs) |
| 76 | + return out |
| 77 | + |
| 78 | + @sanitize_input_shape |
| 79 | + def build(self, input_shape): |
| 80 | + super().build(input_shape) |
| 81 | + self.call(keras.ops.zeros(input_shape)) |
| 82 | + |
| 83 | + |
| 84 | +@serializable("bayesflow.wrappers") |
| 85 | +class MambaSSM(SummaryNetwork): |
| 86 | + """ |
| 87 | + Wraps the original Mamba module from: |
| 88 | + https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py |
| 89 | +
|
| 90 | + Copyright (c) 2023, Tri Dao, Albert Gu. |
| 91 | + """ |
| 92 | + |
| 93 | + def __init__( |
| 94 | + self, |
| 95 | + summary_dim: int = 16, |
| 96 | + feature_dims: Sequence[int] = (64, 64), |
| 97 | + state_dims: Sequence[int] = (128, 128), |
| 98 | + conv_dims: Sequence[int] = (64, 64), |
| 99 | + expand_dims: Sequence[int] = (2, 2), |
| 100 | + dt_min: float = 0.001, |
| 101 | + dt_max: float = 0.1, |
| 102 | + dropout: float = 0.05, |
| 103 | + device: str = "cuda", |
| 104 | + **kwargs, |
| 105 | + ): |
| 106 | + """ |
| 107 | + A time-series summarization network using Mamba-based State Space Models (SSM). This model processes |
| 108 | + sequential input data using a sequence of Mamba SSM layers (determined by the length of the tuples), |
| 109 | + followed by optional pooling, dropout, and a dense layer for extracting summary statistics. |
| 110 | +
|
| 111 | + Parameters |
| 112 | + ---------- |
| 113 | + summary_dim : Sequence[int], optional |
| 114 | + The output dimensionality of the summary statistics layer (default is 16). |
| 115 | + feature_dims : Sequence[int], optional |
| 116 | + The feature dimension for each mamba block, default is (64, 64), |
| 117 | + state_dims : Sequence[int], optional |
| 118 | + The dimensionality of the internal state in each Mamba block, default is (64, 64) |
| 119 | + conv_dims : Sequence[int], optional |
| 120 | + The dimensionality of the convolutional layer in each Mamba block, default is (32, 32) |
| 121 | + expand_dims : Sequence[int], optional |
| 122 | + The expansion factors for the hidden state in each Mamba block, default is (2, 2) |
| 123 | + dt_min : float, optional |
| 124 | + Minimum dynamic state evolution over time (default is 0.001). |
| 125 | + dt_max : float, optional |
| 126 | + Maximum dynamic state evolution over time (default is 0.1). |
| 127 | + pooling : bool, optional |
| 128 | + Whether to apply global average pooling (default is True). |
| 129 | + dropout : int, float, or None, optional |
| 130 | + Dropout rate applied before the summary layer (default is 0.5). |
| 131 | + dropout: float, optional |
| 132 | + Dropout probability; dropout is applied to the pooled summary vector. |
| 133 | + device : str, optional |
| 134 | + The computing device. Currently, only "cuda" is supported (default is "cuda"). |
| 135 | + **kwargs : dict |
| 136 | + Additional keyword arguments passed to the `SummaryNetwork` parent class. |
| 137 | + """ |
| 138 | + |
| 139 | + super().__init__(**kwargs) |
| 140 | + |
| 141 | + if device != "cuda": |
| 142 | + raise NotImplementedError("MambaSSM only supports cuda as `device`.") |
| 143 | + |
| 144 | + self.mamba_blocks = [] |
| 145 | + for feature_dim, state_dim, conv_dim, expand in zip(feature_dims, state_dims, conv_dims, expand_dims): |
| 146 | + self.mamba_blocks.append(MambaBlock(feature_dim, state_dim, conv_dim, expand, dt_min, dt_max, device)) |
| 147 | + |
| 148 | + self.pooling_layer = keras.layers.GlobalAveragePooling1D() |
| 149 | + self.dropout = keras.layers.Dropout(dropout) |
| 150 | + self.summary_stats = keras.layers.Dense(summary_dim) |
| 151 | + |
| 152 | + def call(self, time_series: Tensor, training: bool = True, **kwargs) -> Tensor: |
| 153 | + """ |
| 154 | + Apply a sequence of Mamba blocks, followed by pooling, dropout, and summary statistics. |
| 155 | +
|
| 156 | + Parameters |
| 157 | + ---------- |
| 158 | + time_series : Tensor |
| 159 | + Input tensor representing the time series data, typically of shape |
| 160 | + (batch_size, sequence_length, feature_dim). |
| 161 | + training : bool, optional |
| 162 | + Whether the model is in training mode (default is True). Affects behavior of |
| 163 | + layers like dropout. |
| 164 | + **kwargs : dict |
| 165 | + Additional keyword arguments (not used in this method). |
| 166 | +
|
| 167 | + Returns |
| 168 | + ------- |
| 169 | + Tensor |
| 170 | + Output tensor after applying Mamba blocks, pooling, dropout, and summary statistics. |
| 171 | + """ |
| 172 | + |
| 173 | + summary = time_series |
| 174 | + for mamba_block in self.mamba_blocks: |
| 175 | + summary = mamba_block(summary, training=training) |
| 176 | + |
| 177 | + summary = self.pooling_layer(summary) |
| 178 | + summary = self.dropout(summary, training=training) |
| 179 | + summary = self.summary_stats(summary) |
| 180 | + |
| 181 | + return summary |
| 182 | + |
| 183 | + @sanitize_input_shape |
| 184 | + def build(self, input_shape): |
| 185 | + super().build(input_shape) |
| 186 | + self.call(keras.ops.zeros(input_shape)) |
0 commit comments