Skip to content

Commit 9941fa3

Browse files
committed
seed in stochastic sampler
1 parent ebafc5e commit 9941fa3

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

bayesflow/utils/integrate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def integrate_stochastic(
383383
stop_time: ArrayLike,
384384
steps: int,
385385
method: str = "euler_maruyama",
386-
seed: int | keras.random.SeedGenerator = None,
386+
seed: keras.random.SeedGenerator = None,
387387
**kwargs,
388388
) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, List[ArrayLike]]]]:
389389
"""
@@ -428,8 +428,8 @@ def body(_loop_var, _loop_state):
428428
# Generate noise for this step
429429
_noise = {}
430430
for key in _state.keys():
431-
shape = keras.ops.shape(_state[key])
432-
_noise[key] = keras.random.normal(shape, seed=_seed) * keras.ops.sqrt(keras.ops.abs(step_size))
431+
_eps = keras.random.normal(keras.ops.shape(_state[key]), dtype=keras.ops.dtype(_state[key]), seed=_seed)
432+
_noise[key] = _eps * keras.ops.sqrt(keras.ops.abs(step_size))
433433

434434
# Perform integration step
435435
_state, _time, _ = step_fn(_state, _time, step_size, noise=_noise)

0 commit comments

Comments
 (0)