Skip to content

Commit edf91c2

Browse files
committed
remove missing import
1 parent 5b245ff commit edf91c2

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

bayesflow/diagnostics/metrics/expected_calibration_error.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ def expected_calibration_error(
4141
-------
4242
result : dict
4343
Dictionary containing:
44-
- "values" : float or np.ndarray
44+
- "values" : np.ndarray
4545
The expected calibration error per model
4646
- "metric_name" : str
4747
The name of the metric ("Expected Calibration Error").
4848
- "model_names" : str
4949
The (inferred) variable names.
50-
- "probs_true": (optional) list:
50+
- "probs_true": (optional) list[np.ndarray]:
5151
Outputs of ``sklearn.calibration.calibration_curve()`` per model
52-
- "probs_pred": (optional) list:
52+
- "probs_pred": (optional) list[np.ndarray]:
5353
Outputs of ``sklearn.calibration.calibration_curve()`` per model
5454
"""
5555

@@ -89,7 +89,7 @@ def expected_calibration_error(
8989
probs_true.append(prob_true)
9090
probs_pred.append(prob_pred)
9191

92-
output = dict(values=ece, metric_name="Expected Calibration Error", model_names=model_names)
92+
output = dict(values=np.array(ece), metric_name="Expected Calibration Error", model_names=model_names)
9393

9494
if return_probs:
9595
output["probs_true"] = probs_true

bayesflow/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
numpy_utils,
55
)
66
from .callbacks import detailed_loss_callback
7-
from .comp_utils import expected_calibration_error
87
from .devices import devices
98
from .dict_utils import (
109
convert_args,

0 commit comments

Comments
 (0)