Skip to content

Commit e5e9437

Browse files
committed
Add bidirectional option to recurrent nets
1 parent 733afef commit e5e9437

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

bayesflow/summary_networks.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from warnings import warn
2222

2323
import tensorflow as tf
24-
from tensorflow.keras.layers import GRU, LSTM, Dense
24+
from tensorflow.keras.layers import GRU, LSTM, Bidirectional, Dense
2525
from tensorflow.keras.models import Sequential
2626

2727
from bayesflow import default_settings as defaults
@@ -53,6 +53,7 @@ def __init__(
5353
summary_dim=10,
5454
num_attention_blocks=2,
5555
template_type="lstm",
56+
bidirectional=False,
5657
template_dim=64,
5758
**kwargs,
5859
):
@@ -104,6 +105,10 @@ def __init__(
104105
if ``lstm``, an LSTM network will be used.
105106
if ``gru``, a GRU unit will be used.
106107
if callable, a reference to ``template_type`` will be stored as an attribute.
108+
bidirectional : bool, optional, default: False
109+
Indicates whether the involved LSTM template network is bidirectional (i.e., forward
110+
and backward in time) or unidirectional (forward in time). Defaults to False, but may
111+
increase performance in some applications.
107112
template_dim : int, optional, default: 64
108113
Only used if ``template_type`` in ['lstm', 'gru']. The number of hidden
109114
units (equiv. output dimensions) of the recurrent network.
@@ -134,9 +139,9 @@ def __init__(
134139

135140
# A recurrent network will learn the dynamic many-to-one template
136141
if template_type.upper() == "LSTM":
137-
self.template_net = LSTM(template_dim)
142+
self.template_net = Bidirectional(LSTM(template_dim)) if bidirectional else LSTM(template_dim)
138143
elif template_type.upper() == "GRU":
139-
self.template_net = GRU(template_dim)
144+
self.template_net = Bidirectional(GRU(template_dim)) if bidirectional else GRU(template_dim)
140145
else:
141146
assert callable(template_type), "Argument `template_dim` should be callable or in ['lstm', 'gru']"
142147
self.template_net = template_type
@@ -418,7 +423,9 @@ class SequentialNetwork(tf.keras.Model):
418423
https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1009472
419424
"""
420425

421-
def __init__(self, summary_dim=10, num_conv_layers=2, lstm_units=128, conv_settings=None, **kwargs):
426+
def __init__(
427+
self, summary_dim=10, num_conv_layers=2, lstm_units=128, bidirectional=False, conv_settings=None, **kwargs
428+
):
422429
"""Creates a stack of inception-like layers followed by an LSTM network, with the idea
423430
of learning vector representations from multivariate time series data.
424431
@@ -437,6 +444,9 @@ def __init__(self, summary_dim=10, num_conv_layers=2, lstm_units=128, conv_setti
437444
- layer_args (dict) : arguments for `tf.keras.layers.Conv1D` without kernel_size
438445
- min_kernel_size (int) : the minimum kernel size (>= 1)
439446
- max_kernel_size (int) : the maximum kernel size
447+
bidirectional : bool, optional, default: False
448+
Indicates whether the involved LSTM network is bidirectional (forward and backward in time)
449+
or unidirectional (forward in time). Defaults to False, but may increase performance.
440450
**kwargs : dict
441451
Optional keyword arguments passed to the __init__() method of tf.keras.Model
442452
"""
@@ -449,7 +459,7 @@ def __init__(self, summary_dim=10, num_conv_layers=2, lstm_units=128, conv_setti
449459

450460
self.net = Sequential([MultiConv1D(conv_settings) for _ in range(num_conv_layers)])
451461

452-
self.lstm = LSTM(lstm_units)
462+
self.lstm = Bidirectional(LSTM(lstm_units)) if bidirectional else LSTM(lstm_units)
453463
self.out_layer = Dense(summary_dim, activation="linear")
454464
self.summary_dim = summary_dim
455465

0 commit comments

Comments
 (0)