We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent af3f19c commit 07f7546Copy full SHA for 07f7546
bayesflow/approximators/model_comparison_approximator.py
@@ -160,9 +160,9 @@ def compute_metrics(
160
161
if stage != "training" and any(self.classifier_network.metrics):
162
# compute sample-based metrics
163
- predictions = keras.ops.argmax(logits, axis=-1)
+ probs = keras.ops.softmax(logits)
164
classifier_metrics |= {
165
- metric.name: metric(model_indices, predictions) for metric in self.classifier_network.metrics
+ metric.name: metric(model_indices, probs) for metric in self.classifier_network.metrics
166
}
167
if "loss" in summary_metrics:
168
loss = classifier_metrics["loss"] + summary_metrics["loss"]
0 commit comments