|
| 1 | +from typing import Sequence, Mapping, Any |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +import keras |
| 6 | + |
| 7 | +from bayesflow.utils.exceptions import ShapeError |
| 8 | +from bayesflow.networks import MLP |
| 9 | + |
| 10 | + |
| 11 | +def classifier_two_sample_test( |
| 12 | + estimates: np.ndarray, |
| 13 | + targets: np.ndarray, |
| 14 | + metric: str = "accuracy", |
| 15 | + patience: int = 10, |
| 16 | + max_epochs: int = 1000, |
| 17 | + batch_size: int = 64, |
| 18 | + return_metric_only: bool = True, |
| 19 | + validation_split: float = 0.5, |
| 20 | + standardize: bool = True, |
| 21 | + mlp_widths: Sequence = (256, 256), |
| 22 | + **kwargs, |
| 23 | +) -> float | Mapping[str, Any]: |
| 24 | + """ |
| 25 | + C2ST metric [1] between samples from two distributions computed using a neural classifier. |
| 26 | + Can be computationally expensive, since it involves training of a neural classifier. |
| 27 | +
|
| 28 | + Note: works best for large numbers of samples and averaged across different posteriors. |
| 29 | +
|
| 30 | + Code adapted from https://github.com/sbi-benchmark/sbibm/blob/main/sbibm/metrics/c2st.py |
| 31 | +
|
| 32 | + [1] Lopez-Paz, D., & Oquab, M. (2016). Revisiting classifier two-sample tests. arXiv:1610.06545. |
| 33 | +
|
| 34 | + Parameters |
| 35 | + ---------- |
| 36 | + estimates : np.ndarray |
| 37 | + Array of shape (num_samples_est, num_variables) containing samples representing estimated quantities |
| 38 | + (e.g., approximate posterior samples). |
| 39 | + targets : np.ndarray |
| 40 | + Array of shape (num_samples_tar, num_variables) containing target samples |
| 41 | + (e.g., samples from a reference posterior). |
| 42 | + metric : str, optional |
| 43 | + Metric to evaluate the classifier performance. Default is "accuracy". |
| 44 | + patience : int, optional |
| 45 | + Number of epochs with no improvement after which training will be stopped. Default is 5. |
| 46 | + max_epochs : int, optional |
| 47 | + Maximum number of epochs to train the classifier. Default is 1000. |
| 48 | + batch_size : int, optional |
| 49 | + Number of samples per batch during training. Default is 64. |
| 50 | + return_metric_only : bool, optional |
| 51 | + If True, only the final validation metric is returned. Otherwise, a dictionary with the score, classifier, and |
| 52 | + full training history is returned. Default is True. |
| 53 | + validation_split : float, optional |
| 54 | + Fraction of the training data to be used as validation data. Default is 0.5. |
| 55 | + standardize : bool, optional |
| 56 | + If True, both estimates and targets will be standardized using the mean and standard deviation of estimates. |
| 57 | + Default is True. |
| 58 | + mlp_widths : Sequence[int], optional |
| 59 | + Sequence specifying the number of units in each hidden layer of the MLP classifier. Default is (256, 256). |
| 60 | + **kwargs |
| 61 | + Additional keyword arguments. Recognized keyword: |
| 62 | + mlp_kwargs : dict |
| 63 | + Dictionary of additional parameters to pass to the MLP constructor. |
| 64 | +
|
| 65 | + Returns |
| 66 | + ------- |
| 67 | + results : float or dict |
| 68 | + If return_metric_only is True, returns the final validation metric (e.g., accuracy) as a float. |
| 69 | + Otherwise, returns a dictionary with keys "score", "classifier", and "history", where "score" |
| 70 | + is the final validation metric, "classifier" is the trained Keras model, and "history" contains the |
| 71 | + full training history. |
| 72 | + """ |
| 73 | + |
| 74 | + # Convert tensors to numpy, if passed |
| 75 | + estimates = keras.ops.convert_to_numpy(estimates) |
| 76 | + targets = keras.ops.convert_to_numpy(targets) |
| 77 | + |
| 78 | + # Error, if targets dim does not match estimates dim |
| 79 | + num_dims = estimates.shape[1] |
| 80 | + if not num_dims == targets.shape[1]: |
| 81 | + raise ShapeError( |
| 82 | + f"estimates and targets can have different number of samples (1st dim)" |
| 83 | + f"but must have the same dimensionality (2nd dim)" |
| 84 | + f"found: estimates shape {estimates.shape[1]}, targets shape {targets.shape[1]}" |
| 85 | + ) |
| 86 | + |
| 87 | + # Standardize both estimates and targets relative to estimates mean and std |
| 88 | + if standardize: |
| 89 | + estimates_mean = np.mean(estimates, axis=0) |
| 90 | + estimates_std = np.std(estimates, axis=0) |
| 91 | + estimates = (estimates - estimates_mean) / estimates_std |
| 92 | + targets = (targets - estimates_mean) / estimates_std |
| 93 | + |
| 94 | + # Create data for classification task |
| 95 | + data = np.r_[estimates, targets] |
| 96 | + labels = np.r_[np.zeros((estimates.shape[0],)), np.ones((targets.shape[0],))] |
| 97 | + |
| 98 | + # Create and train classifier with optional stopping |
| 99 | + classifier = keras.Sequential( |
| 100 | + [MLP(widths=mlp_widths, **kwargs.get("mlp_kwargs", {})), keras.layers.Dense(1, activation="sigmoid")] |
| 101 | + ) |
| 102 | + |
| 103 | + classifier.compile(optimizer="adam", loss="binary_crossentropy", metrics=[metric]) |
| 104 | + |
| 105 | + early_stopping = keras.callbacks.EarlyStopping( |
| 106 | + monitor=f"val_{metric}", patience=patience, restore_best_weights=True |
| 107 | + ) |
| 108 | + |
| 109 | + history = classifier.fit( |
| 110 | + x=data, |
| 111 | + y=labels, |
| 112 | + epochs=max_epochs, |
| 113 | + batch_size=batch_size, |
| 114 | + verbose=0, |
| 115 | + callbacks=[early_stopping], |
| 116 | + validation_split=validation_split, |
| 117 | + ) |
| 118 | + |
| 119 | + if return_metric_only: |
| 120 | + return history.history[f"val_{metric}"][-1] |
| 121 | + return {"score": history.history[f"val_{metric}"][-1], "classifier": classifier, "history": history.history} |
0 commit comments