Skip to content

Commit 0d64899

Browse files
committed
Optimize time series networks
1 parent b034b21 commit 0d64899

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

bayesflow/networks/time_series_network/skip_recurrent.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,12 @@
99
@serializable(package="bayesflow.networks")
1010
class SkipRecurrentNet(keras.Model):
1111
"""
12-
Implements a Skip recurrent layer as described in [1], but allowing a more flexible
13-
recurrent backbone and a more flexible implementation.
12+
Implements a Skip recurrent layer as described in [1], allowing a more flexible recurrent backbone
13+
and a more efficient implementation.
1414
1515
[1] Y. Zhang and L. Mikelsons, Solving Stochastic Inverse Problems with Stochastic BayesFlow,
1616
2023 IEEE/ASME International Conference on Advanced Intelligent Mechatronics (AIM),
1717
Seattle, WA, USA, 2023, pp. 966-972, doi: 10.1109/AIM46323.2023.10196190.
18-
19-
TODO: Add proper docstring
20-
2118
"""
2219

2320
def __init__(
@@ -30,6 +27,32 @@ def __init__(
3027
dropout: float = 0.05,
3128
**kwargs,
3229
):
30+
"""
31+
Creates a skip recurrent neural network layer that extends a traditional recurrent backbone with
32+
skip connections implemented via convolution and an additional recurrent path. This allows
33+
more efficient modeling of long-term dependencies by combining local and non-local temporal
34+
features.
35+
36+
Parameters
37+
----------
38+
hidden_dim : int, optional
39+
Dimensionality of the hidden state in the recurrent layers. Default is 256.
40+
recurrent_type : str, optional
41+
Type of recurrent unit to use. Should correspond to a supported type in `find_recurrent_net`,
42+
such as "gru" or "lstm". Default is "gru".
43+
bidirectional : bool, optional
44+
If True, uses bidirectional wrappers for both recurrent and skip recurrent layers. Default is True.
45+
input_channels : int, optional
46+
Number of input channels for the 1D convolution used in skip connections. Default is 64.
47+
skip_steps : int, optional
48+
Step size and kernel size used in the skip convolution. Determines how many steps are skipped.
49+
Also determines the multiplier for the number of filters. Default is 4.
50+
dropout : float, optional
51+
Dropout rate applied within the recurrent layers. Default is 0.05.
52+
**kwargs
53+
Additional keyword arguments passed to the parent class constructor.
54+
"""
55+
3356
super().__init__(**keras_kwargs(kwargs))
3457

3558
self.skip_conv = keras.layers.Conv1D(
@@ -64,4 +87,4 @@ def call(self, time_series: Tensor, training: bool = False, **kwargs) -> Tensor:
6487

6588
@sanitize_input_shape
6689
def build(self, input_shape):
67-
self.call(keras.ops.zeros(input_shape))
90+
super().build(input_shape)

bayesflow/networks/time_series_network/time_series_network.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(
2525
strides: int | list | tuple = 1,
2626
activation: str = "mish",
2727
kernel_initializer: str = "glorot_uniform",
28-
groups: int = 8,
28+
groups: int = None,
2929
recurrent_type: str = "gru",
3030
recurrent_dim: int = 128,
3131
bidirectional: bool = True,
@@ -62,7 +62,7 @@ def __init__(
6262
Default is "glorot_uniform".
6363
groups : int, optional
6464
Number of groups for group normalization applied after each convolutional layer.
65-
Default is 8.
65+
Default is None.
6666
recurrent_type : str, optional
6767
Type of recurrent layer used for sequence modeling, such as "gru" or "lstm".
6868
Default is "gru".
@@ -99,7 +99,8 @@ def __init__(
9999
padding="same",
100100
)
101101
)
102-
self.conv_blocks.append(keras.layers.GroupNormalization(groups=groups))
102+
if groups is not None:
103+
self.conv_blocks.append(keras.layers.GroupNormalization(groups=groups))
103104

104105
# Recurrent and feedforward backbones
105106
self.recurrent = SkipRecurrentNet(
@@ -149,4 +150,3 @@ def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
149150
@sanitize_input_shape
150151
def build(self, input_shape):
151152
super().build(input_shape)
152-
self.call(keras.ops.zeros(input_shape))

0 commit comments

Comments
 (0)