99@serializable (package = "bayesflow.networks" )
1010class SkipRecurrentNet (keras .Model ):
1111 """
12- Implements a Skip recurrent layer as described in [1], but allowing a more flexible
13- recurrent backbone and a more flexible implementation.
12+ Implements a Skip recurrent layer as described in [1], allowing a more flexible recurrent backbone
13+ and a more efficient implementation.
1414
1515 [1] Y. Zhang and L. Mikelsons, Solving Stochastic Inverse Problems with Stochastic BayesFlow,
1616 2023 IEEE/ASME International Conference on Advanced Intelligent Mechatronics (AIM),
1717 Seattle, WA, USA, 2023, pp. 966-972, doi: 10.1109/AIM46323.2023.10196190.
18-
19- TODO: Add proper docstring
20-
2118 """
2219
2320 def __init__ (
@@ -30,6 +27,32 @@ def __init__(
3027 dropout : float = 0.05 ,
3128 ** kwargs ,
3229 ):
30+ """
31+ Creates a skip recurrent neural network layer that extends a traditional recurrent backbone with
32+ skip connections implemented via convolution and an additional recurrent path. This allows
33+ more efficient modeling of long-term dependencies by combining local and non-local temporal
34+ features.
35+
36+ Parameters
37+ ----------
38+ hidden_dim : int, optional
39+ Dimensionality of the hidden state in the recurrent layers. Default is 256.
40+ recurrent_type : str, optional
41+ Type of recurrent unit to use. Should correspond to a supported type in `find_recurrent_net`,
42+ such as "gru" or "lstm". Default is "gru".
43+ bidirectional : bool, optional
44+ If True, uses bidirectional wrappers for both recurrent and skip recurrent layers. Default is True.
45+ input_channels : int, optional
46+ Number of input channels for the 1D convolution used in skip connections. Default is 64.
47+ skip_steps : int, optional
48+ Step size and kernel size used in the skip convolution. Determines how many steps are skipped.
49+ Also determines the multiplier for the number of filters. Default is 4.
50+ dropout : float, optional
51+ Dropout rate applied within the recurrent layers. Default is 0.05.
52+ **kwargs
53+ Additional keyword arguments passed to the parent class constructor.
54+ """
55+
3356 super ().__init__ (** keras_kwargs (kwargs ))
3457
3558 self .skip_conv = keras .layers .Conv1D (
@@ -64,4 +87,4 @@ def call(self, time_series: Tensor, training: bool = False, **kwargs) -> Tensor:
6487
6588 @sanitize_input_shape
6689 def build (self , input_shape ):
67- self . call ( keras . ops . zeros (input_shape ) )
90+ super (). build (input_shape )
0 commit comments