Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 53 additions & 11 deletions bayesflow/utils/integrate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from collections.abc import Callable
from collections.abc import Callable, Sequence
from functools import partial

import keras

import numpy as np
from typing import Literal

from bayesflow.types import Tensor
Expand Down Expand Up @@ -233,21 +234,62 @@
return state


def integrate_scheduled(
fn: Callable,
state: dict[str, ArrayLike],
steps: Tensor | np.ndarray,
method: str = "rk45",
**kwargs,
) -> dict[str, ArrayLike]:
match method:
case "euler":
step_fn = euler_step
case "rk45":
step_fn = rk45_step
case str() as name:
raise ValueError(f"Unknown integration method name: {name!r}")
case other:
raise TypeError(f"Invalid integration method: {other!r}")

Check warning on line 252 in bayesflow/utils/integrate.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/integrate.py#L247-L252

Added lines #L247 - L252 were not covered by tests

step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False)

def body(_loop_var, _loop_state):
_time = steps[_loop_var]
step_size = steps[_loop_var + 1] - steps[_loop_var]

_loop_state, _, _ = step_fn(_loop_state, _time, step_size)
return _loop_state

state = keras.ops.fori_loop(0, len(steps) - 1, body, state)
return state


def integrate(
fn: Callable,
state: dict[str, ArrayLike],
start_time: ArrayLike,
stop_time: ArrayLike,
start_time: ArrayLike | None = None,
stop_time: ArrayLike | None = None,
min_steps: int = 10,
max_steps: int = 10_000,
steps: int | Literal["adaptive"] = 100,
steps: int | Literal["adaptive"] | Tensor | np.ndarray = 100,
method: str = "euler",
**kwargs,
) -> dict[str, ArrayLike]:
match steps:
case "adaptive" | "dynamic":
return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs)
case int():
return integrate_fixed(fn, state, start_time, stop_time, steps, method, **kwargs)
case _:
raise RuntimeError("Type or value of `steps` not understood.")
if isinstance(steps, str) and steps in ["adaptive", "dynamic"]:
if start_time is None or stop_time is None:
raise ValueError(

Check warning on line 280 in bayesflow/utils/integrate.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/integrate.py#L279-L280

Added lines #L279 - L280 were not covered by tests
"Please provide start_time and stop_time for the integration, was "
f"'start_time={start_time}', 'stop_time={stop_time}'."
)
return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs)

Check warning on line 284 in bayesflow/utils/integrate.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/integrate.py#L284

Added line #L284 was not covered by tests
elif isinstance(steps, int):
if start_time is None or stop_time is None:
raise ValueError(

Check warning on line 287 in bayesflow/utils/integrate.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/integrate.py#L287

Added line #L287 was not covered by tests
"Please provide start_time and stop_time for the integration, was "
f"'start_time={start_time}', 'stop_time={stop_time}'."
)
return integrate_fixed(fn, state, start_time, stop_time, steps, method, **kwargs)
elif isinstance(steps, Sequence) or isinstance(steps, np.ndarray) or keras.ops.is_tensor(steps):
return integrate_scheduled(fn, state, steps, method, **kwargs)
else:
raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})")

Check warning on line 295 in bayesflow/utils/integrate.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/integrate.py#L295

Added line #L295 was not covered by tests
Empty file added tests/test_utils/__init__.py
Empty file.
Empty file added tests/test_utils/conftest.py
Empty file.
11 changes: 11 additions & 0 deletions tests/test_utils/test_integrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
def test_scheduled_integration():
import keras
from bayesflow.utils import integrate

def fn(t, x):
return {"x": t**2}

steps = keras.ops.convert_to_tensor([0.0, 0.5, 1.0])
approximate_result = 0.0 + 0.5**2 * 0.5
result = integrate(fn, {"x": 0.0}, steps=steps)["x"]
assert result == approximate_result