Skip to content

Commit effa62f

Browse files
committed
Slight doc changes
1 parent fd092ab commit effa62f

File tree

3 files changed

+18
-14
lines changed

3 files changed

+18
-14
lines changed

bayesflow/amortizers.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import numpy as np
2222
import tensorflow as tf
23-
from tensorflow.keras import Model
2423

2524
from bayesflow.exceptions import ConfigurationError, SummaryStatsError
2625
from bayesflow.losses import log_loss, mmd_summary_space, kl_dirichlet
@@ -59,7 +58,7 @@ def log_prob(input_dict, **kwargs):
5958
pass
6059

6160

62-
class AmortizedPosterior(Model, AmortizedTarget):
61+
class AmortizedPosterior(tf.keras.Model, AmortizedTarget):
6362
"""A wrapper to connect an inference network for parameter estimation with an optional summary network
6463
as in the original BayesFlow set-up described in the paper:
6564
@@ -115,7 +114,7 @@ def __init__(self, inference_net, summary_net=None, latent_dist=None, latent_is_
115114
any `sumamry_conditions`, i.e., `summary_conditions` should be set to None, otherwise these will be ignored.
116115
"""
117116

118-
Model.__init__(self, **kwargs)
117+
tf.keras.Model.__init__(self, **kwargs)
119118

120119
self.inference_net = inference_net
121120
self.summary_net = summary_net
@@ -409,7 +408,7 @@ def _determine_summary_loss(self, loss_fun):
409408
raise NotImplementedError("Could not infer summary_loss_fun, argument should be of type (None, callable, or str)!")
410409

411410

412-
class AmortizedLikelihood(Model, AmortizedTarget):
411+
class AmortizedLikelihood(tf.keras.Model, AmortizedTarget):
413412
"""An interface for a surrogate model of a simulator, or an implicit likelihood
414413
``p(params | data, context).''
415414
"""
@@ -427,7 +426,7 @@ def __init__(self, surrogate_net, latent_dist=None, **kwargs):
427426
a multivariate unit Gaussian.
428427
"""
429428

430-
Model.__init__(self, **kwargs)
429+
tf.keras.Model.__init__(self, **kwargs)
431430

432431
self.surrogate_net = surrogate_net
433432
self.latent_dim = self.surrogate_net.latent_dim
@@ -616,7 +615,7 @@ def _determine_latent_dist(self, latent_dist):
616615
return latent_dist
617616

618617

619-
class AmortizedPosteriorLikelihood(Model, AmortizedTarget):
618+
class AmortizedPosteriorLikelihood(tf.keras.Model, AmortizedTarget):
620619
"""An interface for jointly learning a surrogate model of the simulator and an approximate
621620
posterior given a generative model.
622621
"""
@@ -632,7 +631,7 @@ def __init__(self, amortized_posterior, amortized_likelihood, **kwargs):
632631
The generative neural likelihood approximator.
633632
"""
634633

635-
Model.__init__(self, **kwargs)
634+
tf.keras.Model.__init__(self, **kwargs)
636635

637636
self.amortized_posterior = amortized_posterior
638637
self.amortized_likelihood = amortized_likelihood

bayesflow/diagnostics.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def plot_sbc_histograms(post_samples, prior_samples, param_names=None, fig_size=
409409
return f
410410

411411

412-
def plot_posterior_2d(posterior_draws, prior=None, prior_draws=None, param_names=None, height=2,
412+
def plot_posterior_2d(posterior_draws, prior=None, prior_draws=None, param_names=None, height=3,
413413
legend_fontsize=14, post_color='#8f2727', prior_color='gray', post_alpha=0.9,
414414
prior_alpha=0.7):
415415
"""Generates a bivariate pairplot given posterior draws and optional prior or prior draws.
@@ -423,7 +423,7 @@ def plot_posterior_2d(posterior_draws, prior=None, prior_draws=None, param_names
423423
will be used.
424424
param_names : list or None, optional, default: None
425425
The parameter names for nice plot titles. Inferred if None
426-
height : float, optional, default: 2.
426+
height : float, optional, default: 3.
427427
The height of the pairplot.
428428
legend_fontsize : int, optional, default: 14
429429
The font size of the legend text.
@@ -466,9 +466,12 @@ def plot_posterior_2d(posterior_draws, prior=None, prior_draws=None, param_names
466466
# Attempt to determine parameter names
467467
if param_names is None:
468468
if hasattr(prior, 'param_names'):
469-
param_names = prior.param_names
469+
if prior.param_names is not None:
470+
param_names = prior.param_names
471+
else:
472+
param_names = [f'$p_{i}$' for i in range(1, n_params+1)]
470473
else:
471-
param_names = [f'Param. {p}' for p in range(1, n_params+1)]
474+
param_names = [f'$p_{i}$' for i in range(1, n_params+1)]
472475

473476
# Pack posterior draws into a dataframe
474477
posterior_draws_df = pd.DataFrame(posterior_draws, columns=param_names)

bayesflow/trainers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class Trainer:
8585
"""
8686

8787
def __init__(self, amortizer, generative_model=None, configurator=None, checkpoint_path=None,
88-
max_to_keep=3, default_lr=0.0005, skip_checks=False, memory=True, **kwargs):
88+
max_to_keep=3, default_lr=0.001, skip_checks=False, memory=True, **kwargs):
8989
"""Creates a trainer which will use a generative model (or data simulated from it) to optimize
9090
a neural arhcitecture (amortizer) for amortized posterior inference, likelihood inference, or both.
9191
@@ -102,7 +102,7 @@ def __init__(self, amortizer, generative_model=None, configurator=None, checkpoi
102102
Optional file path for storing the trained amortizer, loss history and optional memory.
103103
max_to_keep : int, optional, default: 3
104104
Number of checkpoints and loss history snapshots to keep.
105-
default_lr : float, optional, default: 0.0005
105+
default_lr : float, optional, default: 0.001
106106
The default learning rate to use for default optimizers.
107107
skip_checks : boolean, optional, default: False
108108
If True, do not perform consistency checks, i.e., simulator runs and passed through nets
@@ -271,7 +271,7 @@ def diagnose_sbc_histograms(self, inputs=None, n_samples=None, **kwargs):
271271

272272
# Check for prior names and override keyword if available
273273
plot_kwargs = kwargs.pop('plot_args', {})
274-
if type(self.generative_model) is GenerativeModel and plot_kwargs.get('param_names') is not None:
274+
if type(self.generative_model) is GenerativeModel and plot_kwargs.get('param_names') is None:
275275
plot_kwargs['param_names'] = self.generative_model.param_names
276276

277277
return plot_sbc_histograms(post_samples, prior_samples, **plot_kwargs)
@@ -414,7 +414,9 @@ def train_offline(self, simulations_dict, epochs, batch_size, save_checkpoint=Tr
414414

415415
# Convert to custom data set
416416
data_set = SimulationDataset(simulations_dict, batch_size)
417+
# Prepare optimizer and initislize loss history
417418
self._setup_optimizer(optimizer, epochs, len(data_set.data))
419+
418420
self.loss_history.start_new_run()
419421
for ep in range(1, epochs+1):
420422

0 commit comments

Comments
 (0)