44from bayesflow .types import Tensor
55from bayesflow .utils import check_lengths_same
66
7- from ..embeddings import Time2Vec
7+ from ..embeddings import Time2Vec , RecurrentEmbedding
88from ..summary_network import SummaryNetwork
99
1010from .mab import MultiHeadAttentionBlock
@@ -24,12 +24,14 @@ def __init__(
2424 kernel_initializer : str = "he_normal" ,
2525 use_bias : bool = True ,
2626 layer_norm : bool = True ,
27- t2v_embed_dim : int = 8 ,
27+ time_embedding : str = "time2vec" ,
28+ time_embed_dim : int = 8 ,
29+ time_axis : int = None ,
2830 ** kwargs ,
2931 ):
3032 """Creates a regular transformer coupled with Time2Vec embeddings of time used to flexibly compress time series.
3133 If the time intervals vary across batches, it is highly recommended that your simulator also returns a "time"
32- vector denoting absolute or relative time .
34+ vector appended to the simulator outputs and specified via the "time_axis" argument .
3335
3436 Parameters
3537 ----------
@@ -53,8 +55,14 @@ def __init__(
5355 Whether to include a bias term in the dense layers.
5456 layer_norm : bool, optional (default - True)
5557 Whether to apply layer normalization after the attention and MLP layers.
56- t2v_embed_dim : int, optional (default - 8)
57- The dimensionality of the Time2Vec embedding.
58+ time_embedding : str, optional (default - "time2vec")
59+ The type of embedding to use. Must be in ["time2vec", "lstm", "gru"]
60+ time_embed_dim : int, optional (default - 8)
61+ The dimensionality of the Time2Vec or recurrent embedding.
62+ time_axis : int, optional (default - None)
63+ The time axis (e.g., -1 for last axis) from which to grab the time vector that goes into the embedding.
64+ If an embedding is provided and time_axis is None, a uniform time interval between [0, sequence_len]
65+ will be assumed.
5866 **kwargs : dict
5967 Additional keyword arguments passed to the base layer.
6068 """
@@ -65,7 +73,14 @@ def __init__(
6573 check_lengths_same (embed_dims , num_heads , mlp_depths , mlp_widths )
6674
6775 # Initialize Time2Vec embedding layer
68- self .time2vec = Time2Vec (t2v_embed_dim )
76+ if time_embedding is None :
77+ self .time_embedding = None
78+ elif time_embedding == "time2vec" :
79+ self .time_embedding = Time2Vec (num_periodic_features = time_embed_dim - 1 )
80+ elif time_embedding in ["lstm" , "gru" ]:
81+ self .time_embedding = RecurrentEmbedding (time_embed_dim , time_embedding )
82+ else :
83+ raise ValueError ("Embedding not found!" )
6984
7085 # Construct a series of set-attention blocks
7186 self .attention_blocks = []
@@ -89,17 +104,15 @@ def __init__(
89104 self .pooling = keras .layers .GlobalAvgPool1D ()
90105 self .output_projector = keras .layers .Dense (summary_dim )
91106
92- def call (self , input_sequence : Tensor , time : Tensor = None , training : bool = False , ** kwargs ) -> Tensor :
107+ self .time_axis = time_axis
108+
109+ def call (self , input_sequence : Tensor , training : bool = False , ** kwargs ) -> Tensor :
93110 """Compresses the input sequence into a summary vector of size `summary_dim`.
94111
95112 Parameters
96113 ----------
97114 input_sequence : Tensor
98115 Input of shape (batch_size, sequence_length, input_dim)
99- time : Tensor
100- Time vector of shape (batch_size, sequence_length), optional (default - None)
101- Note: time values for Time2Vec embeddings will be inferred on a linearly spaced
102- interval between [0, sequence length], if no time vector is specified.
103116 training : boolean, optional (default - False)
104117 Passed to the optional internal dropout and spectral normalization
105118 layers to distinguish between train and test time behavior.
@@ -113,8 +126,17 @@ def call(self, input_sequence: Tensor, time: Tensor = None, training: bool = Fal
113126 Output of shape (batch_size, set_size, output_dim)
114127 """
115128
116- # Concatenate learnable time embedding to input sequence
117- inp = self .time2vec (input_sequence , t = time )
129+ if self .time_axis is not None :
130+ t = input_sequence [..., self .time_axis ]
131+ indices = list (range (keras .ops .shape (input_sequence )[- 1 ]))
132+ indices .pop (self .time_axis )
133+ inp = keras .ops .take (input_sequence , indices , axis = - 1 )
134+ else :
135+ t = None
136+ inp = input_sequence
137+
138+ if self .time_embedding :
139+ inp = self .time_embedding (inp , t = t )
118140
119141 # Apply self-attention blocks
120142 for layer in self .attention_blocks :
0 commit comments