|
4 | 4 | import keras |
5 | 5 |
|
6 | 6 | import numpy as np |
7 | | -from typing import Literal |
| 7 | +from typing import Literal, Union, List |
8 | 8 |
|
9 | 9 | from bayesflow.types import Tensor |
10 | 10 | from bayesflow.utils import filter_kwargs |
@@ -293,3 +293,159 @@ def integrate( |
293 | 293 | return integrate_scheduled(fn, state, steps, method, **kwargs) |
294 | 294 | else: |
295 | 295 | raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") |
| 296 | + |
| 297 | + |
| 298 | +def euler_maruyama_step( |
| 299 | + drift_fn: Callable, |
| 300 | + diffusion_fn: Callable, |
| 301 | + state: dict[str, ArrayLike], |
| 302 | + time: ArrayLike, |
| 303 | + step_size: ArrayLike, |
| 304 | + noise: dict[str, ArrayLike] = None, |
| 305 | + tolerance: ArrayLike = 1e-6, |
| 306 | + min_step_size: ArrayLike = -float("inf"), |
| 307 | + max_step_size: ArrayLike = float("inf"), |
| 308 | + use_adaptive_step_size: bool = False, |
| 309 | +) -> (dict[str, ArrayLike], ArrayLike, ArrayLike): |
| 310 | + """ |
| 311 | + Performs a single Euler-Maruyama step for stochastic differential equations. |
| 312 | +
|
| 313 | + Args: |
| 314 | + drift_fn: Function that computes the drift term. |
| 315 | + diffusion_fn: Function that computes the diffusion term. |
| 316 | + state: Dictionary containing the current state. |
| 317 | + time: Current time. |
| 318 | + step_size: Size of the integration step. |
| 319 | + noise: Dictionary of noise terms for each state variable. |
| 320 | + tolerance: Error tolerance for adaptive step size. |
| 321 | + min_step_size: Minimum allowed step size. |
| 322 | + max_step_size: Maximum allowed step size. |
| 323 | + use_adaptive_step_size: Whether to use adaptive step sizing. |
| 324 | +
|
| 325 | + Returns: |
| 326 | + Tuple of (new_state, new_time, new_step_size). |
| 327 | + """ |
| 328 | + # Compute drift term |
| 329 | + drift = drift_fn(time, **filter_kwargs(state, drift_fn)) |
| 330 | + |
| 331 | + # Compute diffusion term |
| 332 | + diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) |
| 333 | + |
| 334 | + # Generate noise if not provided |
| 335 | + if noise is None: |
| 336 | + noise = {} |
| 337 | + for key in diffusion.keys(): |
| 338 | + shape = keras.ops.shape(diffusion[key]) |
| 339 | + noise[key] = keras.random.normal(shape) * keras.ops.sqrt(keras.ops.abs(step_size)) |
| 340 | + |
| 341 | + # Check if diffusion and noise have the same keys |
| 342 | + if set(diffusion.keys()) != set(noise.keys()): |
| 343 | + raise ValueError("Keys of diffusion terms and noise do not match.") |
| 344 | + |
| 345 | + if use_adaptive_step_size: |
| 346 | + # Perform a half-step to estimate error |
| 347 | + intermediate_state = state.copy() |
| 348 | + for key in drift.keys(): |
| 349 | + intermediate_state[key] = state[key] + (step_size * drift[key]) + (diffusion[key] * noise[key]) |
| 350 | + |
| 351 | + # Compute drift and diffusion at intermediate state |
| 352 | + intermediate_drift = drift_fn(time + step_size, **filter_kwargs(intermediate_state, drift_fn)) |
| 353 | + |
| 354 | + # Compute error estimate |
| 355 | + error_terms = [] |
| 356 | + for key in drift.keys(): |
| 357 | + error = keras.ops.norm(intermediate_drift[key] - drift[key], ord=2, axis=-1) |
| 358 | + error_terms.append(error) |
| 359 | + |
| 360 | + intermediate_error = keras.ops.stack(error_terms) |
| 361 | + new_step_size = step_size * tolerance / (intermediate_error + 1e-9) |
| 362 | + |
| 363 | + # Apply constraints to step size |
| 364 | + new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) |
| 365 | + |
| 366 | + # Consolidate step size |
| 367 | + new_step_size = keras.ops.take(new_step_size, keras.ops.argmin(keras.ops.abs(new_step_size))) |
| 368 | + else: |
| 369 | + new_step_size = step_size |
| 370 | + |
| 371 | + # Apply updates using Euler-Maruyama formula: dx = f(x)dt + g(x)dW |
| 372 | + new_state = state.copy() |
| 373 | + for key in drift.keys(): |
| 374 | + if key in diffusion: |
| 375 | + new_state[key] = state[key] + (step_size * drift[key]) + (diffusion[key] * noise[key]) |
| 376 | + else: |
| 377 | + # If no diffusion term for this variable, apply deterministic update |
| 378 | + new_state[key] = state[key] + step_size * drift[key] |
| 379 | + |
| 380 | + new_time = time + step_size |
| 381 | + |
| 382 | + return new_state, new_time, new_step_size |
| 383 | + |
| 384 | + |
| 385 | +def integrate_stochastic( |
| 386 | + drift_fn: Callable, |
| 387 | + diffusion_fn: Callable, |
| 388 | + state: dict[str, ArrayLike], |
| 389 | + start_time: ArrayLike, |
| 390 | + stop_time: ArrayLike, |
| 391 | + steps: int, |
| 392 | + method: str = "euler_maruyama", |
| 393 | + seed: int = None, |
| 394 | + **kwargs, |
| 395 | +) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, List[ArrayLike]]]]: |
| 396 | + """ |
| 397 | + Integrates a stochastic differential equation from start_time to stop_time. |
| 398 | +
|
| 399 | + Args: |
| 400 | + drift_fn: Function that computes the drift term. |
| 401 | + diffusion_fn: Function that computes the diffusion term. |
| 402 | + state: Dictionary containing the initial state. |
| 403 | + start_time: Starting time for integration. |
| 404 | + stop_time: Ending time for integration. |
| 405 | + steps: Number of integration steps. |
| 406 | + method: Integration method to use ('euler_maruyama'). |
| 407 | + seed: Random seed for noise generation. |
| 408 | + **kwargs: Additional arguments to pass to the step function. |
| 409 | +
|
| 410 | + Returns: |
| 411 | + If return_noise is False, returns the final state dictionary. |
| 412 | + If return_noise is True, returns a tuple of (final_state, noise_history). |
| 413 | + """ |
| 414 | + if steps <= 0: |
| 415 | + raise ValueError("Number of steps must be positive.") |
| 416 | + |
| 417 | + # Set random seed if provided |
| 418 | + if seed is not None: |
| 419 | + keras.random.set_seed(seed) |
| 420 | + |
| 421 | + # Select step function based on method |
| 422 | + match method: |
| 423 | + case "euler_maruyama": |
| 424 | + step_fn = euler_maruyama_step |
| 425 | + case str() as name: |
| 426 | + raise ValueError(f"Unknown integration method name: {name!r}") |
| 427 | + case other: |
| 428 | + raise TypeError(f"Invalid integration method: {other!r}") |
| 429 | + |
| 430 | + # Prepare step function with partial application |
| 431 | + step_fn = partial(step_fn, drift_fn, diffusion_fn, **kwargs) |
| 432 | + step_size = (stop_time - start_time) / steps |
| 433 | + |
| 434 | + time = start_time |
| 435 | + |
| 436 | + def body(_loop_var, _loop_state): |
| 437 | + _state, _time = _loop_state |
| 438 | + |
| 439 | + # Generate noise for this step |
| 440 | + _noise = {} |
| 441 | + for key in _state.keys(): |
| 442 | + shape = keras.ops.shape(_state[key]) |
| 443 | + _noise[key] = keras.random.normal(shape) * keras.ops.sqrt(keras.ops.abs(step_size)) |
| 444 | + |
| 445 | + # Perform integration step |
| 446 | + _state, _time, _ = step_fn(_state, _time, step_size, noise=_noise) |
| 447 | + |
| 448 | + return _state, _time |
| 449 | + |
| 450 | + state, time = keras.ops.fori_loop(0, steps, body, (state, time)) |
| 451 | + return state |
0 commit comments