Skip to content

Commit 1cba9c6

Browse files
committed
Change default num_eval_points to 1 for reduced default memory consumption
1 parent cc089e9 commit 1cba9c6

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

bayesflow/experimental/rectifiers.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -178,19 +178,23 @@ def __init__(self, drift_net, summary_net=None, latent_dist=None, loss_fun=None,
178178
self.loss_fun = self._determine_loss(loss_fun)
179179
self.summary_loss = self._determine_summary_loss(summary_loss_fun)
180180

181-
def call(self, input_dict, return_summary=False, num_eval_points=32, **kwargs):
181+
def call(self, input_dict, return_summary=False, num_eval_points=1, **kwargs):
182182
"""Performs a forward pass through the summary and drift network given an input dictionary.
183183
184184
Parameters
185185
----------
186-
input_dict : dict
186+
input_dict : dict
187187
Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
188188
``targets`` - the latent model parameters over which a condition density is learned
189189
``summary_conditions`` - the conditioning variables (including data) that are first passed through a summary network
190190
``direct_conditions`` - the conditioning variables that the directly passed to the inference network
191-
return_summary : bool, optional, default: False
191+
return_summary : bool, optional, default: False
192192
A flag which determines whether the learnable data summaries (representations) are returned or not.
193-
**kwargs : dict, optional, default: {}
193+
num_eval_points : int, optional, default: 1
194+
The number of time points for evaluating the noisy estimator. Values larger than the default 1
195+
may reduce the variance of the estimator, but may lead to increased memory demands, since an
196+
additional dimension is added at axis 1 of all tensors.
197+
**kwargs : dict, optional, default: {}
194198
Additional keyword arguments passed to the networks
195199
For instance, ``kwargs={'training': True}`` is passed automatically during training.
196200
@@ -215,13 +219,15 @@ def call(self, input_dict, return_summary=False, num_eval_points=32, **kwargs):
215219
# Sample latent variables
216220
latent_vars = self.latent_dist.sample(batch_size)
217221

218-
# Do a little trick for less noisy estimator
219-
target_vars = tf.stack([target_vars] * num_eval_points, axis=1)
220-
latent_vars = tf.stack([latent_vars] * num_eval_points, axis=1)
221-
full_cond = tf.stack([full_cond] * num_eval_points, axis=1)
222-
223-
# Sample time
224-
time = tf.random.uniform((batch_size, num_eval_points, 1))
222+
# Do a little trick for less noisy estimator, if evals > 1
223+
if num_eval_points > 1:
224+
target_vars = tf.stack([target_vars] * num_eval_points, axis=1)
225+
latent_vars = tf.stack([latent_vars] * num_eval_points, axis=1)
226+
full_cond = tf.stack([full_cond] * num_eval_points, axis=1)
227+
# Sample time
228+
time = tf.random.uniform((batch_size, num_eval_points, 1))
229+
else:
230+
time = tf.random.uniform((batch_size, 1))
225231

226232
# Compute drift
227233
net_out = self.drift_net(target_vars, latent_vars, time, full_cond, **kwargs)

0 commit comments

Comments
 (0)