Skip to content

Commit 31a6251

Browse files
committed
Fix softflow bug for splines in 3D
1 parent af19085 commit 31a6251

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

bayesflow/inference_networks.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,21 +183,29 @@ def forward(self, targets, condition, **kwargs):
183183
# Add noise to target if using SoftFlow, use explicitly
184184
# not in call(), since methods are public
185185
if self.soft_flow and condition is not None:
186+
# Extract shapes of tensors
187+
target_shape = tf.shape(targets)
188+
condition_shape = tf.shape(condition)
189+
186190
# Needs to be concatinable with condition
187-
shape_scale = (
188-
(condition.shape[0], 1) if len(condition.shape) == 2 else (condition.shape[0], condition.shape[1], 1)
189-
)
191+
if len(condition_shape) == 2:
192+
shape_scale = (condition_shape[0], 1)
193+
else:
194+
shape_scale = (condition_shape[0], condition_shape[1], 1)
195+
190196
# Case training mode
191197
if kwargs.get("training"):
192198
noise_scale = tf.random.uniform(shape=shape_scale, minval=self.soft_low, maxval=self.soft_high)
193199
# Case inference mode
194200
else:
195201
noise_scale = tf.zeros(shape=shape_scale) + self.soft_low
202+
196203
# Perturb data with noise (will broadcast to all dimensions)
197-
if len(shape_scale) == 2 and len(targets.shape) == 3:
198-
targets += tf.expand_dims(noise_scale, axis=1) * tf.random.normal(shape=targets.shape)
204+
if len(shape_scale) == 2 and len(target_shape) == 3:
205+
targets += tf.expand_dims(noise_scale, axis=1) * tf.random.normal(shape=target_shape)
199206
else:
200-
targets += noise_scale * tf.random.normal(shape=targets.shape)
207+
targets += noise_scale * tf.random.normal(shape=target_shape)
208+
201209
# Augment condition with noise scale variate
202210
condition = tf.concat((condition, noise_scale), axis=-1)
203211

0 commit comments

Comments
 (0)