3535 AmortizedPosterior ,
3636 AmortizedPosteriorLikelihood ,
3737)
38+ from bayesflow .computational_utilities import maximum_mean_discrepancy
3839from bayesflow .configuration import *
3940from bayesflow .default_settings import DEFAULT_KEYS , OPTIMIZER_DEFAULTS
4041from bayesflow .diagnostics import plot_latent_space_2d , plot_sbc_histograms
41- from bayesflow .exceptions import SimulationError , ArgumentError
42+ from bayesflow .exceptions import ArgumentError , SimulationError
4243from bayesflow .helper_classes import (
4344 EarlyStopper ,
4445 LossHistory ,
4950)
5051from bayesflow .helper_functions import backprop_step , extract_current_lr , format_loss_string , loss_to_string
5152from bayesflow .simulation import GenerativeModel , MultiGenerativeModel
52- from bayesflow .computational_utilities import maximum_mean_discrepancy
5353
5454
5555class Trainer :
@@ -116,7 +116,7 @@ def __init__(
116116 max_to_keep = 3 ,
117117 default_lr = 0.0005 ,
118118 skip_checks = False ,
119- memory = True ,
119+ memory = False ,
120120 ** kwargs ,
121121 ):
122122 """Creates a trainer which will use a generative model (or data simulated from it) to optimize
@@ -139,7 +139,7 @@ def __init__(
139139 The default learning rate to use for default optimizers.
140140 skip_checks : bool, optional, default: False
141141 If True, do not perform consistency checks, i.e., simulator runs and passed through nets
142- memory : bool or bayesflow.SimulationMemory, optional, default: True
142+ memory : bool or bayesflow.SimulationMemory, optional, default: False
143143 If ``True``, store a pre-defined amount of simulations for later use (validation, etc.).
144144 If ``SimulationMemory`` instance provided, stores a reference to the instance.
145145 Otherwise the corresponding attribute will be set to None.
@@ -1010,12 +1010,9 @@ def train_rounds(
10101010 self .optimizer = None
10111011 return self .loss_history .get_plottable ()
10121012
1013- def mmd_hypothesis_test (self ,
1014- observed_data ,
1015- reference_data = None ,
1016- num_reference_simulations = 1000 ,
1017- num_null_samples = 100 ,
1018- bootstrap = False ):
1013+ def mmd_hypothesis_test (
1014+ self , observed_data , reference_data = None , num_reference_simulations = 1000 , num_null_samples = 100 , bootstrap = False
1015+ ):
10191016 """
10201017
10211018 Parameters
@@ -1048,12 +1045,12 @@ def mmd_hypothesis_test(self,
10481045
10491046 reference_data = self .configurator (self .generative_model (num_reference_simulations ))
10501047
1051- if type (reference_data ) == dict and ' summary_conditions' in reference_data .keys ():
1048+ if type (reference_data ) == dict and " summary_conditions" in reference_data .keys ():
10521049 reference_summary = self .amortizer .summary_net (reference_data ["summary_conditions" ])
10531050 else :
10541051 reference_summary = self .amortizer .summary_net (reference_data )
10551052
1056- if type (observed_data ) == dict and ' summary_conditions' in observed_data .keys ():
1053+ if type (observed_data ) == dict and " summary_conditions" in observed_data .keys ():
10571054 observed_summary = self .amortizer .summary_net (observed_data ["summary_conditions" ])
10581055 else :
10591056 observed_summary = self .amortizer .summary_net (observed_data )
0 commit comments