Skip to content

Commit cddbce8

Browse files
rename diagnostics plots (#265)
* rename diagnostics plots * Slight name change --------- Co-authored-by: stefanradev93 <[email protected]>
1 parent 6fb8b83 commit cddbce8

18 files changed

+475
-348
lines changed

bayesflow/diagnostics/__init__.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from .plot_losses import plot_losses
2-
from .plot_recovery import plot_recovery
3-
from .plot_sbc_ecdf import plot_sbc_ecdf
4-
from .plot_sbc_histograms import plot_sbc_histograms
5-
from .plot_samples_2d import plot_samples_2d
6-
from .plot_z_score_contraction import plot_z_score_contraction
7-
from .plot_prior_2d import plot_prior_2d
8-
from .plot_posterior_2d import plot_posterior_2d
9-
from .plot_calibration_curves import plot_calibration_curves
1+
from .plots import calibration_ecdf
2+
from .plots import calibration_histogram
3+
from .plots import loss
4+
from .plots import mc_calibration
5+
from .plots import mc_confusion_matrix
6+
from .plots import mmd_hypothesis_test
7+
from .plots import pairs_posterior
8+
from .plots import pairs_prior
9+
from .plots import pairs_samples
10+
from .plots import recovery
11+
from .plots import z_score_contraction
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .calibration_ecdf import calibration_ecdf
2+
from .calibration_histogram import calibration_histogram
3+
from .loss import loss
4+
from .mc_calibration import mc_calibration
5+
from .mc_confusion_matrix import mc_confusion_matrix
6+
from .mmd_hypothesis_test import mmd_hypothesis_test
7+
from .pairs_posterior import pairs_posterior
8+
from .pairs_prior import pairs_prior
9+
from .pairs_samples import pairs_samples
10+
from .recovery import recovery
11+
from .z_score_contraction import z_score_contraction

bayesflow/diagnostics/plot_sbc_ecdf.py renamed to bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
import matplotlib.pyplot as plt
33

44
from typing import Sequence
5-
from ..utils.plot_utils import preprocess, add_titles_and_labels, prettify_subplots
6-
from ..utils.ecdf import simultaneous_ecdf_bands
7-
from ..utils.ecdf.ranks import fractional_ranks, distance_ranks
5+
from ...utils.plot_utils import preprocess, add_titles_and_labels, prettify_subplots
6+
from ...utils.ecdf import simultaneous_ecdf_bands
7+
from ...utils.ecdf.ranks import fractional_ranks, distance_ranks
88

99

10-
def plot_sbc_ecdf(
10+
def calibration_ecdf(
1111
post_samples: dict[str, np.ndarray] | np.ndarray,
1212
prior_samples: dict[str, np.ndarray] | np.ndarray,
1313
filter_keys: Sequence[str] = None,
@@ -61,12 +61,15 @@ def plot_sbc_ecdf(
6161
stacked : bool, optional, default: False
6262
If `True`, all ECDFs will be plotted on the same plot.
6363
If `False`, each ECDF will have its own subplot,
64-
similar to the behavior of `plot_sbc_histograms`.
64+
similar to the behavior of `calibration_histogram`.
6565
rank_type : str, optional, default: 'fractional'
66-
If `fractional` (default), the ranks are computed as the fraction of posterior samples that are smaller than
67-
the prior. If `distance`, the ranks are computed as the fraction of posterior samples that are closer to
68-
a reference points (default here is the origin). You can pass a reference array in the same shape as the
69-
`prior_samples` array by setting `references` in the ``ranks_kwargs``. This is motivated by [2].
66+
If `fractional` (default), the ranks are computed as the fraction
67+
of posterior samples that are smaller than the prior.
68+
If `distance`, the ranks are computed as the fraction of posterior
69+
samples that are closer to a reference points (default here is the origin).
70+
You can pass a reference array in the same shape as the
71+
`prior_samples` array by setting `references` in the ``ranks_kwargs``.
72+
This is motivated by [2].
7073
variable_names : list or None, optional, default: None
7174
The parameter names for nice plot titles.
7275
Inferred if None. Only relevant if `stacked=False`.

bayesflow/diagnostics/plot_sbc_histograms.py renamed to bayesflow/diagnostics/plots/calibration_histogram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from bayesflow.utils import preprocess, add_titles_and_labels, prettify_subplots
1010

1111

12-
def plot_sbc_histograms(
12+
def calibration_histogram(
1313
post_samples: dict[str, np.ndarray] | np.ndarray,
1414
prior_samples: dict[str, np.ndarray] | np.ndarray,
1515
filter_keys: Sequence[str] = None,

bayesflow/diagnostics/plot_losses.py renamed to bayesflow/diagnostics/plots/loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
import matplotlib.pyplot as plt
77

88

9-
from ..utils.plot_utils import make_figure, add_titles_and_labels
9+
from ...utils.plot_utils import make_figure, add_titles_and_labels
1010

1111

12-
def plot_losses(
12+
def loss(
1313
train_losses: pd.DataFrame | np.ndarray,
1414
val_losses: pd.DataFrame | np.ndarray = None,
1515
moving_average: bool = False,

bayesflow/diagnostics/plot_calibration_curves.py renamed to bayesflow/diagnostics/plots/mc_calibration.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from bayesflow.utils import expected_calibration_error, preprocess, add_titles_and_labels, add_metric, prettify_subplots
88

99

10-
def plot_calibration_curves(
11-
post_model_samples: dict[str, np.ndarray] | np.ndarray,
12-
true_model_samples: dict[str, np.ndarray] | np.ndarray,
10+
def mc_calibration(
11+
pred_models: dict[str, np.ndarray] | np.ndarray,
12+
true_models: dict[str, np.ndarray] | np.ndarray,
1313
names: Sequence[str] = None,
1414
num_bins: int = 10,
1515
label_fontsize: int = 16,
@@ -28,11 +28,11 @@ def plot_calibration_curves(
2828
2929
Parameters
3030
----------
31-
true_model_samples : np.ndarray of shape (num_data_sets, num_models)
31+
true_models : np.ndarray of shape (num_data_sets, num_models)
3232
The one-hot-encoded true model indices per data set.
33-
post_model_samples : np.ndarray of shape (num_data_sets, num_models)
33+
pred_models : np.ndarray of shape (num_data_sets, num_models)
3434
The predicted posterior model probabilities (PMPs) per data set.
35-
names : list or None, optional, default: None
35+
names : list or None, optional, default: None
3636
The model names for nice plot titles. Inferred if None.
3737
num_bins : int, optional, default: 10
3838
The number of bins to use for the calibration curves (and marginal histograms).
@@ -60,7 +60,7 @@ def plot_calibration_curves(
6060
fig : plt.Figure - the figure instance for optional saving
6161
"""
6262

63-
plot_data = preprocess(post_model_samples, true_model_samples, names, num_col, num_row, figsize, context="M")
63+
plot_data = preprocess(pred_models, true_models, names, num_col, num_row, figsize, context="M")
6464

6565
# Compute calibration
6666
cal_errors, true_probs, pred_probs = expected_calibration_error(

bayesflow/diagnostics/plot_confusion_matrix.py renamed to bayesflow/diagnostics/plots/mc_confusion_matrix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from bayesflow.utils.plot_utils import make_figure
1313

1414

15-
def plot_confusion_matrix(
15+
def mc_confusion_matrix(
1616
true_models: dict[str, np.ndarray] | np.ndarray,
1717
pred_models: dict[str, np.ndarray] | np.ndarray,
1818
model_names: Sequence[str] = None,

bayesflow/diagnostics/plot_mmd_hypothesis_test.py renamed to bayesflow/diagnostics/plots/mmd_hypothesis_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from keras import ops
66

77

8-
def plot_mmd_hypothesis_test(
8+
def mmd_hypothesis_test(
99
mmd_null: np.ndarray,
1010
mmd_observed: float = None,
1111
alpha_level: float = 0.05,

bayesflow/diagnostics/plot_posterior_2d.py renamed to bayesflow/diagnostics/plots/pairs_posterior.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
from matplotlib.lines import Line2D
88

9-
from .plot_samples_2d import plot_samples_2d
9+
from .pairs_samples import pairs_samples
1010

1111

12-
def plot_posterior_2d(
12+
def pairs_posterior(
1313
post_samples: np.ndarray,
1414
prior_samples: np.ndarray = None,
1515
prior=None,
@@ -70,7 +70,7 @@ def plot_posterior_2d(
7070

7171
# Plot posterior first
7272
context = ""
73-
g = plot_samples_2d(
73+
g = pairs_samples(
7474
post_samples, context=context, variable_names=variable_names, render=False, height=height, **kwargs
7575
)
7676

bayesflow/diagnostics/plot_prior_2d.py renamed to bayesflow/diagnostics/plots/pairs_prior.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import seaborn as sns
44

55
from bayesflow.simulators import Simulator
6-
from .plot_samples_2d import plot_samples_2d
6+
from .pairs_samples import pairs_samples
77

88

9-
def plot_prior_2d(
9+
def pairs_prior(
1010
simulator: Simulator,
1111
variable_names: Sequence[str] | str = None,
1212
num_samples: int = 2000,
@@ -43,6 +43,6 @@ def plot_prior_2d(
4343
if isinstance(samples, dict):
4444
samples = samples["theta"]
4545

46-
return plot_samples_2d(
46+
return pairs_samples(
4747
samples, context="Prior", height=height, color=color, param_names=variable_names, render=True, **kwargs
4848
)

0 commit comments

Comments
 (0)