Skip to content

Commit c6fc402

Browse files
committed
Transformers reloaded
1 parent 967c4e5 commit c6fc402

File tree

6 files changed

+96
-36
lines changed

6 files changed

+96
-36
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .fourier_embedding import FourierEmbedding
22
from .time2vec import Time2Vec
3+
from .recurrent_embedding import RecurrentEmbedding
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import keras
2+
from keras.saving import register_keras_serializable as serializable
3+
4+
from bayesflow.types import Tensor
5+
from bayesflow.utils import expand_tile
6+
7+
8+
@serializable(package="bayesflow.networks")
9+
class RecurrentEmbedding(keras.Layer):
10+
"""Implements a recurrent network for embedding time."""
11+
12+
def __init__(self, embed_dim: int = 8, embedding: str = "lstm"):
13+
super().__init__()
14+
15+
self.embed_dim = embed_dim
16+
self.embedding = embedding
17+
if embedding == "lstm":
18+
self.embedder = keras.layers.LSTM(embed_dim, return_sequences=True)
19+
elif embedding == "gru":
20+
self.embedder = keras.layers.GRU(embed_dim, return_sequences=True)
21+
else:
22+
raise ValueError(f"Unknown embedding type {embedding}. Must be in ['lstm', 'gru']")
23+
24+
def call(self, x: Tensor, t: Tensor = None) -> Tensor:
25+
"""Creates time representations and concatenates them to x.
26+
27+
Parameters:
28+
-----------
29+
x : Tensor of shape (batch_size, sequence_length, dim)
30+
The input sequence.
31+
t : Tensor of shape (batch_size, sequence_length)
32+
Vector of times
33+
34+
Returns:
35+
--------
36+
emb : Tensor
37+
Embedding of shape (batch_size, sequence_length, embed_dim + 1)
38+
"""
39+
40+
if t is None:
41+
t = keras.ops.linspace(0, keras.ops.shape(x)[1], keras.ops.shape(x)[1], dtype=x.dtype)
42+
t = expand_tile(t, keras.ops.shape(x)[0], axis=0)
43+
44+
emb = self.embedder(t[..., None])
45+
return keras.ops.concatenate([x, emb], axis=-1)

bayesflow/networks/embeddings/time2vec.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ def call(self, x: Tensor, t: Tensor = None) -> Tensor:
5858
Returns:
5959
--------
6060
emb : Tensor
61-
Embedding of shape (batch_size, fourier_emb_dim) if `include_identity`
62-
is False, else (batch_size, fourier_emb_dim+1)
61+
Embedding of shape (batch_size, sequence_length, num_periodic_features + 1)
6362
"""
6463

6564
if t is None:

bayesflow/networks/transformers/fusion_transformer.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from bayesflow.types import Tensor
66
from bayesflow.utils import check_lengths_same
77

8-
from ..embeddings import Time2Vec
98
from ..summary_network import SummaryNetwork
109

1110
from .mab import MultiHeadAttentionBlock
@@ -28,15 +27,17 @@ def __init__(
2827
kernel_initializer: str = "he_normal",
2928
use_bias: bool = True,
3029
layer_norm: bool = True,
31-
t2v_embed_dim: int = 8,
3230
template_type: str = "lstm",
3331
bidirectional: bool = True,
3432
template_dim: int = 128,
3533
**kwargs,
3634
):
37-
"""Creates a fusion transformer used to flexibly compress time series. If the time intervals vary across
38-
batches, it is highly recommended that your simulator also returns a "time" vector denoting absolute or
39-
relative time.
35+
"""Creates a fusion transformer used to flexibly compress time series and learn additional time embeddings
36+
using a recurrent neural network. If the time intervals vary across batches, it is highly recommended that
37+
your simulator also returns a "time" vector appended to the simulator outputs.
38+
39+
Important: This network needs at least 2 transformer blocks and will generally be slower than the
40+
corresponding TimeSeriesTransformer.
4041
4142
Parameters
4243
----------
@@ -73,6 +74,8 @@ def __init__(
7374
template_dim : int, optional (default - 128)
7475
Only used if ``template_type`` in ['lstm', 'gru']. The number of hidden
7576
units (equiv. output dimensions) of the recurrent network.
77+
time_axis : int, optional (default - None)
78+
The time axis (e.g., -1 for last axis) from which to grab the time vector that goes into t2v.
7679
**kwargs : dict
7780
Additional keyword arguments passed to the base layer.
7881
"""
@@ -82,9 +85,6 @@ def __init__(
8285
# Ensure all tuple-settings have the same length
8386
check_lengths_same(embed_dims, num_heads, mlp_depths, mlp_widths)
8487

85-
# Initialize Time2Vec embedding layer
86-
self.time2vec = Time2Vec(t2v_embed_dim)
87-
8888
# Construct a series of set-attention blocks
8989
self.attention_blocks = []
9090
for i in range(len(embed_dims)):
@@ -121,17 +121,13 @@ def __init__(
121121

122122
self.output_projector = keras.layers.Dense(summary_dim)
123123

124-
def call(self, input_sequence: Tensor, time: Tensor = None, training: bool = False, **kwargs) -> Tensor:
124+
def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor:
125125
"""Compresses the input sequence into a summary vector of size `summary_dim`.
126126
127127
Parameters
128128
----------
129129
input_sequence : Tensor
130130
Input of shape (batch_size, sequence_length, input_dim)
131-
time : Tensor
132-
Time vector of shape (batch_size, sequence_length), optional (default - None)
133-
Note: time values for Time2Vec embeddings will be inferred on a linearly spaced
134-
interval between [0, sequence length], if no time vector is specified.
135131
training : boolean, optional (default - False)
136132
Passed to the optional internal dropout and spectral normalization
137133
layers to distinguish between train and test time behavior.
@@ -145,12 +141,12 @@ def call(self, input_sequence: Tensor, time: Tensor = None, training: bool = Fal
145141
Output of shape (batch_size, set_size, output_dim)
146142
"""
147143

148-
inp = self.time2vec(input_sequence, t=time)
149-
template = self.template_net(inp, training=training)
144+
template = self.template_net(input_sequence, training=training)
150145

146+
rep = input_sequence
151147
for layer in self.attention_blocks[:-1]:
152-
inp = layer(inp, inp, training=training, **kwargs)
148+
rep = layer(rep, rep, training=training, **kwargs)
153149

154-
summary = self.attention_blocks[-1](keras.ops.expand_dims(template, axis=1), inp, training=training, **kwargs)
150+
summary = self.attention_blocks[-1](keras.ops.expand_dims(template, axis=1), rep, training=training, **kwargs)
155151
summary = self.output_projector(keras.ops.squeeze(summary, axis=1))
156152
return summary

bayesflow/networks/transformers/mab.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,7 @@ def __init__(
4040

4141
self.input_projector = layers.Dense(embed_dim)
4242
self.attention = layers.MultiHeadAttention(
43-
key_dim=embed_dim,
44-
num_heads=num_heads,
45-
dropout=dropout,
46-
use_bias=use_bias,
43+
key_dim=embed_dim, num_heads=num_heads, dropout=dropout, use_bias=use_bias, output_shape=embed_dim
4744
)
4845
self.ln_pre = layers.LayerNormalization() if layer_norm else None
4946
self.mlp = MLP(

bayesflow/networks/transformers/time_series_transformer.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from bayesflow.types import Tensor
55
from bayesflow.utils import check_lengths_same
66

7-
from ..embeddings import Time2Vec
7+
from ..embeddings import Time2Vec, RecurrentEmbedding
88
from ..summary_network import SummaryNetwork
99

1010
from .mab import MultiHeadAttentionBlock
@@ -24,12 +24,14 @@ def __init__(
2424
kernel_initializer: str = "he_normal",
2525
use_bias: bool = True,
2626
layer_norm: bool = True,
27-
t2v_embed_dim: int = 8,
27+
time_embedding: str = "time2vec",
28+
time_embed_dim: int = 8,
29+
time_axis: int = None,
2830
**kwargs,
2931
):
3032
"""Creates a regular transformer coupled with Time2Vec embeddings of time used to flexibly compress time series.
3133
If the time intervals vary across batches, it is highly recommended that your simulator also returns a "time"
32-
vector denoting absolute or relative time.
34+
vector appended to the simulator outputs and specified via the "time_axis" argument.
3335
3436
Parameters
3537
----------
@@ -53,8 +55,14 @@ def __init__(
5355
Whether to include a bias term in the dense layers.
5456
layer_norm : bool, optional (default - True)
5557
Whether to apply layer normalization after the attention and MLP layers.
56-
t2v_embed_dim : int, optional (default - 8)
57-
The dimensionality of the Time2Vec embedding.
58+
time_embedding : str, optional (default - "time2vec")
59+
The type of embedding to use. Must be in ["time2vec", "lstm", "gru"]
60+
time_embed_dim : int, optional (default - 8)
61+
The dimensionality of the Time2Vec or recurrent embedding.
62+
time_axis : int, optional (default - None)
63+
The time axis (e.g., -1 for last axis) from which to grab the time vector that goes into the embedding.
64+
If an embedding is provided and time_axis is None, a uniform time interval between [0, sequence_len]
65+
will be assumed.
5866
**kwargs : dict
5967
Additional keyword arguments passed to the base layer.
6068
"""
@@ -65,7 +73,14 @@ def __init__(
6573
check_lengths_same(embed_dims, num_heads, mlp_depths, mlp_widths)
6674

6775
# Initialize Time2Vec embedding layer
68-
self.time2vec = Time2Vec(t2v_embed_dim)
76+
if time_embedding is None:
77+
self.time_embedding = None
78+
elif time_embedding == "time2vec":
79+
self.time_embedding = Time2Vec(num_periodic_features=time_embed_dim - 1)
80+
elif time_embedding in ["lstm", "gru"]:
81+
self.time_embedding = RecurrentEmbedding(time_embed_dim, time_embedding)
82+
else:
83+
raise ValueError("Embedding not found!")
6984

7085
# Construct a series of set-attention blocks
7186
self.attention_blocks = []
@@ -89,17 +104,15 @@ def __init__(
89104
self.pooling = keras.layers.GlobalAvgPool1D()
90105
self.output_projector = keras.layers.Dense(summary_dim)
91106

92-
def call(self, input_sequence: Tensor, time: Tensor = None, training: bool = False, **kwargs) -> Tensor:
107+
self.time_axis = time_axis
108+
109+
def call(self, input_sequence: Tensor, training: bool = False, **kwargs) -> Tensor:
93110
"""Compresses the input sequence into a summary vector of size `summary_dim`.
94111
95112
Parameters
96113
----------
97114
input_sequence : Tensor
98115
Input of shape (batch_size, sequence_length, input_dim)
99-
time : Tensor
100-
Time vector of shape (batch_size, sequence_length), optional (default - None)
101-
Note: time values for Time2Vec embeddings will be inferred on a linearly spaced
102-
interval between [0, sequence length], if no time vector is specified.
103116
training : boolean, optional (default - False)
104117
Passed to the optional internal dropout and spectral normalization
105118
layers to distinguish between train and test time behavior.
@@ -113,8 +126,17 @@ def call(self, input_sequence: Tensor, time: Tensor = None, training: bool = Fal
113126
Output of shape (batch_size, set_size, output_dim)
114127
"""
115128

116-
# Concatenate learnable time embedding to input sequence
117-
inp = self.time2vec(input_sequence, t=time)
129+
if self.time_axis is not None:
130+
t = input_sequence[..., self.time_axis]
131+
indices = list(range(keras.ops.shape(input_sequence)[-1]))
132+
indices.pop(self.time_axis)
133+
inp = keras.ops.take(input_sequence, indices, axis=-1)
134+
else:
135+
t = None
136+
inp = input_sequence
137+
138+
if self.time_embedding:
139+
inp = self.time_embedding(inp, t=t)
118140

119141
# Apply self-attention blocks
120142
for layer in self.attention_blocks:

0 commit comments

Comments
 (0)