@@ -178,19 +178,23 @@ def __init__(self, drift_net, summary_net=None, latent_dist=None, loss_fun=None,
178178 self .loss_fun = self ._determine_loss (loss_fun )
179179 self .summary_loss = self ._determine_summary_loss (summary_loss_fun )
180180
181- def call (self , input_dict , return_summary = False , num_eval_points = 32 , ** kwargs ):
181+ def call (self , input_dict , return_summary = False , num_eval_points = 1 , ** kwargs ):
182182 """Performs a forward pass through the summary and drift network given an input dictionary.
183183
184184 Parameters
185185 ----------
186- input_dict : dict
186+ input_dict : dict
187187 Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
188188 ``targets`` - the latent model parameters over which a condition density is learned
189189 ``summary_conditions`` - the conditioning variables (including data) that are first passed through a summary network
190190 ``direct_conditions`` - the conditioning variables that the directly passed to the inference network
191- return_summary : bool, optional, default: False
191+ return_summary : bool, optional, default: False
192192 A flag which determines whether the learnable data summaries (representations) are returned or not.
193- **kwargs : dict, optional, default: {}
193+ num_eval_points : int, optional, default: 1
194+ The number of time points for evaluating the noisy estimator. Values larger than the default 1
195+ may reduce the variance of the estimator, but may lead to increased memory demands, since an
196+ additional dimension is added at axis 1 of all tensors.
197+ **kwargs : dict, optional, default: {}
194198 Additional keyword arguments passed to the networks
195199 For instance, ``kwargs={'training': True}`` is passed automatically during training.
196200
@@ -215,13 +219,15 @@ def call(self, input_dict, return_summary=False, num_eval_points=32, **kwargs):
215219 # Sample latent variables
216220 latent_vars = self .latent_dist .sample (batch_size )
217221
218- # Do a little trick for less noisy estimator
219- target_vars = tf .stack ([target_vars ] * num_eval_points , axis = 1 )
220- latent_vars = tf .stack ([latent_vars ] * num_eval_points , axis = 1 )
221- full_cond = tf .stack ([full_cond ] * num_eval_points , axis = 1 )
222-
223- # Sample time
224- time = tf .random .uniform ((batch_size , num_eval_points , 1 ))
222+ # Do a little trick for less noisy estimator, if evals > 1
223+ if num_eval_points > 1 :
224+ target_vars = tf .stack ([target_vars ] * num_eval_points , axis = 1 )
225+ latent_vars = tf .stack ([latent_vars ] * num_eval_points , axis = 1 )
226+ full_cond = tf .stack ([full_cond ] * num_eval_points , axis = 1 )
227+ # Sample time
228+ time = tf .random .uniform ((batch_size , num_eval_points , 1 ))
229+ else :
230+ time = tf .random .uniform ((batch_size , 1 ))
225231
226232 # Compute drift
227233 net_out = self .drift_net (target_vars , latent_vars , time , full_cond , ** kwargs )
0 commit comments