Skip to content

Commit 65dee2c

Browse files
committed
Patch mamba and update student t doc
1 parent 7557086 commit 65dee2c

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

bayesflow/distributions/diagonal_student_t.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import math
55
import numpy as np
66

7-
87
from bayesflow.types import Shape, Tensor
98
from bayesflow.utils import expand_tile
109
from bayesflow.utils.decorators import allow_batch_size

bayesflow/wrappers/mamba/mamba.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
self.dropout = keras.layers.Dropout(dropout)
8484
self.summary_stats = keras.layers.Dense(summary_dim)
8585

86-
def call(self, time_series: Tensor, training: bool = True, **kwargs) -> Tensor:
86+
def call(self, time_series: Tensor, training: bool = False, **kwargs) -> Tensor:
8787
"""
8888
Apply a sequence of Mamba blocks, followed by pooling, dropout, and summary statistics.
8989
@@ -93,8 +93,8 @@ def call(self, time_series: Tensor, training: bool = True, **kwargs) -> Tensor:
9393
Input tensor representing the time series data, typically of shape
9494
(batch_size, sequence_length, feature_dim).
9595
training : bool, optional
96-
Whether the model is in training mode (default is True). Affects behavior of
97-
layers like dropout.
96+
Whether the model is in training mode (default is False). Affects the behavior of
97+
the inner dropout and norm layers.
9898
**kwargs : dict
9999
Additional keyword arguments (not used in this method).
100100

0 commit comments

Comments
 (0)