Skip to content

Commit 376b20d

Browse files
Merge pull request #95 from stefanradev93/Development
Development
2 parents 98d895c + 733afef commit 376b20d

File tree

5 files changed

+48
-70
lines changed

5 files changed

+48
-70
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ amortized inference if the generative model is a poor representation of reality?
161161
A modified loss function optimizes the learned summary statistics towards a unit
162162
Gaussian and reliably detects model misspecification during inference time.
163163

164-
![](https://github.com/stefanradev93/BayesFlow/blob/master/docs/source/images/model_misspecification_amortized_sbi.png?raw=true)
164+
165+
<img src="https://github.com/stefanradev93/BayesFlow/blob/master/examples/img/model_misspecification_amortized_sbi.png" width=100% height=100%>
165166

166167
In order to use this method, you should only provide the `summary_loss_fun` argument
167168
to the `AmortizedPosterior` instance:

bayesflow/coupling_networks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,9 +436,9 @@ def _semantic_spline_parameters(self, parameters):
436436
"""
437437

438438
shape = tf.shape(parameters)
439-
if len(shape) == 2:
439+
if parameters.ndim == 2:
440440
new_shape = (shape[0], self.dim_out, -1)
441-
elif len(shape) == 3:
441+
elif parameters.ndim == 3:
442442
new_shape = (shape[0], shape[1], self.dim_out, -1)
443443
else:
444444
raise NotImplementedError("Spline flows can currently only operate on 2D and 3D inputs!")

bayesflow/helper_networks.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def call(self, target, condition, **kwargs):
112112

113113
# Handle 3D case for a set-flow and repeat condition over
114114
# the second `time` or `n_observations` axis of `target``
115-
if len(tf.shape(target)) == 3 and len(tf.shape(condition)) == 2:
115+
if target.ndim == 3 and condition.ndim == 2:
116116
shape = tf.shape(target)
117117
condition = tf.expand_dims(condition, 1)
118118
condition = tf.tile(condition, [1, shape[1], 1])
@@ -228,7 +228,7 @@ def _forward(self, target):
228228
"""Performs a learnable generalized permutation over the last axis."""
229229

230230
shape = tf.shape(target)
231-
rank = len(shape)
231+
rank = target.ndim
232232
log_det = tf.math.log(tf.math.abs(tf.linalg.det(self.W)))
233233
if rank == 2:
234234
z = tf.linalg.matmul(target, self.W)
@@ -241,7 +241,7 @@ def _inverse(self, z):
241241
"""Un-does the learnable permutation over the last axis."""
242242

243243
W_inv = tf.linalg.inv(self.W)
244-
rank = len(tf.shape(z))
244+
rank = z.ndim
245245
if rank == 2:
246246
return tf.linalg.matmul(z, W_inv)
247247
return tf.tensordot(z, W_inv, [[rank - 1], [0]])
@@ -402,11 +402,11 @@ def _initalize_parameters_data_dependent(self, init_data):
402402
"""
403403

404404
# 2D Tensor case, assume first batch dimension
405-
if len(init_data.shape) == 2:
405+
if init_data.ndim == 2:
406406
mean = tf.math.reduce_mean(init_data, axis=0)
407407
std = tf.math.reduce_std(init_data, axis=0)
408408
# 3D Tensor case, assume first batch dimension, second number of observations dimension
409-
elif len(init_data.shape) == 3:
409+
elif init_data.ndim == 3:
410410
mean = tf.math.reduce_mean(init_data, axis=(0, 1))
411411
std = tf.math.reduce_std(init_data, axis=(0, 1))
412412
# Raise other cases
@@ -527,7 +527,7 @@ def call(self, x, **kwargs):
527527
# Example: Output dim is (batch_size, inv_dim) - > (batch_size, N, inv_dim)
528528
out_inv = self.invariant_module(x, **kwargs)
529529
out_inv = tf.expand_dims(out_inv, -2)
530-
tiler = [1] * len(shape)
530+
tiler = [1] * x.ndim
531531
tiler[-2] = shape[-2]
532532
out_inv_rep = tf.tile(out_inv, tiler)
533533

bayesflow/inference_networks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def forward(self, targets, condition, **kwargs):
188188
condition_shape = tf.shape(condition)
189189

190190
# Needs to be concatinable with condition
191-
if len(condition_shape) == 2:
191+
if condition.ndim == 2:
192192
shape_scale = (condition_shape[0], 1)
193193
else:
194194
shape_scale = (condition_shape[0], condition_shape[1], 1)
@@ -201,7 +201,7 @@ def forward(self, targets, condition, **kwargs):
201201
noise_scale = tf.zeros(shape=shape_scale) + self.soft_low
202202

203203
# Perturb data with noise (will broadcast to all dimensions)
204-
if len(shape_scale) == 2 and len(target_shape) == 3:
204+
if len(shape_scale) == 2 and targets.ndim == 3:
205205
targets += tf.expand_dims(noise_scale, axis=1) * tf.random.normal(shape=target_shape)
206206
else:
207207
targets += noise_scale * tf.random.normal(shape=target_shape)
@@ -228,7 +228,7 @@ def inverse(self, z, condition, **kwargs):
228228
if self.soft_flow and condition is not None:
229229
# Needs to be concatinable with condition
230230
shape_scale = (
231-
(condition.shape[0], 1) if len(condition.shape) == 2 else (condition.shape[0], condition.shape[1], 1)
231+
(condition.shape[0], 1) if condition.ndim == 2 else (condition.shape[0], condition.shape[1], 1)
232232
)
233233
noise_scale = tf.zeros(shape=shape_scale) + 2.0 * self.soft_low
234234

examples/Hierarchical_Model_Comparison_MPT.ipynb

Lines changed: 35 additions & 58 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)