Skip to content

Commit 93ca6a9

Browse files
Merge pull request #334 from Kucharssim/fix-mc_calibration
`expected_calibration_error` does not crash with `VariableArray` objects
2 parents 07ec251 + 3744793 commit 93ca6a9

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

bayesflow/utils/comp_utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from keras import ops
23

34
from sklearn.calibration import calibration_curve
45

@@ -16,9 +17,9 @@ def expected_calibration_error(m_true, m_pred, num_bins=10):
1617
1718
Parameters
1819
----------
19-
m_true : np.ndarray of shape (num_sim, num_models)
20+
m_true : array of shape (num_sim, num_models)
2021
The one-hot-encoded true model indices.
21-
m_pred : tf.tensor of shape (num_sim, num_models)
22+
m_pred : array of shape (num_sim, num_models)
2223
The predicted posterior model probabilities.
2324
num_bins : int, optional, default: 10
2425
The number of bins to use for the calibration curves (and marginal histograms).
@@ -32,11 +33,9 @@ def expected_calibration_error(m_true, m_pred, num_bins=10):
3233
Each list contains two arrays of length (num_bins) with the predicted and true probabilities for each bin.
3334
"""
3435

35-
# Convert tf.Tensors to numpy, if passed
36-
if type(m_true) is not np.ndarray:
37-
m_true = m_true.numpy()
38-
if type(m_pred) is not np.ndarray:
39-
m_pred = m_pred.numpy()
36+
# Convert tensors to numpy, if passed
37+
m_true = ops.convert_to_numpy(m_true)
38+
m_pred = ops.convert_to_numpy(m_pred)
4039

4140
# Extract number of models and prepare containers
4241
n_models = m_true.shape[1]

0 commit comments

Comments
 (0)