Skip to content

Commit 07f7546

Browse files
committed
pass probs insteads of predictions to classifier metrics
This is what most classifier metrics expect, and contains more detail for the metrics to work with.
1 parent af3f19c commit 07f7546

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ def compute_metrics(
160160

161161
if stage != "training" and any(self.classifier_network.metrics):
162162
# compute sample-based metrics
163-
predictions = keras.ops.argmax(logits, axis=-1)
163+
probs = keras.ops.softmax(logits)
164164
classifier_metrics |= {
165-
metric.name: metric(model_indices, predictions) for metric in self.classifier_network.metrics
165+
metric.name: metric(model_indices, probs) for metric in self.classifier_network.metrics
166166
}
167167
if "loss" in summary_metrics:
168168
loss = classifier_metrics["loss"] + summary_metrics["loss"]

0 commit comments

Comments
 (0)