Skip to content

Commit ee201d5

Browse files
committed
Remove sklearn as dep
1 parent 1a2e35d commit ee201d5

File tree

10 files changed

+176
-17
lines changed

10 files changed

+176
-17
lines changed

bayesflow/diagnostics/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .posterior_contraction import posterior_contraction
33
from .root_mean_squared_error import root_mean_squared_error
44
from .expected_calibration_error import expected_calibration_error
5+
from .classifier_two_sample_test import classifier_two_sample_test

bayesflow/diagnostics/metrics/expected_calibration_error.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from typing import Sequence, Any, Mapping
44

55
from ...utils.exceptions import ShapeError
6-
from sklearn.calibration import calibration_curve
6+
from ...utils.classification import calibration_curve
77

88

99
def expected_calibration_error(
1010
estimates: np.ndarray,
1111
targets: np.ndarray,
1212
model_names: Sequence[str] = None,
13-
n_bins: int = 10,
13+
num_bins: int = 10,
1414
return_probs: bool = False,
1515
) -> Mapping[str, Any]:
1616
"""
@@ -31,11 +31,11 @@ def expected_calibration_error(
3131
The one-hot-encoded true model indices.
3232
model_names : Sequence[str], optional (default = None)
3333
Optional model names to show in the output. By default, models are called "M_" + model index.
34-
n_bins : int, optional, default: 10
34+
num_bins : int, optional, default: 10
3535
The number of bins to use for the calibration curves (and marginal histograms).
36-
Passed into ``sklearn.calibration.calibration_curve()``.
36+
Passed into ``bayesflow.utils.calibration_curve()``.
3737
return_probs : bool (default = False)
38-
Do you want to obtain the output of ``sklearn.calibration.calibration_curve()``?
38+
Do you want to obtain the output of ``bayesflow.utils.calibration_curve()``?
3939
4040
Returns
4141
-------
@@ -48,9 +48,9 @@ def expected_calibration_error(
4848
- "model_names" : str
4949
The (inferred) variable names.
5050
- "probs_true": (optional) list[np.ndarray]:
51-
Outputs of ``sklearn.calibration.calibration_curve()`` per model
51+
Outputs of ``bayesflow.utils.calibration.calibration_curve()`` per model
5252
- "probs_pred": (optional) list[np.ndarray]:
53-
Outputs of ``sklearn.calibration.calibration_curve()`` per model
53+
Outputs of ``bayesflow.utils.calibration.calibration_curve()`` per model
5454
"""
5555

5656
# Convert tensors to numpy, if passed
@@ -76,10 +76,10 @@ def expected_calibration_error(
7676
for model_index in range(estimates.shape[-1]):
7777
y_true = (targets == model_index).astype(np.float32)
7878
y_prob = estimates[..., model_index]
79-
prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins)
79+
prob_true, prob_pred = calibration_curve(y_true, y_prob, num_bins=num_bins)
8080

8181
# Compute ECE by weighting bin errors by bin size
82-
bins = np.linspace(0.0, 1.0, n_bins + 1)
82+
bins = np.linspace(0.0, 1.0, num_bins + 1)
8383
binids = np.searchsorted(bins[1:-1], y_prob)
8484
bin_total = np.bincount(binids, minlength=len(bins))
8585
nonzero = bin_total != 0

bayesflow/diagnostics/plots/mc_confusion_matrix.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
from matplotlib.colors import LinearSegmentedColormap
66
import numpy as np
77

8-
from sklearn.metrics import confusion_matrix
9-
10-
from bayesflow.utils.plot_utils import make_figure
8+
from ...utils.plot_utils import make_figure
9+
from ...utils.classification import confusion_matrix
1110

1211

1312
def mc_confusion_matrix(
@@ -50,10 +49,9 @@ def mc_confusion_matrix(
5049
ytick_rotation: int, optional, default: None
5150
Rotation of y-axis tick labels (helps with long model names).
5251
normalize : {'true', 'pred', 'all'}, default=None
53-
Passed to sklearn.metrics.confusion_matrix.
54-
Normalizes confusion matrix over the true (rows), predicted (columns)
55-
conditions or all the population. If None, confusion matrix will not be
56-
normalized.
52+
Passed to confusion matrix. Normalizes confusion matrix over the true (rows),
53+
predicted (columns) conditions or all the population. If None, confusion matrix
54+
will not be normalized.
5755
cmap : matplotlib.colors.Colormap or str, optional, default: None
5856
Colormap to be used for the cells. If a str, it should be the name of a registered colormap,
5957
e.g., 'viridis'. Default colormap matches the BayesFlow defaults by ranging from white to red.

bayesflow/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
tree_stack,
7373
fill_triangular_matrix,
7474
)
75+
from .classification import calibration_curve, confusion_matrix
7576
from .validators import check_lengths_same
7677
from .workflow_utils import find_inference_network, find_summary_network
7778

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .confusion_matrix import confusion_matrix
2+
from .calibration_curve import calibration_curve
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import numpy as np
2+
3+
4+
def calibration_curve(
5+
targets: np.ndarray,
6+
estimates: np.ndarray,
7+
*,
8+
pos_label: int | float | bool | str = 1,
9+
num_bins: int = 5,
10+
strategy: str = "uniform",
11+
):
12+
"""Compute true and predicted probabilities for a calibration curve.
13+
14+
The method assumes the inputs come from a binary classifier, and
15+
discretize the [0, 1] interval into bins.
16+
17+
Code from: https://github.com/scikit-learn/scikit-learn/blob/98ed9dc73/sklearn/calibration.py#L927
18+
19+
Parameters
20+
----------
21+
targets : array-like of shape (n_samples,)
22+
True targets.
23+
estimates : array-like of shape (n_samples,)
24+
Probabilities of the positive class.
25+
pos_label : int, float, bool or str, default = 1
26+
The label of the positive class.
27+
num_bins : int, default=5
28+
Number of bins to discretize the [0, 1] interval. A bigger number
29+
requires more data. Bins with no samples (i.e. without
30+
corresponding values in `y_prob`) will not be returned, thus the
31+
returned arrays may have less than `n_bins` values.
32+
strategy : {'uniform', 'quantile'}, default='uniform'
33+
Strategy used to define the widths of the bins.
34+
35+
uniform
36+
The bins have identical widths.
37+
quantile
38+
The bins have the same number of samples and depend on `y_prob`.
39+
40+
Returns
41+
-------
42+
prob_true : ndarray of shape (n_bins,) or smaller
43+
The proportion of samples whose class is the positive class, in each
44+
bin (fraction of positives).
45+
46+
prob_pred : ndarray of shape (n_bins,) or smaller
47+
The mean estimated probability in each bin.
48+
49+
References
50+
----------
51+
Alexandru Niculescu-Mizil and Rich Caruana (2005) Predicting Good
52+
Probabilities With Supervised Learning, in Proceedings of the 22nd
53+
International Conference on Machine Learning (ICML).
54+
"""
55+
56+
if estimates.min() < 0 or estimates.max() > 1:
57+
raise ValueError("y_prob has values outside [0, 1].")
58+
59+
labels = np.unique(targets)
60+
if len(labels) > 2:
61+
raise ValueError(f"Only binary classification is supported. Provided labels {labels}.")
62+
targets = targets == pos_label
63+
64+
if strategy == "quantile": # Determine bin edges by distribution of data
65+
quantiles = np.linspace(0, 1, num_bins + 1)
66+
bins = np.percentile(estimates, quantiles * 100)
67+
elif strategy == "uniform":
68+
bins = np.linspace(0.0, 1.0, num_bins + 1)
69+
else:
70+
raise ValueError("Invalid entry to 'strategy' input. Strategy must be either 'quantile' or 'uniform'.")
71+
72+
binids = np.searchsorted(bins[1:-1], estimates)
73+
74+
bin_sums = np.bincount(binids, weights=estimates, minlength=len(bins))
75+
bin_true = np.bincount(binids, weights=targets, minlength=len(bins))
76+
bin_total = np.bincount(binids, minlength=len(bins))
77+
78+
nonzero = bin_total != 0
79+
prob_true = bin_true[nonzero] / bin_total[nonzero]
80+
prob_pred = bin_sums[nonzero] / bin_total[nonzero]
81+
82+
return prob_true, prob_pred
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import Sequence
2+
3+
import numpy as np
4+
5+
6+
def confusion_matrix(targets: np.ndarray, estimates: np.ndarray, labels: Sequence = None, normalize: str = None):
7+
"""
8+
Compute confusion matrix to evaluate the accuracy of a classification or model comparison setting.
9+
10+
Code inspired by: https://github.com/scikit-learn/scikit-learn/blob/98ed9dc73/sklearn/metrics/_classification.py
11+
12+
Parameters
13+
----------
14+
targets : np.ndarray
15+
Ground truth (correct) target values.
16+
estimates : np.ndarray
17+
Estimated targets as returned by a classifier.
18+
labels : Sequence, optional
19+
List of labels to index the matrix. This may be used to reorder or select a subset of labels.
20+
If None, labels that appear at least once in y_true or y_pred are used in sorted order.
21+
normalize : {'true', 'pred', 'all'}, optional
22+
Normalizes confusion matrix over the true (rows), predicted (columns)
23+
conditions or all the population. If None, no normalization is applied.
24+
25+
Returns
26+
-------
27+
cm : np.ndarray of shape (num_labels, num_labels)
28+
Confusion matrix. Rows represent true classes, columns represent predicted classes.
29+
"""
30+
31+
# Get unique labels
32+
labels = np.asarray(labels) or np.unique(np.concatenate((targets, estimates)))
33+
34+
label_to_index = {label: i for i, label in enumerate(labels)}
35+
num_labels = len(labels)
36+
37+
# Initialize the confusion matrix
38+
cm = np.zeros((num_labels, num_labels), dtype=np.int64)
39+
40+
# Fill confusion matrix
41+
for t, p in zip(targets, estimates):
42+
if t in label_to_index and p in label_to_index:
43+
cm[label_to_index[t], label_to_index[p]] += 1
44+
45+
# Normalize if required
46+
if normalize == "true":
47+
with np.errstate(all="ignore"):
48+
cm = cm.astype(np.float64)
49+
cm = np.divide(cm, cm.sum(axis=1, keepdims=True), where=cm.sum(axis=1, keepdims=True) != 0)
50+
elif normalize == "pred":
51+
with np.errstate(all="ignore"):
52+
cm = cm.astype(np.float64)
53+
cm = np.divide(cm, cm.sum(axis=0, keepdims=True), where=cm.sum(axis=0, keepdims=True) != 0)
54+
elif normalize == "all":
55+
cm = cm.astype(np.float64)
56+
cm /= cm.sum()
57+
58+
return cm

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ dependencies = [
2727
"matplotlib",
2828
"numpy >= 1.24, <2.0",
2929
"pandas",
30-
"scikit-learn",
3130
"scipy",
3231
"seaborn",
3332
"tqdm",

tests/test_diagnostics/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@ def var_names():
88
return [r"$\beta_0$", r"$\beta_1$", r"$\sigma$"]
99

1010

11+
@pytest.fixture()
12+
def random_samples_a():
13+
return np.random.normal(loc=0, scale=1, size=(1000, 8))
14+
15+
16+
@pytest.fixture()
17+
def random_samples_b():
18+
return np.random.normal(loc=0, scale=3, size=(1000, 8))
19+
20+
1121
@pytest.fixture()
1222
def random_estimates():
1323
return {

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ def test_root_mean_squared_error(random_estimates, random_targets):
5050
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]
5151

5252

53+
def test_classifier_two_sample_test(random_samples_a, random_samples_b):
54+
metric = bf.diagnostics.metrics.classifier_two_sample_test(estimates=random_samples_a, targets=random_samples_a)
55+
assert 0.6 > metric > 0.4
56+
57+
metric = bf.diagnostics.metrics.classifier_two_sample_test(estimates=random_samples_a, targets=random_samples_b)
58+
assert metric > 0.6
59+
60+
5361
def test_expected_calibration_error(pred_models, true_models, model_names):
5462
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models, model_names=model_names)
5563
assert list(out.keys()) == ["values", "metric_name", "model_names"]

0 commit comments

Comments
 (0)