diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index 3f2d7f5c0..b197ea975 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -6,8 +6,10 @@ import numpy as np from typing import Literal, Union +from bayesflow.adapters import Adapter from bayesflow.types import Tensor from bayesflow.utils import filter_kwargs +from bayesflow.utils.logging import warning from . import logging @@ -265,6 +267,53 @@ def body(_loop_var, _loop_state): return state +def integrate_scipy( + fn: Callable, + state: dict[str, ArrayLike], + start_time: ArrayLike, + stop_time: ArrayLike, + scipy_kwargs: dict | None = None, + **kwargs, +) -> dict[str, ArrayLike]: + import scipy.integrate + + scipy_kwargs = scipy_kwargs or {} + keys = list(state.keys()) + # convert to tensor before determining the shape in case a number was passed + shapes = keras.tree.map_structure(lambda x: keras.ops.shape(keras.ops.convert_to_tensor(x)), state) + adapter = Adapter().concatenate(keys, into="x", axis=-1).convert_dtype(np.float32, np.float64) + + def state_to_vector(state): + state = keras.tree.map_structure(keras.ops.convert_to_numpy, state) + # flatten state + state = keras.tree.map_structure(lambda x: keras.ops.reshape(x, (-1,)), state) + # apply concatenation + x = adapter.forward(state)["x"] + return x + + def vector_to_state(x): + state = adapter.inverse({"x": x}) + state = {key: keras.ops.reshape(value, shapes[key]) for key, value in state.items()} + state = keras.tree.map_structure(keras.ops.convert_to_tensor, state) + return state + + def scipy_wrapper_fn(time, x): + state = vector_to_state(x) + time = keras.ops.convert_to_tensor(time, dtype="float32") + deltas = fn(time, **filter_kwargs(state, fn)) + return state_to_vector(deltas) + + result = scipy.integrate.solve_ivp( + scipy_wrapper_fn, + (start_time, stop_time), + state_to_vector(state), + **scipy_kwargs, + ) + + result = vector_to_state(result.y[:, -1]) + return result + + def integrate( fn: Callable, state: dict[str, ArrayLike], @@ -282,6 +331,12 @@ def integrate( "Please provide start_time and stop_time for the integration, was " f"'start_time={start_time}', 'stop_time={stop_time}'." ) + if method == "scipy": + if min_steps != 10: + warning("Setting min_steps has no effect for method 'scipy'") + if max_steps != 10_000: + warning("Setting max_steps has no effect for method 'scipy'") + return integrate_scipy(fn, state, start_time, stop_time, **kwargs) return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs) elif isinstance(steps, int): if start_time is None or stop_time is None: diff --git a/tests/test_utils/test_integrate.py b/tests/test_utils/test_integrate.py index 8147d0589..db5c448d7 100644 --- a/tests/test_utils/test_integrate.py +++ b/tests/test_utils/test_integrate.py @@ -1,3 +1,6 @@ +import numpy as np + + def test_scheduled_integration(): import keras from bayesflow.utils import integrate @@ -9,3 +12,25 @@ def fn(t, x): approximate_result = 0.0 + 0.5**2 * 0.5 result = integrate(fn, {"x": 0.0}, steps=steps)["x"] assert result == approximate_result + + +def test_scipy_integration(): + import keras + from bayesflow.utils import integrate + + def fn(t, x): + return {"x": keras.ops.exp(t)} + + start_time = -1.0 + stop_time = 1.0 + exact_result = keras.ops.exp(stop_time) - keras.ops.exp(start_time) + result = integrate( + fn, + {"x": 0.0}, + start_time=start_time, + stop_time=stop_time, + steps="adaptive", + method="scipy", + scipy_kwargs={"atol": 1e-6, "rtol": 1e-6}, + )["x"] + np.testing.assert_allclose(exact_result, result, atol=1e-6, rtol=1e-6)