Skip to content

Commit 9464595

Browse files
committed
Fix ndim -> rank bug
1 parent eb911c7 commit 9464595

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

bayesflow/coupling_networks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,9 +436,9 @@ def _semantic_spline_parameters(self, parameters):
436436
"""
437437

438438
shape = tf.shape(parameters)
439-
if parameters.ndim == 2:
439+
if tf.rank(parameters) == 2:
440440
new_shape = (shape[0], self.dim_out, -1)
441-
elif parameters.ndim == 3:
441+
elif tf.rank(parameters) == 3:
442442
new_shape = (shape[0], shape[1], self.dim_out, -1)
443443
else:
444444
raise NotImplementedError("Spline flows can currently only operate on 2D and 3D inputs!")

bayesflow/helper_networks.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def call(self, target, condition, **kwargs):
112112

113113
# Handle 3D case for a set-flow and repeat condition over
114114
# the second `time` or `n_observations` axis of `target``
115-
if target.ndim == 3 and condition.ndim == 2:
115+
if tf.rank(target) == 3 and tf.rank(condition) == 2:
116116
shape = tf.shape(target)
117117
condition = tf.expand_dims(condition, 1)
118118
condition = tf.tile(condition, [1, shape[1], 1])
@@ -228,7 +228,7 @@ def _forward(self, target):
228228
"""Performs a learnable generalized permutation over the last axis."""
229229

230230
shape = tf.shape(target)
231-
rank = target.ndim
231+
rank = tf.rank(target)
232232
log_det = tf.math.log(tf.math.abs(tf.linalg.det(self.W)))
233233
if rank == 2:
234234
z = tf.linalg.matmul(target, self.W)
@@ -241,7 +241,7 @@ def _inverse(self, z):
241241
"""Un-does the learnable permutation over the last axis."""
242242

243243
W_inv = tf.linalg.inv(self.W)
244-
rank = z.ndim
244+
rank = tf.rank(z)
245245
if rank == 2:
246246
return tf.linalg.matmul(z, W_inv)
247247
return tf.tensordot(z, W_inv, [[rank - 1], [0]])
@@ -402,11 +402,11 @@ def _initalize_parameters_data_dependent(self, init_data):
402402
"""
403403

404404
# 2D Tensor case, assume first batch dimension
405-
if init_data.ndim == 2:
405+
if tf.rank(init_data) == 2:
406406
mean = tf.math.reduce_mean(init_data, axis=0)
407407
std = tf.math.reduce_std(init_data, axis=0)
408408
# 3D Tensor case, assume first batch dimension, second number of observations dimension
409-
elif init_data.ndim == 3:
409+
elif tf.rank(init_data) == 3:
410410
mean = tf.math.reduce_mean(init_data, axis=(0, 1))
411411
std = tf.math.reduce_std(init_data, axis=(0, 1))
412412
# Raise other cases
@@ -527,7 +527,7 @@ def call(self, x, **kwargs):
527527
# Example: Output dim is (batch_size, inv_dim) - > (batch_size, N, inv_dim)
528528
out_inv = self.invariant_module(x, **kwargs)
529529
out_inv = tf.expand_dims(out_inv, -2)
530-
tiler = [1] * x.ndim
530+
tiler = [1] * tf.rank(x)
531531
tiler[-2] = shape[-2]
532532
out_inv_rep = tf.tile(out_inv, tiler)
533533

bayesflow/inference_networks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def forward(self, targets, condition, **kwargs):
188188
condition_shape = tf.shape(condition)
189189

190190
# Needs to be concatinable with condition
191-
if condition.ndim == 2:
191+
if tf.rank(condition) == 2:
192192
shape_scale = (condition_shape[0], 1)
193193
else:
194194
shape_scale = (condition_shape[0], condition_shape[1], 1)
@@ -201,7 +201,7 @@ def forward(self, targets, condition, **kwargs):
201201
noise_scale = tf.zeros(shape=shape_scale) + self.soft_low
202202

203203
# Perturb data with noise (will broadcast to all dimensions)
204-
if len(shape_scale) == 2 and targets.ndim == 3:
204+
if len(shape_scale) == 2 and tf.rank(targets) == 3:
205205
targets += tf.expand_dims(noise_scale, axis=1) * tf.random.normal(shape=target_shape)
206206
else:
207207
targets += noise_scale * tf.random.normal(shape=target_shape)
@@ -228,7 +228,7 @@ def inverse(self, z, condition, **kwargs):
228228
if self.soft_flow and condition is not None:
229229
# Needs to be concatinable with condition
230230
shape_scale = (
231-
(condition.shape[0], 1) if condition.ndim == 2 else (condition.shape[0], condition.shape[1], 1)
231+
(condition.shape[0], 1) if tf.rank(condition) == 2 else (condition.shape[0], condition.shape[1], 1)
232232
)
233233
noise_scale = tf.zeros(shape=shape_scale) + 2.0 * self.soft_low
234234

0 commit comments

Comments
 (0)