diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index a932424f5..5e3b407ec 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -1,8 +1,9 @@ -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import partial import keras +import numpy as np from typing import Literal from bayesflow.types import Tensor @@ -233,21 +234,62 @@ def body(_state, _time, _step_size, _step): return state +def integrate_scheduled( + fn: Callable, + state: dict[str, ArrayLike], + steps: Tensor | np.ndarray, + method: str = "rk45", + **kwargs, +) -> dict[str, ArrayLike]: + match method: + case "euler": + step_fn = euler_step + case "rk45": + step_fn = rk45_step + case str() as name: + raise ValueError(f"Unknown integration method name: {name!r}") + case other: + raise TypeError(f"Invalid integration method: {other!r}") + + step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) + + def body(_loop_var, _loop_state): + _time = steps[_loop_var] + step_size = steps[_loop_var + 1] - steps[_loop_var] + + _loop_state, _, _ = step_fn(_loop_state, _time, step_size) + return _loop_state + + state = keras.ops.fori_loop(0, len(steps) - 1, body, state) + return state + + def integrate( fn: Callable, state: dict[str, ArrayLike], - start_time: ArrayLike, - stop_time: ArrayLike, + start_time: ArrayLike | None = None, + stop_time: ArrayLike | None = None, min_steps: int = 10, max_steps: int = 10_000, - steps: int | Literal["adaptive"] = 100, + steps: int | Literal["adaptive"] | Tensor | np.ndarray = 100, method: str = "euler", **kwargs, ) -> dict[str, ArrayLike]: - match steps: - case "adaptive" | "dynamic": - return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs) - case int(): - return integrate_fixed(fn, state, start_time, stop_time, steps, method, **kwargs) - case _: - raise RuntimeError("Type or value of `steps` not understood.") + if isinstance(steps, str) and steps in ["adaptive", "dynamic"]: + if start_time is None or stop_time is None: + raise ValueError( + "Please provide start_time and stop_time for the integration, was " + f"'start_time={start_time}', 'stop_time={stop_time}'." + ) + return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs) + elif isinstance(steps, int): + if start_time is None or stop_time is None: + raise ValueError( + "Please provide start_time and stop_time for the integration, was " + f"'start_time={start_time}', 'stop_time={stop_time}'." + ) + return integrate_fixed(fn, state, start_time, stop_time, steps, method, **kwargs) + elif isinstance(steps, Sequence) or isinstance(steps, np.ndarray) or keras.ops.is_tensor(steps): + return integrate_scheduled(fn, state, steps, method, **kwargs) + else: + raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_utils/conftest.py b/tests/test_utils/conftest.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py new file mode 100644 index 000000000..8147d0589 --- /dev/null +++ b/tests/test_utils/test_integrate.py @@ -0,0 +1,11 @@ +def test_scheduled_integration(): + import keras + from bayesflow.utils import integrate + + def fn(t, x): + return {"x": t**2} + + steps = keras.ops.convert_to_tensor([0.0, 0.5, 1.0]) + approximate_result = 0.0 + 0.5**2 * 0.5 + result = integrate(fn, {"x": 0.0}, steps=steps)["x"] + assert result == approximate_result