Skip to content

Commit 29e4ba0

Browse files
committed
Bayes Estimators: Lossfunctions with flexible signature
1 parent 65961be commit 29e4ba0

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

bayesflow/amortizers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from bayesflow.default_settings import DEFAULT_KEYS
3232
from bayesflow.exceptions import ConfigurationError, SummaryStatsError
3333
from bayesflow.helper_functions import check_tensor_sanity
34-
from bayesflow.losses import log_loss, mmd_summary_space
34+
from bayesflow.losses import log_loss, mmd_summary_space, norm_diff
3535
from bayesflow.networks import EvidentialNetwork
3636

3737

@@ -1337,7 +1337,7 @@ def compute_loss(self, input_dict, **kwargs):
13371337
"""
13381338

13391339
net_out = self(input_dict, **kwargs)
1340-
loss = tf.reduce_mean(self.loss_fn(net_out - input_dict[DEFAULT_KEYS["parameters"]]))
1340+
loss = tf.reduce_mean(self.loss_fn(net_out, input_dict[DEFAULT_KEYS["parameters"]]))
13411341
return loss
13421342

13431343
def _compute_summary_condition(self, summary_conditions, direct_conditions, **kwargs):
@@ -1366,4 +1366,4 @@ def _determine_loss(self, loss_fun, norm_ord):
13661366
# In case of user-provided loss, override norm order
13671367
if loss_fun is not None:
13681368
return loss_fun
1369-
return partial(tf.norm, ord=norm_ord, axis=-1)
1369+
return partial(norm_diff, ord=norm_ord, axis=-1)

bayesflow/losses.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,20 @@ def log_loss(model_indices, preds, evidential=False, label_smoothing=0.01):
184184
# Actual loss + regularization (if given)
185185
loss = -tf.reduce_mean(tf.reduce_sum(model_indices * tf.math.log(preds), axis=1))
186186
return loss
187+
188+
189+
def norm_diff(tensor_a, tensor_b, axis=None, ord='euclidean'):
190+
"""
191+
Wrapper around tf.norm that computes the norm of the difference between two tensors along the specified axis.
192+
193+
Parameters
194+
----------
195+
tensor_a : A Tensor.
196+
tensor_b : A Tensor. Must be the same shape as tensor_a.
197+
axis : Any or None
198+
Axis along which to compute the norm of the difference. Default is None.
199+
ord : int or str
200+
Order of the norm. Supports 'euclidean' and other norms supported by tf.norm. Default is 'euclidean'.
201+
"""
202+
difference = tensor_a - tensor_b
203+
return tf.norm(difference, ord=ord, axis=axis)

0 commit comments

Comments
 (0)