Skip to content

Commit 1b64a2c

Browse files
committed
Bugfixed save/load set transformers with inducing points
1 parent 8c3d8da commit 1b64a2c

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

bayesflow/attention.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,10 @@ def call(self, x, **kwargs):
191191
Output of shape (batch_size, set_size, input_dim)
192192
"""
193193

194-
batch_size = x.shape[0]
195-
h = self.mab0(tf.stack([self.I] * batch_size), x, **kwargs)
194+
batch_size = tf.shape(x)[0]
195+
I_expanded = self.I[None, ...]
196+
I_tiled = tf.tile(I_expanded, [batch_size, 1, 1])
197+
h = self.mab0(I_tiled, x, **kwargs)
196198
return self.mab1(x, h, **kwargs)
197199

198200

@@ -240,7 +242,7 @@ def __init__(
240242
summary_dim, attention_settings, num_dense_fc, dense_settings, use_layer_norm, **kwargs
241243
)
242244
init = tf.keras.initializers.GlorotUniform()
243-
self.seed_vec = init(shape=(num_seeds, summary_dim))
245+
self.seed_vec = tf.Variable(init(shape=(num_seeds, summary_dim)), name="seed_vec", trainable=True)
244246
self.fc = Sequential([Dense(**dense_settings) for _ in range(num_dense_fc)])
245247
self.fc.add(Dense(summary_dim))
246248

@@ -258,7 +260,9 @@ def call(self, x, **kwargs):
258260
Output of shape (batch_size, num_seeds * summary_dim)
259261
"""
260262

261-
batch_size = x.shape[0]
262263
out = self.fc(x)
263-
out = self.mab(tf.stack([self.seed_vec] * batch_size), out, **kwargs)
264-
return tf.reshape(out, (out.shape[0], -1))
264+
batch_size = tf.shape(x)[0]
265+
seed_expanded = self.seed_vec[None, ...]
266+
seed_tiled = tf.tile(seed_expanded, [batch_size, 1, 1])
267+
out = self.mab(seed_tiled, out, **kwargs)
268+
return tf.reshape(out, (tf.shape(out)[0], -1))

0 commit comments

Comments
 (0)