Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions bayesflow/utils/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import numpy as np
from typing import Literal, Union

from bayesflow.adapters import Adapter
from bayesflow.types import Tensor
from bayesflow.utils import filter_kwargs
from bayesflow.utils.logging import warning

from . import logging

Expand Down Expand Up @@ -265,6 +267,53 @@ def body(_loop_var, _loop_state):
return state


def integrate_scipy(
fn: Callable,
state: dict[str, ArrayLike],
start_time: ArrayLike,
stop_time: ArrayLike,
scipy_kwargs: dict | None = None,
**kwargs,
) -> dict[str, ArrayLike]:
import scipy.integrate

scipy_kwargs = scipy_kwargs or {}
keys = list(state.keys())
# convert to tensor before determining the shape in case a number was passed
shapes = keras.tree.map_structure(lambda x: keras.ops.shape(keras.ops.convert_to_tensor(x)), state)
adapter = Adapter().concatenate(keys, into="x", axis=-1).convert_dtype(np.float32, np.float64)

def state_to_vector(state):
state = keras.tree.map_structure(keras.ops.convert_to_numpy, state)
# flatten state
state = keras.tree.map_structure(lambda x: keras.ops.reshape(x, (-1,)), state)
# apply concatenation
x = adapter.forward(state)["x"]
return x

def vector_to_state(x):
state = adapter.inverse({"x": x})
state = {key: keras.ops.reshape(value, shapes[key]) for key, value in state.items()}
state = keras.tree.map_structure(keras.ops.convert_to_tensor, state)
return state

def scipy_wrapper_fn(time, x):
state = vector_to_state(x)
time = keras.ops.convert_to_tensor(time, dtype="float32")
deltas = fn(time, **filter_kwargs(state, fn))
return state_to_vector(deltas)

result = scipy.integrate.solve_ivp(
scipy_wrapper_fn,
(start_time, stop_time),
state_to_vector(state),
**scipy_kwargs,
)

result = vector_to_state(result.y[:, -1])
return result


def integrate(
fn: Callable,
state: dict[str, ArrayLike],
Expand All @@ -282,6 +331,12 @@ def integrate(
"Please provide start_time and stop_time for the integration, was "
f"'start_time={start_time}', 'stop_time={stop_time}'."
)
if method == "scipy":
if min_steps != 10:
warning("Setting min_steps has no effect for method 'scipy'")
if max_steps != 10_000:
warning("Setting max_steps has no effect for method 'scipy'")
return integrate_scipy(fn, state, start_time, stop_time, **kwargs)
return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs)
elif isinstance(steps, int):
if start_time is None or stop_time is None:
Expand Down
25 changes: 25 additions & 0 deletions tests/test_utils/test_integrate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import numpy as np


def test_scheduled_integration():
import keras
from bayesflow.utils import integrate
Expand All @@ -9,3 +12,25 @@ def fn(t, x):
approximate_result = 0.0 + 0.5**2 * 0.5
result = integrate(fn, {"x": 0.0}, steps=steps)["x"]
assert result == approximate_result


def test_scipy_integration():
import keras
from bayesflow.utils import integrate

def fn(t, x):
return {"x": keras.ops.exp(t)}

start_time = -1.0
stop_time = 1.0
exact_result = keras.ops.exp(stop_time) - keras.ops.exp(start_time)
result = integrate(
fn,
{"x": 0.0},
start_time=start_time,
stop_time=stop_time,
steps="adaptive",
method="scipy",
scipy_kwargs={"atol": 1e-6, "rtol": 1e-6},
)["x"]
np.testing.assert_allclose(exact_result, result, atol=1e-6, rtol=1e-6)
Loading