Skip to content

Commit eeaee04

Browse files
committed
Bugfix of ranks 2
1 parent 9464595 commit eeaee04

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

bayesflow/coupling_networks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,9 +436,9 @@ def _semantic_spline_parameters(self, parameters):
436436
"""
437437

438438
shape = tf.shape(parameters)
439-
if tf.rank(parameters) == 2:
439+
if len(shape) == 2:
440440
new_shape = (shape[0], self.dim_out, -1)
441-
elif tf.rank(parameters) == 3:
441+
elif len(shape) == 3:
442442
new_shape = (shape[0], shape[1], self.dim_out, -1)
443443
else:
444444
raise NotImplementedError("Spline flows can currently only operate on 2D and 3D inputs!")

bayesflow/helper_networks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def _forward(self, target):
228228
"""Performs a learnable generalized permutation over the last axis."""
229229

230230
shape = tf.shape(target)
231-
rank = tf.rank(target)
231+
rank = len(shape)
232232
log_det = tf.math.log(tf.math.abs(tf.linalg.det(self.W)))
233233
if rank == 2:
234234
z = tf.linalg.matmul(target, self.W)
@@ -241,7 +241,7 @@ def _inverse(self, z):
241241
"""Un-does the learnable permutation over the last axis."""
242242

243243
W_inv = tf.linalg.inv(self.W)
244-
rank = tf.rank(z)
244+
rank = len(tf.shape(z))
245245
if rank == 2:
246246
return tf.linalg.matmul(z, W_inv)
247247
return tf.tensordot(z, W_inv, [[rank - 1], [0]])
@@ -527,7 +527,7 @@ def call(self, x, **kwargs):
527527
# Example: Output dim is (batch_size, inv_dim) - > (batch_size, N, inv_dim)
528528
out_inv = self.invariant_module(x, **kwargs)
529529
out_inv = tf.expand_dims(out_inv, -2)
530-
tiler = [1] * tf.rank(x)
530+
tiler = [1] * len(shape)
531531
tiler[-2] = shape[-2]
532532
out_inv_rep = tf.tile(out_inv, tiler)
533533

0 commit comments

Comments
 (0)