Skip to content

Commit ae7f8c4

Browse files
committed
Add option to aggregate outputs of c2st
1 parent bb0cf65 commit ae7f8c4

File tree

1 file changed

+30
-16
lines changed

1 file changed

+30
-16
lines changed

bayesflow/computational_utilities.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
import tensorflow as tf
2323
from scipy import stats
2424
from sklearn.calibration import calibration_curve
25+
from sklearn.model_selection import KFold, cross_val_score
2526
from sklearn.neural_network import MLPClassifier
26-
from sklearn.model_selection import cross_val_score, KFold
2727

2828
from bayesflow.default_settings import MMD_BANDWIDTH_LIST
2929
from bayesflow.exceptions import ShapeError
@@ -521,30 +521,41 @@ def aggregated_rmse(x_true, x_pred):
521521
)
522522

523523

524-
def c2st(source_samples, target_samples, n_folds=5, scoring="accuracy", normalize=True, seed=123,
525-
hidden_units_per_dim=10):
526-
"""C2ST metric [1] using an sklearn MLP classifier.
524+
def c2st(
525+
source_samples,
526+
target_samples,
527+
n_folds=5,
528+
scoring="accuracy",
529+
normalize=True,
530+
seed=123,
531+
hidden_units_per_dim=16,
532+
aggregate_output=True,
533+
):
534+
"""C2ST metric [1] using an sklearn neural network classifier (i.e., MLP).
527535
Code adapted from https://github.com/sbi-benchmark/sbibm/blob/main/sbibm/metrics/c2st.py
528536
529537
[1] Lopez-Paz, D., & Oquab, M. (2016). Revisiting classifier two-sample tests. arXiv:1610.06545.
530538
531539
Parameters
532540
----------
533-
source_samples : np.ndarray or tf.Tensor
541+
source_samples : np.ndarray or tf.Tensor
534542
Source samples (e.g., approximate posterior samples)
535-
target_samples : np.ndarray or tf.Tensor
543+
target_samples : np.ndarray or tf.Tensor
536544
Target samples (e.g., samples from a reference posterior)
537-
n_folds : int, optional, default: 5
545+
n_folds : int, optional, default: 5
538546
Number of folds in k-fold cross-validation for the classifier evaluation
539-
scoring : str, optional, default: "accuracy"
547+
scoring : str, optional, default: "accuracy"
540548
Evaluation score of the sklearn MLP classifier
541-
normalize : bool, optional, default: True
549+
normalize : bool, optional, default: True
542550
Whether the data shall be z-standardized relative to source_samples
543-
seed : int, optional, default: 123
551+
seed : int, optional, default: 123
544552
RNG seed for the MLP and k-fold CV
545-
hidden_units_per_dim : int, optional, default: 10
553+
hidden_units_per_dim : int, optional, default: 16
546554
Number of hidden units in the MLP, relative to the input dimensions.
547-
Example: source samples are 5D, hidden_units_per_dim=10 -> 50 hidden units per layer
555+
Example: source samples are 5D, hidden_units_per_dim=16 -> 80 hidden units per layer
556+
aggregate_output : bool, optional, default: True
557+
Whether to return a single value aggregated over all cross-validation runs
558+
or all values from all runs. If left at default, the empirical mean will be returned
548559
549560
Returns
550561
-------
@@ -558,9 +569,11 @@ def c2st(source_samples, target_samples, n_folds=5, scoring="accuracy", normaliz
558569

559570
num_dims = x.shape[1]
560571
if not num_dims == y.shape[1]:
561-
raise ShapeError(f"source_samples and target_samples can have different number of observations (1st dim)"
562-
f"but must have the same dimensionality (2nd dim)"
563-
f"found: source_samples {source_samples.shape[1]}, target_samples {target_samples.shape[1]}")
572+
raise ShapeError(
573+
f"source_samples and target_samples can have different number of observations (1st dim)"
574+
f"but must have the same dimensionality (2nd dim)"
575+
f"found: source_samples {source_samples.shape[1]}, target_samples {target_samples.shape[1]}"
576+
)
564577

565578
if normalize:
566579
x_mean = np.mean(x, axis=0)
@@ -587,5 +600,6 @@ def c2st(source_samples, target_samples, n_folds=5, scoring="accuracy", normaliz
587600
shuffle = KFold(n_splits=n_folds, shuffle=True, random_state=seed)
588601
scores = cross_val_score(clf, data, target, cv=shuffle, scoring=scoring)
589602

590-
c2st_score = np.asarray(np.mean(scores)).astype(np.float32)
603+
if aggregate_output:
604+
c2st_score = np.asarray(np.mean(scores)).astype(np.float32)
591605
return c2st_score

0 commit comments

Comments
 (0)