|
| 1 | +from collections.abc import Sequence |
| 2 | + |
| 3 | +import keras |
| 4 | +from keras.saving import register_keras_serializable as serializable |
| 5 | + |
| 6 | +from bayesflow.networks.summary_network import SummaryNetwork |
| 7 | +from bayesflow.types import Tensor |
| 8 | +from bayesflow.utils import keras_kwargs |
| 9 | +from bayesflow.utils.decorators import sanitize_input_shape |
| 10 | + |
| 11 | +from .mamba_block import MambaBlock |
| 12 | + |
| 13 | + |
| 14 | +@serializable("bayesflow.wrappers") |
| 15 | +class Mamba(SummaryNetwork): |
| 16 | + """ |
| 17 | + Wraps a sequence of Mamba modules using the simple Mamba module from: |
| 18 | + https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py |
| 19 | +
|
| 20 | + Copyright (c) 2023, Tri Dao, Albert Gu. |
| 21 | +
|
| 22 | + Example usage in a BayesFlow workflow as a summary network: |
| 23 | +
|
| 24 | + `summary_net = bayesflow.wrappers.Mamba(summary_dim=32)` |
| 25 | + """ |
| 26 | + |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + summary_dim: int = 16, |
| 30 | + feature_dims: Sequence[int] = (64, 64), |
| 31 | + state_dims: Sequence[int] = (64, 64), |
| 32 | + conv_dims: Sequence[int] = (64, 64), |
| 33 | + expand_dims: Sequence[int] = (2, 2), |
| 34 | + bidirectional: bool = True, |
| 35 | + dt_min: float = 0.001, |
| 36 | + dt_max: float = 0.1, |
| 37 | + dropout: float = 0.05, |
| 38 | + device: str = "cuda", |
| 39 | + **kwargs, |
| 40 | + ): |
| 41 | + """ |
| 42 | + A time-series summarization network using Mamba-based State Space Models (SSM). This model processes |
| 43 | + sequential input data using a sequence of Mamba SSM layers (determined by the length of the tuples), |
| 44 | + followed by optional pooling, dropout, and a dense layer for extracting summary statistics. |
| 45 | +
|
| 46 | + Parameters |
| 47 | + ---------- |
| 48 | + summary_dim : Sequence[int], optional |
| 49 | + The output dimensionality of the summary statistics layer (default is 16). |
| 50 | + feature_dims : Sequence[int], optional |
| 51 | + The feature dimension for each mamba block, default is (64, 64), |
| 52 | + state_dims : Sequence[int], optional |
| 53 | + The dimensionality of the internal state in each Mamba block, default is (64, 64) |
| 54 | + conv_dims : Sequence[int], optional |
| 55 | + The dimensionality of the convolutional layer in each Mamba block, default is (32, 32) |
| 56 | + expand_dims : Sequence[int], optional |
| 57 | + The expansion factors for the hidden state in each Mamba block, default is (2, 2) |
| 58 | + dt_min : float, optional |
| 59 | + Minimum dynamic state evolution over time (default is 0.001). |
| 60 | + dt_max : float, optional |
| 61 | + Maximum dynamic state evolution over time (default is 0.1). |
| 62 | + pooling : bool, optional |
| 63 | + Whether to apply global average pooling (default is True). |
| 64 | + dropout : int, float, or None, optional |
| 65 | + Dropout rate applied before the summary layer (default is 0.5). |
| 66 | + dropout: float, optional |
| 67 | + Dropout probability; dropout is applied to the pooled summary vector. |
| 68 | + device : str, optional |
| 69 | + The computing device. Currently, only "cuda" is supported (default is "cuda"). |
| 70 | + **kwargs : |
| 71 | + Additional keyword arguments passed to the `SummaryNetwork` parent class. |
| 72 | + """ |
| 73 | + |
| 74 | + super().__init__(**keras_kwargs(kwargs)) |
| 75 | + |
| 76 | + if device != "cuda": |
| 77 | + raise NotImplementedError("MambaSSM only supports cuda as `device`.") |
| 78 | + |
| 79 | + self.mamba_blocks = [] |
| 80 | + for feature_dim, state_dim, conv_dim, expand in zip(feature_dims, state_dims, conv_dims, expand_dims): |
| 81 | + mamba = MambaBlock(feature_dim, state_dim, conv_dim, expand, bidirectional, dt_min, dt_max, device) |
| 82 | + self.mamba_blocks.append(mamba) |
| 83 | + |
| 84 | + self.pooling_layer = keras.layers.GlobalAveragePooling1D() |
| 85 | + self.dropout = keras.layers.Dropout(dropout) |
| 86 | + self.summary_stats = keras.layers.Dense(summary_dim) |
| 87 | + |
| 88 | + def call(self, time_series: Tensor, training: bool = True, **kwargs) -> Tensor: |
| 89 | + """ |
| 90 | + Apply a sequence of Mamba blocks, followed by pooling, dropout, and summary statistics. |
| 91 | +
|
| 92 | + Parameters |
| 93 | + ---------- |
| 94 | + time_series : Tensor |
| 95 | + Input tensor representing the time series data, typically of shape |
| 96 | + (batch_size, sequence_length, feature_dim). |
| 97 | + training : bool, optional |
| 98 | + Whether the model is in training mode (default is True). Affects behavior of |
| 99 | + layers like dropout. |
| 100 | + **kwargs : dict |
| 101 | + Additional keyword arguments (not used in this method). |
| 102 | +
|
| 103 | + Returns |
| 104 | + ------- |
| 105 | + Tensor |
| 106 | + Output tensor after applying Mamba blocks, pooling, dropout, and summary statistics. |
| 107 | + """ |
| 108 | + |
| 109 | + summary = time_series |
| 110 | + for mamba_block in self.mamba_blocks: |
| 111 | + summary = mamba_block(summary, training=training) |
| 112 | + |
| 113 | + summary = self.pooling_layer(summary) |
| 114 | + summary = self.dropout(summary, training=training) |
| 115 | + summary = self.summary_stats(summary) |
| 116 | + |
| 117 | + return summary |
| 118 | + |
| 119 | + @sanitize_input_shape |
| 120 | + def build(self, input_shape): |
| 121 | + super().build(input_shape) |
| 122 | + self.call(keras.ops.zeros(input_shape)) |
0 commit comments