Skip to content

Commit 1a2e35d

Browse files
committed
Add c2st
1 parent 635071e commit 1a2e35d

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from typing import Sequence, Mapping, Any
2+
3+
import numpy as np
4+
5+
import keras
6+
7+
from bayesflow.utils.exceptions import ShapeError
8+
from bayesflow.networks import MLP
9+
10+
11+
def classifier_two_sample_test(
12+
estimates: np.ndarray,
13+
targets: np.ndarray,
14+
metric: str = "accuracy",
15+
patience: int = 10,
16+
max_epochs: int = 1000,
17+
batch_size: int = 64,
18+
return_metric_only: bool = True,
19+
validation_split: float = 0.5,
20+
standardize: bool = True,
21+
mlp_widths: Sequence = (256, 256),
22+
**kwargs,
23+
) -> float | Mapping[str, Any]:
24+
"""
25+
C2ST metric [1] between samples from two distributions computed using a neural classifier.
26+
Can be computationally expensive, since it involves training of a neural classifier.
27+
28+
Note: works best for large numbers of samples and averaged across different posteriors.
29+
30+
Code adapted from https://github.com/sbi-benchmark/sbibm/blob/main/sbibm/metrics/c2st.py
31+
32+
[1] Lopez-Paz, D., & Oquab, M. (2016). Revisiting classifier two-sample tests. arXiv:1610.06545.
33+
34+
Parameters
35+
----------
36+
estimates : np.ndarray
37+
Array of shape (num_samples_est, num_variables) containing samples representing estimated quantities
38+
(e.g., approximate posterior samples).
39+
targets : np.ndarray
40+
Array of shape (num_samples_tar, num_variables) containing target samples
41+
(e.g., samples from a reference posterior).
42+
metric : str, optional
43+
Metric to evaluate the classifier performance. Default is "accuracy".
44+
patience : int, optional
45+
Number of epochs with no improvement after which training will be stopped. Default is 5.
46+
max_epochs : int, optional
47+
Maximum number of epochs to train the classifier. Default is 1000.
48+
batch_size : int, optional
49+
Number of samples per batch during training. Default is 64.
50+
return_metric_only : bool, optional
51+
If True, only the final validation metric is returned. Otherwise, a dictionary with the score, classifier, and
52+
full training history is returned. Default is True.
53+
validation_split : float, optional
54+
Fraction of the training data to be used as validation data. Default is 0.5.
55+
standardize : bool, optional
56+
If True, both estimates and targets will be standardized using the mean and standard deviation of estimates.
57+
Default is True.
58+
mlp_widths : Sequence[int], optional
59+
Sequence specifying the number of units in each hidden layer of the MLP classifier. Default is (256, 256).
60+
**kwargs
61+
Additional keyword arguments. Recognized keyword:
62+
mlp_kwargs : dict
63+
Dictionary of additional parameters to pass to the MLP constructor.
64+
65+
Returns
66+
-------
67+
results : float or dict
68+
If return_metric_only is True, returns the final validation metric (e.g., accuracy) as a float.
69+
Otherwise, returns a dictionary with keys "score", "classifier", and "history", where "score"
70+
is the final validation metric, "classifier" is the trained Keras model, and "history" contains the
71+
full training history.
72+
"""
73+
74+
# Convert tensors to numpy, if passed
75+
estimates = keras.ops.convert_to_numpy(estimates)
76+
targets = keras.ops.convert_to_numpy(targets)
77+
78+
# Error, if targets dim does not match estimates dim
79+
num_dims = estimates.shape[1]
80+
if not num_dims == targets.shape[1]:
81+
raise ShapeError(
82+
f"estimates and targets can have different number of samples (1st dim)"
83+
f"but must have the same dimensionality (2nd dim)"
84+
f"found: estimates shape {estimates.shape[1]}, targets shape {targets.shape[1]}"
85+
)
86+
87+
# Standardize both estimates and targets relative to estimates mean and std
88+
if standardize:
89+
estimates_mean = np.mean(estimates, axis=0)
90+
estimates_std = np.std(estimates, axis=0)
91+
estimates = (estimates - estimates_mean) / estimates_std
92+
targets = (targets - estimates_mean) / estimates_std
93+
94+
# Create data for classification task
95+
data = np.r_[estimates, targets]
96+
labels = np.r_[np.zeros((estimates.shape[0],)), np.ones((targets.shape[0],))]
97+
98+
# Create and train classifier with optional stopping
99+
classifier = keras.Sequential(
100+
[MLP(widths=mlp_widths, **kwargs.get("mlp_kwargs", {})), keras.layers.Dense(1, activation="sigmoid")]
101+
)
102+
103+
classifier.compile(optimizer="adam", loss="binary_crossentropy", metrics=[metric])
104+
105+
early_stopping = keras.callbacks.EarlyStopping(
106+
monitor=f"val_{metric}", patience=patience, restore_best_weights=True
107+
)
108+
109+
history = classifier.fit(
110+
x=data,
111+
y=labels,
112+
epochs=max_epochs,
113+
batch_size=batch_size,
114+
verbose=0,
115+
callbacks=[early_stopping],
116+
validation_split=validation_split,
117+
)
118+
119+
if return_metric_only:
120+
return history.history[f"val_{metric}"][-1]
121+
return {"score": history.history[f"val_{metric}"][-1], "classifier": classifier, "history": history.history}

0 commit comments

Comments
 (0)