Skip to content

Commit 8b9b16f

Browse files
LarsKueCopilot
andauthored
Update bayesflow/utils/classification/confusion_matrix.py
Co-authored-by: Copilot <[email protected]>
1 parent 4c34065 commit 8b9b16f

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

bayesflow/utils/classification/confusion_matrix.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ def confusion_matrix(targets: np.ndarray, estimates: np.ndarray, labels: Sequenc
2929
"""
3030

3131
# Get unique labels
32-
labels = np.asarray(labels) or np.unique(np.concatenate((targets, estimates)))
32+
if labels is None:
33+
labels = np.unique(np.concatenate((targets, estimates)))
34+
else:
35+
labels = np.asarray(labels)
3336

3437
label_to_index = {label: i for i, label in enumerate(labels)}
3538
num_labels = len(labels)

0 commit comments

Comments
 (0)