2222import tensorflow as tf
2323from scipy import stats
2424from sklearn .calibration import calibration_curve
25+ from sklearn .model_selection import KFold , cross_val_score
2526from sklearn .neural_network import MLPClassifier
26- from sklearn .model_selection import cross_val_score , KFold
2727
2828from bayesflow .default_settings import MMD_BANDWIDTH_LIST
2929from bayesflow .exceptions import ShapeError
@@ -521,30 +521,41 @@ def aggregated_rmse(x_true, x_pred):
521521 )
522522
523523
524- def c2st (source_samples , target_samples , n_folds = 5 , scoring = "accuracy" , normalize = True , seed = 123 ,
525- hidden_units_per_dim = 10 ):
526- """C2ST metric [1] using an sklearn MLP classifier.
524+ def c2st (
525+ source_samples ,
526+ target_samples ,
527+ n_folds = 5 ,
528+ scoring = "accuracy" ,
529+ normalize = True ,
530+ seed = 123 ,
531+ hidden_units_per_dim = 16 ,
532+ aggregate_output = True ,
533+ ):
534+ """C2ST metric [1] using an sklearn neural network classifier (i.e., MLP).
527535 Code adapted from https://github.com/sbi-benchmark/sbibm/blob/main/sbibm/metrics/c2st.py
528536
529537 [1] Lopez-Paz, D., & Oquab, M. (2016). Revisiting classifier two-sample tests. arXiv:1610.06545.
530538
531539 Parameters
532540 ----------
533- source_samples : np.ndarray or tf.Tensor
541+ source_samples : np.ndarray or tf.Tensor
534542 Source samples (e.g., approximate posterior samples)
535- target_samples : np.ndarray or tf.Tensor
543+ target_samples : np.ndarray or tf.Tensor
536544 Target samples (e.g., samples from a reference posterior)
537- n_folds : int, optional, default: 5
545+ n_folds : int, optional, default: 5
538546 Number of folds in k-fold cross-validation for the classifier evaluation
539- scoring : str, optional, default: "accuracy"
547+ scoring : str, optional, default: "accuracy"
540548 Evaluation score of the sklearn MLP classifier
541- normalize : bool, optional, default: True
549+ normalize : bool, optional, default: True
542550 Whether the data shall be z-standardized relative to source_samples
543- seed : int, optional, default: 123
551+ seed : int, optional, default: 123
544552 RNG seed for the MLP and k-fold CV
545- hidden_units_per_dim : int, optional, default: 10
553+ hidden_units_per_dim : int, optional, default: 16
546554 Number of hidden units in the MLP, relative to the input dimensions.
547- Example: source samples are 5D, hidden_units_per_dim=10 -> 50 hidden units per layer
555+ Example: source samples are 5D, hidden_units_per_dim=16 -> 80 hidden units per layer
556+ aggregate_output : bool, optional, default: True
557+ Whether to return a single value aggregated over all cross-validation runs
558+ or all values from all runs. If left at default, the empirical mean will be returned
548559
549560 Returns
550561 -------
@@ -558,9 +569,11 @@ def c2st(source_samples, target_samples, n_folds=5, scoring="accuracy", normaliz
558569
559570 num_dims = x .shape [1 ]
560571 if not num_dims == y .shape [1 ]:
561- raise ShapeError (f"source_samples and target_samples can have different number of observations (1st dim)"
562- f"but must have the same dimensionality (2nd dim)"
563- f"found: source_samples { source_samples .shape [1 ]} , target_samples { target_samples .shape [1 ]} " )
572+ raise ShapeError (
573+ f"source_samples and target_samples can have different number of observations (1st dim)"
574+ f"but must have the same dimensionality (2nd dim)"
575+ f"found: source_samples { source_samples .shape [1 ]} , target_samples { target_samples .shape [1 ]} "
576+ )
564577
565578 if normalize :
566579 x_mean = np .mean (x , axis = 0 )
@@ -587,5 +600,6 @@ def c2st(source_samples, target_samples, n_folds=5, scoring="accuracy", normaliz
587600 shuffle = KFold (n_splits = n_folds , shuffle = True , random_state = seed )
588601 scores = cross_val_score (clf , data , target , cv = shuffle , scoring = scoring )
589602
590- c2st_score = np .asarray (np .mean (scores )).astype (np .float32 )
603+ if aggregate_output :
604+ c2st_score = np .asarray (np .mean (scores )).astype (np .float32 )
591605 return c2st_score
0 commit comments