Skip to content

Commit e45fcc2

Browse files
committed
Fix major bug affecting saving and loading of learnable permutation
1 parent a9f84dd commit e45fcc2

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

bayesflow/helper_networks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,9 @@ def __init__(self, input_dim):
198198
super().__init__()
199199

200200
init = tf.keras.initializers.Orthogonal()
201-
self.W = init(shape=(input_dim, input_dim))
201+
self.W = tf.Variable(
202+
initial_value=init(shape=(input_dim, input_dim)), trainable=True, dtype=tf.float32, name="learnable_permute"
203+
)
202204

203205
def call(self, target, inverse=False):
204206
"""Transforms a batch of target vectors over the last axis through an approximately

0 commit comments

Comments
 (0)