Skip to content

Commit de532c7

Browse files
committed
abs step size
1 parent 612b17b commit de532c7

File tree

1 file changed

+3
-17
lines changed

1 file changed

+3
-17
lines changed

bayesflow/utils/integrate.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def euler_maruyama_step(
336336
noise = {}
337337
for key in diffusion.keys():
338338
shape = keras.ops.shape(diffusion[key])
339-
noise[key] = keras.random.normal(shape) * keras.ops.sqrt(step_size)
339+
noise[key] = keras.random.normal(shape) * keras.ops.sqrt(keras.ops.abs(step_size))
340340

341341
# Check if diffusion and noise have the same keys
342342
if set(diffusion.keys()) != set(noise.keys()):
@@ -391,7 +391,6 @@ def integrate_stochastic(
391391
steps: int,
392392
method: str = "euler_maruyama",
393393
seed: int = None,
394-
return_noise: bool = False,
395394
**kwargs,
396395
) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, List[ArrayLike]]]]:
397396
"""
@@ -406,7 +405,6 @@ def integrate_stochastic(
406405
steps: Number of integration steps.
407406
method: Integration method to use ('euler_maruyama').
408407
seed: Random seed for noise generation.
409-
return_noise: Whether to return the generated noise terms.
410408
**kwargs: Additional arguments to pass to the step function.
411409
412410
Returns:
@@ -435,31 +433,19 @@ def integrate_stochastic(
435433

436434
time = start_time
437435

438-
# Store noise history if requested
439-
noise_history = {key: [] for key in state.keys()} if return_noise else None
440-
441436
def body(_loop_var, _loop_state):
442437
_state, _time = _loop_state
443438

444439
# Generate noise for this step
445440
_noise = {}
446441
for key in _state.keys():
447442
shape = keras.ops.shape(_state[key])
448-
_noise[key] = keras.random.normal(shape) * keras.ops.sqrt(step_size)
449-
450-
# Store noise if requested
451-
if return_noise:
452-
for key in _noise:
453-
noise_history[key].append(_noise[key])
443+
_noise[key] = keras.random.normal(shape) * keras.ops.sqrt(keras.ops.abs(step_size))
454444

455445
# Perform integration step
456446
_state, _time, _ = step_fn(_state, _time, step_size, noise=_noise)
457447

458448
return _state, _time
459449

460450
state, time = keras.ops.fori_loop(0, steps, body, (state, time))
461-
462-
if return_noise:
463-
return state, noise_history
464-
else:
465-
return state
451+
return state

0 commit comments

Comments
 (0)