Skip to content

Commit 717c941

Browse files
committed
Update model misspecification toolkit, add tutorial notebook
1 parent de37bc1 commit 717c941

File tree

6 files changed

+1108
-281
lines changed

6 files changed

+1108
-281
lines changed

bayesflow/diagnostics.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,3 +1175,92 @@ def plot_confusion_matrix(
11751175
)
11761176
if title:
11771177
ax.set_title("Confusion Matrix", fontsize=title_fontsize)
1178+
1179+
1180+
def plot_mmd_hypothesis_test(mmd_null,
1181+
mmd_observed=None,
1182+
alpha_level=0.05,
1183+
null_color=(0.16407, 0.020171, 0.577478),
1184+
observed_color="red",
1185+
alpha_color="orange",
1186+
truncate_vlines_at_kde=False,
1187+
xmin=None,
1188+
xmax=None,
1189+
bw_factor=1.5):
1190+
"""
1191+
1192+
Parameters
1193+
----------
1194+
mmd_null: np.ndarray
1195+
samples from the MMD sampling distribution under the null hypothesis "the model is well-specified"
1196+
mmd_observed: float
1197+
observed MMD value
1198+
alpha_level: float
1199+
rejection probability (type I error)
1200+
null_color: color
1201+
color for the H0 sampling distribution
1202+
observed_color: color
1203+
color for the observed MMD
1204+
alpha_color: color
1205+
color for the rejection area
1206+
truncate_vlines_at_kde: bool
1207+
true: cut off the vlines at the kde
1208+
false: continue kde lines across the plot
1209+
xmin: float
1210+
lower x axis limit
1211+
xmax: float
1212+
upper x axis limit
1213+
bw_factor: float, default: 1.5
1214+
bandwidth (aka. smoothing parameter) of the kernel density estimate
1215+
1216+
Returns
1217+
-------
1218+
f : plt.Figure - the figure instance for optional saving
1219+
1220+
"""
1221+
1222+
def draw_vline_to_kde(x, kde_object, color, label=None, **kwargs):
1223+
kde_x, kde_y = kde_object.lines[0].get_data()
1224+
idx = np.argmin(np.abs(kde_x - x))
1225+
plt.vlines(x=x, ymin=0, ymax=kde_y[idx], color=color, linewidth=3, label=label, **kwargs)
1226+
1227+
def fill_area_under_kde(kde_object, x_start, x_end=None, **kwargs):
1228+
kde_x, kde_y = kde_object.lines[0].get_data()
1229+
if x_end is not None:
1230+
plt.fill_between(kde_x, kde_y, where=(kde_x >= x_start) & (kde_x <= x_end),
1231+
interpolate=True, **kwargs)
1232+
else:
1233+
plt.fill_between(kde_x, kde_y, where=(kde_x >= x_start),
1234+
interpolate=True, **kwargs)
1235+
1236+
f = plt.figure(figsize=(8, 4))
1237+
1238+
kde = sns.kdeplot(mmd_null, fill=False, linewidth=0, bw_adjust=bw_factor)
1239+
sns.kdeplot(mmd_null, fill=True, alpha=.12, color=null_color, bw_adjust=bw_factor)
1240+
1241+
if truncate_vlines_at_kde:
1242+
draw_vline_to_kde(x=mmd_observed, kde_object=kde, color=observed_color, label=r"Observed data")
1243+
else:
1244+
plt.vlines(x=mmd_observed, ymin=0, ymax=plt.gca().get_ylim()[1], color=observed_color, linewidth=3,
1245+
label=r"Observed data")
1246+
1247+
mmd_critical = np.quantile(mmd_null, 1 - alpha_level)
1248+
fill_area_under_kde(kde, mmd_critical, color=alpha_color, alpha=0.5, label=fr"{int(alpha_level*100)}% rejection area")
1249+
1250+
if truncate_vlines_at_kde:
1251+
draw_vline_to_kde(x=mmd_critical, kde_object=kde, color=alpha_color)
1252+
else:
1253+
plt.vlines(x=mmd_critical, color=alpha_color, linewidth=3, ymin=0, ymax=plt.gca().get_ylim()[1])
1254+
1255+
sns.kdeplot(mmd_null, fill=False, linewidth=3, color=null_color, label=r"$H_0$", bw_adjust=bw_factor)
1256+
1257+
plt.xlabel(r"MMD", fontsize=20)
1258+
plt.ylabel("")
1259+
plt.yticks([])
1260+
plt.xlim(xmin, xmax)
1261+
plt.tick_params(axis='both', which='major', labelsize=16)
1262+
1263+
plt.legend(fontsize=20)
1264+
sns.despine()
1265+
1266+
return f

bayesflow/exceptions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class LossError(Exception):
3838

3939

4040
class ShapeError(Exception):
41-
"""Class for error in expected shappes."""
41+
"""Class for error in expected shapes."""
4242

4343
pass
4444

@@ -61,3 +61,8 @@ class OperationNotSupportedError(Exception):
6161
"""
6262

6363
pass
64+
65+
66+
class ArgumentError(Exception):
67+
"""Class for error that occurs as a result of a function call which is invalid due to the input arguments."""
68+
pass

bayesflow/trainers.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from bayesflow.configuration import *
3939
from bayesflow.default_settings import DEFAULT_KEYS, OPTIMIZER_DEFAULTS
4040
from bayesflow.diagnostics import plot_latent_space_2d, plot_sbc_histograms
41-
from bayesflow.exceptions import SimulationError
41+
from bayesflow.exceptions import SimulationError, ArgumentError
4242
from bayesflow.helper_classes import (
4343
EarlyStopper,
4444
LossHistory,
@@ -49,6 +49,7 @@
4949
)
5050
from bayesflow.helper_functions import backprop_step, extract_current_lr, format_loss_string, loss_to_string
5151
from bayesflow.simulation import GenerativeModel, MultiGenerativeModel
52+
from bayesflow.computational_utilities import maximum_mean_discrepancy
5253

5354

5455
class Trainer:
@@ -1009,6 +1010,72 @@ def train_rounds(
10091010
self.optimizer = None
10101011
return self.loss_history.get_plottable()
10111012

1013+
def mmd_hypothesis_test(self,
1014+
observed_data,
1015+
reference_data=None,
1016+
num_reference_simulations=1000,
1017+
num_null_samples=100,
1018+
bootstrap=False):
1019+
"""
1020+
1021+
Parameters
1022+
----------
1023+
observed_data: np.ndarray
1024+
Observed data, shape (num_observed, ...)
1025+
reference_data: np.ndarray
1026+
Reference data representing samples from the "well-specified model", shape (num_reference, ...)
1027+
num_reference_simulations: int, default: 1000
1028+
Number of reference simulations (M) simulated from the trainer's generative model
1029+
if no `reference_data` are provided.
1030+
num_null_samples: int, default: 100
1031+
Number of draws from the MMD sampling distribution under the null hypothesis "the trainer's generative
1032+
model is well-specified"
1033+
bootstrap: bool, default: False
1034+
If true, the reference data (see above) are bootstrapped for each sample from the MMD sampling distribution.
1035+
If false, a new data set is simulated for computing each draw from the MMD sampling distribution.
1036+
1037+
Returns
1038+
-------
1039+
mmd_null_samples: np.ndarray
1040+
samples from the H0 sampling distribution ("well-specified model")
1041+
mmd_observed: float
1042+
summary MMD estimate for the observed data sets
1043+
"""
1044+
1045+
if reference_data is None:
1046+
if self.generative_model is None:
1047+
raise ArgumentError("If you do not provide reference data, your trainer must have a generative model!")
1048+
1049+
reference_data = self.configurator(self.generative_model(num_reference_simulations))
1050+
1051+
if type(reference_data) == dict and 'summary_conditions' in reference_data.keys():
1052+
reference_summary = self.amortizer.summary_net(reference_data["summary_conditions"])
1053+
else:
1054+
reference_summary = self.amortizer.summary_net(reference_data)
1055+
1056+
if type(observed_data) == dict and 'summary_conditions' in observed_data.keys():
1057+
observed_summary = self.amortizer.summary_net(observed_data["summary_conditions"])
1058+
else:
1059+
observed_summary = self.amortizer.summary_net(observed_data)
1060+
1061+
num_observed = observed_summary.shape[0]
1062+
num_reference = reference_summary.shape[0]
1063+
1064+
mmd_null_samples = np.empty(num_null_samples, dtype=np.float32)
1065+
for i in tqdm(range(num_null_samples)):
1066+
if bootstrap:
1067+
bootstrap_idx = np.random.randint(0, num_reference, size=num_observed)
1068+
simulated_summary = tf.gather(reference_summary, bootstrap_idx, axis=0)
1069+
else:
1070+
simulated_data = self.configurator(self.generative_model(num_observed))
1071+
simulated_summary = self.amortizer.summary_net(simulated_data["summary_conditions"])
1072+
1073+
mmd_null_samples[i] = np.sqrt(maximum_mean_discrepancy(reference_summary, simulated_summary).numpy())
1074+
1075+
mmd_observed = np.sqrt(maximum_mean_discrepancy(reference_summary, observed_summary).numpy())
1076+
1077+
return mmd_null_samples, mmd_observed
1078+
10121079
def _config_validation(self, validation_sims, **kwargs):
10131080
"""Helper method to prepare validation set based on user input."""
10141081

0 commit comments

Comments
 (0)