Skip to content

Commit bb0cf65

Browse files
committed
add C2ST metric
1 parent 8866be6 commit bb0cf65

File tree

2 files changed

+115
-1
lines changed

2 files changed

+115
-1
lines changed

bayesflow/computational_utilities.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import tensorflow as tf
2323
from scipy import stats
2424
from sklearn.calibration import calibration_curve
25+
from sklearn.neural_network import MLPClassifier
26+
from sklearn.model_selection import cross_val_score, KFold
2527

2628
from bayesflow.default_settings import MMD_BANDWIDTH_LIST
2729
from bayesflow.exceptions import ShapeError
@@ -517,3 +519,73 @@ def aggregated_rmse(x_true, x_pred):
517519
return aggregated_error(
518520
x_true=x_true, x_pred=x_pred, inner_error_fun=root_mean_squared_error, outer_aggregation_fun=np.mean
519521
)
522+
523+
524+
def c2st(source_samples, target_samples, n_folds=5, scoring="accuracy", normalize=True, seed=123,
525+
hidden_units_per_dim=10):
526+
"""C2ST metric [1] using an sklearn MLP classifier.
527+
Code adapted from https://github.com/sbi-benchmark/sbibm/blob/main/sbibm/metrics/c2st.py
528+
529+
[1] Lopez-Paz, D., & Oquab, M. (2016). Revisiting classifier two-sample tests. arXiv:1610.06545.
530+
531+
Parameters
532+
----------
533+
source_samples : np.ndarray or tf.Tensor
534+
Source samples (e.g., approximate posterior samples)
535+
target_samples : np.ndarray or tf.Tensor
536+
Target samples (e.g., samples from a reference posterior)
537+
n_folds : int, optional, default: 5
538+
Number of folds in k-fold cross-validation for the classifier evaluation
539+
scoring : str, optional, default: "accuracy"
540+
Evaluation score of the sklearn MLP classifier
541+
normalize : bool, optional, default: True
542+
Whether the data shall be z-standardized relative to source_samples
543+
seed : int, optional, default: 123
544+
RNG seed for the MLP and k-fold CV
545+
hidden_units_per_dim : int, optional, default: 10
546+
Number of hidden units in the MLP, relative to the input dimensions.
547+
Example: source samples are 5D, hidden_units_per_dim=10 -> 50 hidden units per layer
548+
549+
Returns
550+
-------
551+
c2st_score : float
552+
The resulting C2ST score
553+
554+
"""
555+
556+
x = np.array(source_samples)
557+
y = np.array(target_samples)
558+
559+
num_dims = x.shape[1]
560+
if not num_dims == y.shape[1]:
561+
raise ShapeError(f"source_samples and target_samples can have different number of observations (1st dim)"
562+
f"but must have the same dimensionality (2nd dim)"
563+
f"found: source_samples {source_samples.shape[1]}, target_samples {target_samples.shape[1]}")
564+
565+
if normalize:
566+
x_mean = np.mean(x, axis=0)
567+
x_std = np.std(x, axis=0)
568+
x = (x - x_mean) / x_std
569+
y = (y - x_mean) / x_std
570+
571+
clf = MLPClassifier(
572+
activation="relu",
573+
hidden_layer_sizes=(hidden_units_per_dim * num_dims, hidden_units_per_dim * num_dims),
574+
max_iter=10000,
575+
solver="adam",
576+
random_state=seed,
577+
)
578+
579+
data = np.concatenate((x, y))
580+
target = np.concatenate(
581+
(
582+
np.zeros((x.shape[0],)),
583+
np.ones((y.shape[0],)),
584+
)
585+
)
586+
587+
shuffle = KFold(n_splits=n_folds, shuffle=True, random_state=seed)
588+
scores = cross_val_score(clf, data, target, cv=shuffle, scoring=scoring)
589+
590+
c2st_score = np.asarray(np.mean(scores)).astype(np.float32)
591+
return c2st_score

tests/test_computational_utilities.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import pytest
44
import numpy as np
55
from bayesflow import computational_utilities
6-
from bayesflow.exceptions import ArgumentError
6+
from bayesflow.exceptions import ArgumentError, ShapeError
77
from bayesflow.trainers import Trainer
8+
import tensorflow as tf
89

910

1011
@pytest.mark.parametrize("x_true, x_pred, output",
@@ -93,3 +94,44 @@ def test_aggregated_error(x_true, x_pred, inner_error_fun, outer_aggregation_fun
9394
outer_aggregation_fun=outer_aggregation_fun
9495
)
9596
assert aggregated_error_result == pytest.approx(output)
97+
98+
99+
def test_c2st_shape_error():
100+
source_samples = np.random.random(size=(5, 2))
101+
target_samples = np.random.random(size=(5, 3))
102+
with pytest.raises(ShapeError):
103+
computational_utilities.c2st(source_samples, target_samples)
104+
105+
106+
@pytest.mark.parametrize(
107+
"source_samples, target_samples",
108+
[
109+
(np.random.random((5, 2)), np.random.random((5, 2))),
110+
(np.random.random((10, 2)), np.random.random((5, 2))),
111+
(tf.constant(np.random.random((5, 2))), tf.constant(np.random.random((5, 2))))
112+
]
113+
)
114+
def test_c2st(source_samples, target_samples):
115+
c2st_score = computational_utilities.c2st(source_samples, target_samples)
116+
assert 0.0 <= c2st_score <= 1.0
117+
118+
119+
@pytest.mark.parametrize(
120+
"n_folds, scoring, normalize, seed, hidden_units_per_dim",
121+
[
122+
(3, "accuracy", False, 42, 5),
123+
(7, "f1", True, 12, 10)
124+
]
125+
)
126+
def test_c2st_params(n_folds, scoring, normalize, seed, hidden_units_per_dim):
127+
source_samples = np.random.random((5, 2))
128+
target_samples = np.random.random((10, 2))
129+
_ = computational_utilities.c2st(
130+
source_samples=source_samples,
131+
target_samples=target_samples,
132+
n_folds=n_folds,
133+
scoring=scoring,
134+
normalize=normalize,
135+
seed=seed,
136+
hidden_units_per_dim=hidden_units_per_dim
137+
)

0 commit comments

Comments
 (0)