2121from warnings import warn
2222
2323import tensorflow as tf
24- from tensorflow .keras .layers import GRU , LSTM , Dense
24+ from tensorflow .keras .layers import GRU , LSTM , Bidirectional , Dense
2525from tensorflow .keras .models import Sequential
2626
2727from bayesflow import default_settings as defaults
@@ -53,6 +53,7 @@ def __init__(
5353 summary_dim = 10 ,
5454 num_attention_blocks = 2 ,
5555 template_type = "lstm" ,
56+ bidirectional = False ,
5657 template_dim = 64 ,
5758 ** kwargs ,
5859 ):
@@ -104,6 +105,10 @@ def __init__(
104105 if ``lstm``, an LSTM network will be used.
105106 if ``gru``, a GRU unit will be used.
106107 if callable, a reference to ``template_type`` will be stored as an attribute.
108+ bidirectional : bool, optional, default: False
109+ Indicates whether the involved LSTM template network is bidirectional (i.e., forward
110+ and backward in time) or unidirectional (forward in time). Defaults to False, but may
111+ increase performance in some applications.
107112 template_dim : int, optional, default: 64
108113 Only used if ``template_type`` in ['lstm', 'gru']. The number of hidden
109114 units (equiv. output dimensions) of the recurrent network.
@@ -134,9 +139,9 @@ def __init__(
134139
135140 # A recurrent network will learn the dynamic many-to-one template
136141 if template_type .upper () == "LSTM" :
137- self .template_net = LSTM (template_dim )
142+ self .template_net = Bidirectional ( LSTM ( template_dim )) if bidirectional else LSTM (template_dim )
138143 elif template_type .upper () == "GRU" :
139- self .template_net = GRU (template_dim )
144+ self .template_net = Bidirectional ( GRU ( template_dim )) if bidirectional else GRU (template_dim )
140145 else :
141146 assert callable (template_type ), "Argument `template_dim` should be callable or in ['lstm', 'gru']"
142147 self .template_net = template_type
@@ -418,7 +423,9 @@ class SequentialNetwork(tf.keras.Model):
418423 https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1009472
419424 """
420425
421- def __init__ (self , summary_dim = 10 , num_conv_layers = 2 , lstm_units = 128 , conv_settings = None , ** kwargs ):
426+ def __init__ (
427+ self , summary_dim = 10 , num_conv_layers = 2 , lstm_units = 128 , bidirectional = False , conv_settings = None , ** kwargs
428+ ):
422429 """Creates a stack of inception-like layers followed by an LSTM network, with the idea
423430 of learning vector representations from multivariate time series data.
424431
@@ -437,6 +444,9 @@ def __init__(self, summary_dim=10, num_conv_layers=2, lstm_units=128, conv_setti
437444 - layer_args (dict) : arguments for `tf.keras.layers.Conv1D` without kernel_size
438445 - min_kernel_size (int) : the minimum kernel size (>= 1)
439446 - max_kernel_size (int) : the maximum kernel size
447+ bidirectional : bool, optional, default: False
448+ Indicates whether the involved LSTM network is bidirectional (forward and backward in time)
449+ or unidirectional (forward in time). Defaults to False, but may increase performance.
440450 **kwargs : dict
441451 Optional keyword arguments passed to the __init__() method of tf.keras.Model
442452 """
@@ -449,7 +459,7 @@ def __init__(self, summary_dim=10, num_conv_layers=2, lstm_units=128, conv_setti
449459
450460 self .net = Sequential ([MultiConv1D (conv_settings ) for _ in range (num_conv_layers )])
451461
452- self .lstm = LSTM (lstm_units )
462+ self .lstm = Bidirectional ( LSTM ( lstm_units )) if bidirectional else LSTM (lstm_units )
453463 self .out_layer = Dense (summary_dim , activation = "linear" )
454464 self .summary_dim = summary_dim
455465
0 commit comments