Skip to content

Commit 7cabf17

Browse files
committed
Add wrapper around scipy.integrate.solve_ivp for integration
1 parent d2ac255 commit 7cabf17

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

bayesflow/utils/integrate.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
import numpy as np
77
from typing import Literal, Union
88

9+
from bayesflow.adapters import Adapter
910
from bayesflow.types import Tensor
1011
from bayesflow.utils import filter_kwargs
12+
from bayesflow.utils.logging import warning
1113

1214
from . import logging
1315

@@ -265,6 +267,53 @@ def body(_loop_var, _loop_state):
265267
return state
266268

267269

270+
def integrate_scipy(
271+
fn: Callable,
272+
state: dict[str, ArrayLike],
273+
start_time: ArrayLike,
274+
stop_time: ArrayLike,
275+
scipy_kwargs: dict | None = None,
276+
**kwargs,
277+
) -> dict[str, ArrayLike]:
278+
import scipy.integrate
279+
280+
scipy_kwargs = scipy_kwargs or {}
281+
keys = list(state.keys())
282+
# convert to tensor before determining the shape in case a number was passed
283+
shapes = keras.tree.map_structure(lambda x: keras.ops.shape(keras.ops.convert_to_tensor(x)), state)
284+
adapter = Adapter().concatenate(keys, into="x", axis=-1).convert_dtype(np.float32, np.float64)
285+
286+
def state_to_vector(state):
287+
state = keras.tree.map_structure(keras.ops.convert_to_numpy, state)
288+
# flatten state
289+
state = keras.tree.map_structure(lambda x: keras.ops.reshape(x, (-1,)), state)
290+
# apply concatenation
291+
x = adapter.forward(state)["x"]
292+
return x
293+
294+
def vector_to_state(x):
295+
state = adapter.inverse({"x": x})
296+
state = {key: keras.ops.reshape(value, shapes[key]) for key, value in state.items()}
297+
state = keras.tree.map_structure(keras.ops.convert_to_tensor, state)
298+
return state
299+
300+
def scipy_wrapper_fn(time, x):
301+
state = vector_to_state(x)
302+
time = keras.ops.convert_to_tensor(time, dtype="float32")
303+
deltas = fn(time, **filter_kwargs(state, fn))
304+
return state_to_vector(deltas)
305+
306+
result = scipy.integrate.solve_ivp(
307+
scipy_wrapper_fn,
308+
(start_time, stop_time),
309+
state_to_vector(state),
310+
**scipy_kwargs,
311+
)
312+
313+
result = vector_to_state(result.y[:, -1])
314+
return result
315+
316+
268317
def integrate(
269318
fn: Callable,
270319
state: dict[str, ArrayLike],
@@ -282,6 +331,12 @@ def integrate(
282331
"Please provide start_time and stop_time for the integration, was "
283332
f"'start_time={start_time}', 'stop_time={stop_time}'."
284333
)
334+
if method == "scipy":
335+
if min_steps != 10:
336+
warning("Setting min_steps has no effect for method 'scipy'")
337+
if max_steps != 10_000:
338+
warning("Setting max_steps has no effect for method 'scipy'")
339+
return integrate_scipy(fn, state, start_time, stop_time, **kwargs)
285340
return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs)
286341
elif isinstance(steps, int):
287342
if start_time is None or stop_time is None:

tests/test_utils/test_integrate.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import numpy as np
2+
3+
14
def test_scheduled_integration():
25
import keras
36
from bayesflow.utils import integrate
@@ -9,3 +12,25 @@ def fn(t, x):
912
approximate_result = 0.0 + 0.5**2 * 0.5
1013
result = integrate(fn, {"x": 0.0}, steps=steps)["x"]
1114
assert result == approximate_result
15+
16+
17+
def test_scipy_integration():
18+
import keras
19+
from bayesflow.utils import integrate
20+
21+
def fn(t, x):
22+
return {"x": keras.ops.exp(t)}
23+
24+
start_time = -1.0
25+
stop_time = 1.0
26+
exact_result = keras.ops.exp(stop_time) - keras.ops.exp(start_time)
27+
result = integrate(
28+
fn,
29+
{"x": 0.0},
30+
start_time=start_time,
31+
stop_time=stop_time,
32+
steps="adaptive",
33+
method="scipy",
34+
scipy_kwargs={"atol": 1e-6, "rtol": 1e-6},
35+
)["x"]
36+
np.testing.assert_allclose(exact_result, result, atol=1e-6, rtol=1e-6)

0 commit comments

Comments
 (0)