Skip to content

Commit 387abe8

Browse files
authored
Merge pull request #389 from bayesflow-org/c2st
Add c2st
2 parents cf83bd3 + 8b9b16f commit 387abe8

File tree

13 files changed

+322
-24
lines changed

13 files changed

+322
-24
lines changed

bayesflow/diagnostics/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .posterior_contraction import posterior_contraction
33
from .root_mean_squared_error import root_mean_squared_error
44
from .expected_calibration_error import expected_calibration_error
5+
from .classifier_two_sample_test import classifier_two_sample_test
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from typing import Sequence, Mapping, Any
2+
3+
import numpy as np
4+
5+
import keras
6+
7+
from bayesflow.utils.exceptions import ShapeError
8+
from bayesflow.networks import MLP
9+
10+
11+
def classifier_two_sample_test(
12+
estimates: np.ndarray,
13+
targets: np.ndarray,
14+
metric: str = "accuracy",
15+
patience: int = 5,
16+
max_epochs: int = 1000,
17+
batch_size: int = 128,
18+
return_metric_only: bool = True,
19+
validation_split: float = 0.5,
20+
standardize: bool = True,
21+
mlp_widths: Sequence = (64, 64),
22+
**kwargs,
23+
) -> float | Mapping[str, Any]:
24+
"""
25+
C2ST metric [1] between samples from two distributions computed using a neural classifier.
26+
Can be computationally expensive if called in a loop[, since it needs to train the model
27+
for each set of samples.
28+
29+
Note: works best for large numbers of samples and averaged across different posteriors.
30+
31+
[1] Lopez-Paz, D., & Oquab, M. (2016). Revisiting classifier two-sample tests. arXiv:1610.06545.
32+
33+
Parameters
34+
----------
35+
estimates : np.ndarray
36+
Array of shape (num_samples_est, num_variables) containing samples representing estimated quantities
37+
(e.g., approximate posterior samples).
38+
targets : np.ndarray
39+
Array of shape (num_samples_tar, num_variables) containing target samples
40+
(e.g., samples from a reference posterior).
41+
metric : str, optional
42+
Metric to evaluate the classifier performance. Default is "accuracy".
43+
patience : int, optional
44+
Number of epochs with no improvement after which training will be stopped. Default is 5.
45+
max_epochs : int, optional
46+
Maximum number of epochs to train the classifier. Default is 1000.
47+
batch_size : int, optional
48+
Number of samples per batch during training. Default is 64.
49+
return_metric_only : bool, optional
50+
If True, only the final validation metric is returned. Otherwise, a dictionary with the score, classifier, and
51+
full training history is returned. Default is True.
52+
validation_split : float, optional
53+
Fraction of the training data to be used as validation data. Default is 0.5.
54+
standardize : bool, optional
55+
If True, both estimates and targets will be standardized using the mean and standard deviation of estimates.
56+
Default is True.
57+
mlp_widths : Sequence[int], optional
58+
Sequence specifying the number of units in each hidden layer of the MLP classifier. Default is (256, 256).
59+
**kwargs
60+
Additional keyword arguments. Recognized keyword:
61+
mlp_kwargs : dict
62+
Dictionary of additional parameters to pass to the MLP constructor.
63+
64+
Returns
65+
-------
66+
results : float or dict
67+
If return_metric_only is True, returns the final validation metric (e.g., accuracy) as a float.
68+
Otherwise, returns a dictionary with keys "score", "classifier", and "history", where "score"
69+
is the final validation metric, "classifier" is the trained Keras model, and "history" contains the
70+
full training history.
71+
"""
72+
73+
# Error, if targets dim does not match estimates dim
74+
num_dims = estimates.shape[1]
75+
if not num_dims == targets.shape[1]:
76+
raise ShapeError(
77+
f"estimates and targets can have different number of samples (1st dim)"
78+
f"but must have the same dimensionality (2nd dim)"
79+
f"found: estimates shape {estimates.shape[1]}, targets shape {targets.shape[1]}"
80+
)
81+
82+
# Standardize both estimates and targets relative to estimates mean and std
83+
if standardize:
84+
estimates_mean = np.mean(estimates, axis=0)
85+
estimates_std = np.std(estimates, axis=0)
86+
estimates = (estimates - estimates_mean) / estimates_std
87+
targets = (targets - estimates_mean) / estimates_std
88+
89+
# Create data for classification task
90+
data = np.r_[estimates, targets]
91+
labels = np.r_[np.zeros((estimates.shape[0],)), np.ones((targets.shape[0],))]
92+
93+
# Important: needed, since keras does not shuffle before selecting validation split
94+
shuffle_idx = np.random.permutation(data.shape[0])
95+
data = data[shuffle_idx]
96+
labels = labels[shuffle_idx]
97+
98+
# Create and train classifier with optional stopping
99+
classifier = keras.Sequential(
100+
[MLP(widths=mlp_widths, **kwargs.get("mlp_kwargs", {})), keras.layers.Dense(1, activation="sigmoid")]
101+
)
102+
103+
classifier.compile(optimizer="adam", loss="binary_crossentropy", metrics=[metric])
104+
105+
early_stopping = keras.callbacks.EarlyStopping(
106+
monitor=f"val_{metric}", patience=patience, restore_best_weights=True
107+
)
108+
109+
# For now, we need to enable grads, since we turn them off by default
110+
if keras.backend.backend() == "torch":
111+
import torch
112+
113+
with torch.enable_grad():
114+
history = classifier.fit(
115+
x=data,
116+
y=labels,
117+
epochs=max_epochs,
118+
batch_size=batch_size,
119+
verbose=0,
120+
callbacks=[early_stopping],
121+
validation_split=validation_split,
122+
)
123+
else:
124+
history = classifier.fit(
125+
x=data,
126+
y=labels,
127+
epochs=max_epochs,
128+
batch_size=batch_size,
129+
verbose=0,
130+
callbacks=[early_stopping],
131+
validation_split=validation_split,
132+
)
133+
134+
if return_metric_only:
135+
return history.history[f"val_{metric}"][-1]
136+
return {"score": history.history[f"val_{metric}"][-1], "classifier": classifier, "history": history.history}

bayesflow/diagnostics/metrics/expected_calibration_error.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from typing import Sequence, Any, Mapping
44

55
from ...utils.exceptions import ShapeError
6-
from sklearn.calibration import calibration_curve
6+
from ...utils.classification import calibration_curve
77

88

99
def expected_calibration_error(
1010
estimates: np.ndarray,
1111
targets: np.ndarray,
1212
model_names: Sequence[str] = None,
13-
n_bins: int = 10,
13+
num_bins: int = 10,
1414
return_probs: bool = False,
1515
) -> Mapping[str, Any]:
1616
"""
@@ -31,11 +31,11 @@ def expected_calibration_error(
3131
The one-hot-encoded true model indices.
3232
model_names : Sequence[str], optional (default = None)
3333
Optional model names to show in the output. By default, models are called "M_" + model index.
34-
n_bins : int, optional, default: 10
34+
num_bins : int, optional, default: 10
3535
The number of bins to use for the calibration curves (and marginal histograms).
36-
Passed into ``sklearn.calibration.calibration_curve()``.
36+
Passed into ``bayesflow.utils.calibration_curve()``.
3737
return_probs : bool (default = False)
38-
Do you want to obtain the output of ``sklearn.calibration.calibration_curve()``?
38+
Do you want to obtain the output of ``bayesflow.utils.calibration_curve()``?
3939
4040
Returns
4141
-------
@@ -48,9 +48,9 @@ def expected_calibration_error(
4848
- "model_names" : str
4949
The (inferred) variable names.
5050
- "probs_true": (optional) list[np.ndarray]:
51-
Outputs of ``sklearn.calibration.calibration_curve()`` per model
51+
Outputs of ``bayesflow.utils.calibration.calibration_curve()`` per model
5252
- "probs_pred": (optional) list[np.ndarray]:
53-
Outputs of ``sklearn.calibration.calibration_curve()`` per model
53+
Outputs of ``bayesflow.utils.calibration.calibration_curve()`` per model
5454
"""
5555

5656
# Convert tensors to numpy, if passed
@@ -76,10 +76,10 @@ def expected_calibration_error(
7676
for model_index in range(estimates.shape[-1]):
7777
y_true = (targets == model_index).astype(np.float32)
7878
y_prob = estimates[..., model_index]
79-
prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins)
79+
prob_true, prob_pred = calibration_curve(y_true, y_prob, num_bins=num_bins)
8080

8181
# Compute ECE by weighting bin errors by bin size
82-
bins = np.linspace(0.0, 1.0, n_bins + 1)
82+
bins = np.linspace(0.0, 1.0, num_bins + 1)
8383
binids = np.searchsorted(bins[1:-1], y_prob)
8484
bin_total = np.bincount(binids, minlength=len(bins))
8585
nonzero = bin_total != 0

bayesflow/diagnostics/plots/calibration_histogram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def calibration_histogram(
102102
"Confidence intervals might be unreliable!"
103103
)
104104

105-
# Set n_bins automatically, if nothing provided
105+
# Set num_bins automatically, if nothing provided
106106
if num_bins is None:
107107
num_bins = int(ratio / 2)
108108
# Attempt a fix if a single bin is determined so plot still makes sense

bayesflow/diagnostics/plots/mc_calibration.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def mc_calibration(
1818
pred_models: dict[str, np.ndarray] | np.ndarray,
1919
true_models: dict[str, np.ndarray] | np.ndarray,
2020
model_names: Sequence[str] = None,
21-
n_bins: int = 10,
21+
num_bins: int = 10,
2222
label_fontsize: int = 16,
2323
title_fontsize: int = 18,
2424
metric_fontsize: int = 14,
@@ -41,12 +41,12 @@ def mc_calibration(
4141
The one-hot-encoded true model indices per data set.
4242
model_names : list or None, optional, default: None
4343
The model names for nice plot titles. Inferred if None.
44-
n_bins : int, optional, default: 10
44+
num_bins : int, optional, default: 10
4545
The number of bins to use for the calibration curves (and marginal histograms).
4646
label_fontsize : int, optional, default: 16
4747
The font size of the y-label and y-label texts
48-
legend_fontsize : int, optional, default: 14
49-
The font size of the legend text (ECE value)
48+
metric_fontsize : int, optional, default: 14
49+
The font size of the metric (e.g., ECE)
5050
title_fontsize : int, optional, default: 18
5151
The font size of the title text. Only relevant if `stacked=False`
5252
tick_fontsize : int, optional, default: 12
@@ -83,7 +83,7 @@ def mc_calibration(
8383
estimates=pred_models,
8484
targets=true_models,
8585
model_names=plot_data["variable_names"],
86-
n_bins=n_bins,
86+
num_bins=num_bins,
8787
return_probs=True,
8888
)
8989

@@ -92,7 +92,7 @@ def mc_calibration(
9292
ax.plot(ece["probs_pred"][j], ece["probs_true"][j], "o-", color=color)
9393

9494
# Plot PMP distribution over bins
95-
uniform_bins = np.linspace(0.0, 1.0, n_bins + 1)
95+
uniform_bins = np.linspace(0.0, 1.0, num_bins + 1)
9696
norm_weights = np.ones_like(plot_data["estimates"]) / len(plot_data["estimates"])
9797
ax.hist(plot_data["estimates"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3)
9898

bayesflow/diagnostics/plots/mc_confusion_matrix.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
from matplotlib.colors import LinearSegmentedColormap
66
import numpy as np
77

8-
from sklearn.metrics import confusion_matrix
9-
10-
from bayesflow.utils.plot_utils import make_figure
8+
from ...utils.plot_utils import make_figure
9+
from ...utils.classification import confusion_matrix
1110

1211

1312
def mc_confusion_matrix(
@@ -50,10 +49,9 @@ def mc_confusion_matrix(
5049
ytick_rotation: int, optional, default: None
5150
Rotation of y-axis tick labels (helps with long model names).
5251
normalize : {'true', 'pred', 'all'}, default=None
53-
Passed to sklearn.metrics.confusion_matrix.
54-
Normalizes confusion matrix over the true (rows), predicted (columns)
55-
conditions or all the population. If None, confusion matrix will not be
56-
normalized.
52+
Passed to confusion matrix. Normalizes confusion matrix over the true (rows),
53+
predicted (columns) conditions or all the population. If None, confusion matrix
54+
will not be normalized.
5755
cmap : matplotlib.colors.Colormap or str, optional, default: None
5856
Colormap to be used for the cells. If a str, it should be the name of a registered colormap,
5957
e.g., 'viridis'. Default colormap matches the BayesFlow defaults by ranging from white to red.

bayesflow/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
fill_triangular_matrix,
7474
weighted_sum,
7575
)
76+
from .classification import calibration_curve, confusion_matrix
7677
from .validators import check_lengths_same
7778
from .workflow_utils import find_inference_network, find_summary_network
7879

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .confusion_matrix import confusion_matrix
2+
from .calibration_curve import calibration_curve
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import numpy as np
2+
3+
4+
def calibration_curve(
5+
targets: np.ndarray,
6+
estimates: np.ndarray,
7+
*,
8+
pos_label: int | float | bool | str = 1,
9+
num_bins: int = 5,
10+
strategy: str = "uniform",
11+
):
12+
"""Compute true and predicted probabilities for a calibration curve.
13+
14+
The method assumes the inputs come from a binary classifier, and
15+
discretize the [0, 1] interval into bins.
16+
17+
Code from: https://github.com/scikit-learn/scikit-learn/blob/98ed9dc73/sklearn/calibration.py#L927
18+
19+
Parameters
20+
----------
21+
targets : array-like of shape (n_samples,)
22+
True targets.
23+
estimates : array-like of shape (n_samples,)
24+
Probabilities of the positive class.
25+
pos_label : int, float, bool or str, default = 1
26+
The label of the positive class.
27+
num_bins : int, default=5
28+
Number of bins to discretize the [0, 1] interval. A bigger number
29+
requires more data. Bins with no samples (i.e. without
30+
corresponding values in `estimates`) will not be returned, thus the
31+
returned arrays may have less than `num_bins` values.
32+
strategy : {'uniform', 'quantile'}, default='uniform'
33+
Strategy used to define the widths of the bins.
34+
35+
uniform
36+
The bins have identical widths.
37+
quantile
38+
The bins have the same number of samples and depend on `y_prob`.
39+
40+
Returns
41+
-------
42+
prob_true : ndarray of shape (num_bins,) or smaller
43+
The proportion of samples whose class is the positive class, in each
44+
bin (fraction of positives).
45+
46+
prob_pred : ndarray of shape (num_bins,) or smaller
47+
The mean estimated probability in each bin.
48+
49+
References
50+
----------
51+
Alexandru Niculescu-Mizil and Rich Caruana (2005) Predicting Good
52+
Probabilities With Supervised Learning, in Proceedings of the 22nd
53+
International Conference on Machine Learning (ICML).
54+
"""
55+
56+
if estimates.min() < 0 or estimates.max() > 1:
57+
raise ValueError("y_prob has values outside [0, 1].")
58+
59+
labels = np.unique(targets)
60+
if len(labels) > 2:
61+
raise ValueError(f"Only binary classification is supported. Provided labels {labels}.")
62+
targets = targets == pos_label
63+
64+
if strategy == "quantile": # Determine bin edges by distribution of data
65+
quantiles = np.linspace(0, 1, num_bins + 1)
66+
bins = np.percentile(estimates, quantiles * 100)
67+
elif strategy == "uniform":
68+
bins = np.linspace(0.0, 1.0, num_bins + 1)
69+
else:
70+
raise ValueError("Invalid entry to 'strategy' input. Strategy must be either 'quantile' or 'uniform'.")
71+
72+
binids = np.searchsorted(bins[1:-1], estimates)
73+
74+
bin_sums = np.bincount(binids, weights=estimates, minlength=len(bins))
75+
bin_true = np.bincount(binids, weights=targets, minlength=len(bins))
76+
bin_total = np.bincount(binids, minlength=len(bins))
77+
78+
nonzero = bin_total != 0
79+
prob_true = bin_true[nonzero] / bin_total[nonzero]
80+
prob_pred = bin_sums[nonzero] / bin_total[nonzero]
81+
82+
return prob_true, prob_pred

0 commit comments

Comments
 (0)