Skip to content

Commit 7b55a37

Browse files
committed
fix stochastic sampler
1 parent e8d34d7 commit 7b55a37

File tree

1 file changed

+36
-43
lines changed

1 file changed

+36
-43
lines changed

bayesflow/utils/integrate.py

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -302,50 +302,39 @@ def euler_maruyama_step(
302302
state: dict[str, ArrayLike],
303303
time: ArrayLike,
304304
step_size: ArrayLike,
305-
seed: keras.random.SeedGenerator,
305+
noise: dict[str, ArrayLike],
306306
) -> (dict[str, ArrayLike], ArrayLike, ArrayLike):
307307
"""
308308
Performs a single Euler-Maruyama step for stochastic differential equations.
309309
310310
Args:
311-
drift_fn: Function that computes the drift term.
312-
diffusion_fn: Function that computes the diffusion term.
313-
state: Dictionary containing the current state.
314-
time: Current time.
315-
step_size: Size of the integration step.
316-
seed: Random seed for noise generation.
311+
drift_fn: Function computing the drift term f(t, **state).
312+
diffusion_fn: Function computing the diffusion term g(t, **state).
313+
state: Current state, mapping variable names to tensors.
314+
time: Current time scalar tensor.
315+
step_size: Time increment dt.
316+
noise: Mapping of variable names to dW noise tensors.
317317
318318
Returns:
319-
Tuple of (new_state, new_time, new_step_size).
319+
new_state: Updated state after one Euler-Maruyama step.
320+
new_time: time + dt.
320321
"""
321-
# Compute drift term
322+
# Compute drift and diffusion
322323
drift = drift_fn(time, **filter_kwargs(state, drift_fn))
323-
324-
# Compute diffusion term
325324
diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn))
326325

327-
# Generate noise for this step
328-
noise = {}
329-
for key in state.keys():
330-
eps = keras.random.normal(keras.ops.shape(state[key]), dtype=keras.ops.dtype(state[key]), seed=seed)
331-
noise[key] = eps * keras.ops.sqrt(keras.ops.abs(step_size))
332-
333-
# Check if diffusion and noise have the same keys
326+
# Check noise keys
334327
if set(diffusion.keys()) != set(noise.keys()):
335328
raise ValueError("Keys of diffusion terms and noise do not match.")
336329

337-
# Apply updates using Euler-Maruyama formula: dx = f(x)dt + g(x)dW
338-
new_state = state.copy()
339-
for key in drift.keys():
340-
if key in diffusion:
341-
new_state[key] = state[key] + (step_size * drift[key]) + (diffusion[key] * noise[key])
342-
else:
343-
# If no diffusion term for this variable, apply deterministic update
344-
new_state[key] = state[key] + step_size * drift[key]
330+
new_state = {}
331+
for key, d in drift.items():
332+
base = state[key] + step_size * d
333+
if key in diffusion: # stochastic update
334+
base = base + diffusion[key] * noise[key]
335+
new_state[key] = base
345336

346-
new_time = time + step_size
347-
348-
return new_state, new_time
337+
return new_state, time + step_size
349338

350339

351340
def integrate_stochastic(
@@ -356,7 +345,7 @@ def integrate_stochastic(
356345
stop_time: ArrayLike,
357346
steps: int,
358347
seed: keras.random.SeedGenerator,
359-
method: Literal["euler_maruyama"] = "euler_maruyama",
348+
method: str = "euler_maruyama",
360349
**kwargs,
361350
) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]:
362351
"""
@@ -370,7 +359,7 @@ def integrate_stochastic(
370359
stop_time: Ending time for integration.
371360
steps: Number of integration steps.
372361
seed: Random seed for noise generation.
373-
method: Integration method to use ('euler_maruyama').
362+
method: Integration method to use, e.g., 'euler_maruyama'.
374363
**kwargs: Additional arguments to pass to the step function.
375364
376365
Returns:
@@ -390,18 +379,22 @@ def integrate_stochastic(
390379
# Prepare step function with partial application
391380
step_fn = partial(step_fn, drift_fn=drift_fn, diffusion_fn=diffusion_fn, **kwargs)
392381

382+
# Time step
393383
step_size = (stop_time - start_time) / steps
394-
time = start_time
395-
current_state = state.copy()
396-
397-
# keras.ops.fori_loop does not support keras seed generator in jax
398-
for i in range(steps):
399-
# Execute the step with the specific seed for this step
400-
current_state, time = step_fn(
401-
state=current_state,
402-
time=time,
403-
step_size=step_size,
404-
seed=seed,
384+
sqrt_dt = keras.ops.sqrt(keras.ops.abs(step_size))
385+
386+
# Pre-generate noise history: shape = (steps, *state_shape)
387+
noise_history = {}
388+
for key, val in state.items():
389+
noise_history[key] = (
390+
keras.random.normal((steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed) * sqrt_dt
405391
)
406392

407-
return current_state
393+
def body(_loop_var, _loop_state):
394+
_current_state, _current_time = _loop_state
395+
_noise_i = {k: noise_history[k][_loop_var] for k in _current_state.keys()}
396+
new_state, new_time = step_fn(state=_current_state, time=_current_time, step_size=step_size, noise=_noise_i)
397+
return new_state, new_time
398+
399+
final_state, final_time = keras.ops.fori_loop(0, steps, body, (state, start_time))
400+
return final_state

0 commit comments

Comments
 (0)