Skip to content

Commit 668f6fc

Browse files
committed
seed in stochastic sampler
1 parent eb96620 commit 668f6fc

File tree

2 files changed

+6
-16
lines changed

2 files changed

+6
-16
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -677,9 +677,10 @@ def diffusion(time, xz):
677677
return {"xz": self.compute_diffusion_term(xz, time=time, training=training)}
678678

679679
state = integrate_stochastic(
680-
deltas,
681-
diffusion,
682-
state,
680+
drift_fn=deltas,
681+
diffusion_fn=diffusion,
682+
state=state,
683+
seed=self.seed_generator,
683684
**integrate_kwargs,
684685
)
685686
else:

bayesflow/utils/integrate.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def euler_maruyama_step(
301301
state: dict[str, ArrayLike],
302302
time: ArrayLike,
303303
step_size: ArrayLike,
304-
noise: dict[str, ArrayLike] = None,
304+
noise: dict[str, ArrayLike],
305305
tolerance: ArrayLike = 1e-6,
306306
min_step_size: ArrayLike = -float("inf"),
307307
max_step_size: ArrayLike = float("inf"),
@@ -331,13 +331,6 @@ def euler_maruyama_step(
331331
# Compute diffusion term
332332
diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn))
333333

334-
# Generate noise if not provided
335-
if noise is None:
336-
noise = {}
337-
for key in diffusion.keys():
338-
shape = keras.ops.shape(diffusion[key])
339-
noise[key] = keras.random.normal(shape) * keras.ops.sqrt(keras.ops.abs(step_size))
340-
341334
# Check if diffusion and noise have the same keys
342335
if set(diffusion.keys()) != set(noise.keys()):
343336
raise ValueError("Keys of diffusion terms and noise do not match.")
@@ -414,10 +407,6 @@ def integrate_stochastic(
414407
if steps <= 0:
415408
raise ValueError("Number of steps must be positive.")
416409

417-
# Set random seed if provided
418-
if seed is not None:
419-
keras.random.set_seed(seed)
420-
421410
# Select step function based on method
422411
match method:
423412
case "euler_maruyama":
@@ -440,7 +429,7 @@ def body(_loop_var, _loop_state):
440429
_noise = {}
441430
for key in _state.keys():
442431
shape = keras.ops.shape(_state[key])
443-
_noise[key] = keras.random.normal(shape) * keras.ops.sqrt(keras.ops.abs(step_size))
432+
_noise[key] = keras.random.normal(shape, seed=seed) * keras.ops.sqrt(keras.ops.abs(step_size))
444433

445434
# Perform integration step
446435
_state, _time, _ = step_fn(_state, _time, step_size, noise=_noise)

0 commit comments

Comments
 (0)