Skip to content

Commit 1c6b90d

Browse files
committed
Enable integration with fixed, non-uniform schedule.
This enables specialized, manually specified schedules, as for example required in diffusion models for inference. Closes #402.
1 parent c8ea3dd commit 1c6b90d

File tree

4 files changed

+64
-11
lines changed

4 files changed

+64
-11
lines changed

bayesflow/utils/integrate.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from collections.abc import Callable
1+
from collections.abc import Callable, Sequence
22
from functools import partial
33

44
import keras
55

6+
import numpy as np
67
from typing import Literal
78

89
from bayesflow.types import Tensor
@@ -233,21 +234,62 @@ def body(_state, _time, _step_size, _step):
233234
return state
234235

235236

237+
def integrate_scheduled(
238+
fn: Callable,
239+
state: dict[str, ArrayLike],
240+
steps: Tensor | np.ndarray,
241+
method: str = "rk45",
242+
**kwargs,
243+
) -> dict[str, ArrayLike]:
244+
match method:
245+
case "euler":
246+
step_fn = euler_step
247+
case "rk45":
248+
step_fn = rk45_step
249+
case str() as name:
250+
raise ValueError(f"Unknown integration method name: {name!r}")
251+
case other:
252+
raise TypeError(f"Invalid integration method: {other!r}")
253+
254+
step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False)
255+
256+
def body(_loop_var, _loop_state):
257+
_time = steps[_loop_var]
258+
step_size = steps[_loop_var + 1] - steps[_loop_var]
259+
260+
_loop_state, _, _ = step_fn(_loop_state, _time, step_size)
261+
return _loop_state
262+
263+
state = keras.ops.fori_loop(0, len(steps) - 1, body, state)
264+
return state
265+
266+
236267
def integrate(
237268
fn: Callable,
238269
state: dict[str, ArrayLike],
239-
start_time: ArrayLike,
240-
stop_time: ArrayLike,
270+
start_time: ArrayLike | None = None,
271+
stop_time: ArrayLike | None = None,
241272
min_steps: int = 10,
242273
max_steps: int = 10_000,
243-
steps: int | Literal["adaptive"] = 100,
274+
steps: int | Literal["adaptive"] | Tensor | np.ndarray = 100,
244275
method: str = "euler",
245276
**kwargs,
246277
) -> dict[str, ArrayLike]:
247-
match steps:
248-
case "adaptive" | "dynamic":
249-
return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs)
250-
case int():
251-
return integrate_fixed(fn, state, start_time, stop_time, steps, method, **kwargs)
252-
case _:
253-
raise RuntimeError("Type or value of `steps` not understood.")
278+
if isinstance(steps, str) and steps in ["adaptive", "dynamic"]:
279+
if start_time is None or stop_time is None:
280+
raise ValueError(
281+
"Please provide start_time and stop_time for the integration, was "
282+
f"'start_time={start_time}', 'stop_time={stop_time}'."
283+
)
284+
return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs)
285+
elif isinstance(steps, int):
286+
if start_time is None or stop_time is None:
287+
raise ValueError(
288+
"Please provide start_time and stop_time for the integration, was "
289+
f"'start_time={start_time}', 'stop_time={stop_time}'."
290+
)
291+
return integrate_fixed(fn, state, start_time, stop_time, steps, method, **kwargs)
292+
elif isinstance(steps, Sequence) or isinstance(steps, np.ndarray) or keras.ops.is_tensor(steps):
293+
return integrate_scheduled(fn, state, steps, method, **kwargs)
294+
else:
295+
raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})")

tests/test_utils/__init__.py

Whitespace-only changes.

tests/test_utils/conftest.py

Whitespace-only changes.

tests/test_utils/test_integrate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
def test_scheduled_integration():
2+
import keras
3+
from bayesflow.utils import integrate
4+
5+
def fn(t, x):
6+
return {"x": t**2}
7+
8+
steps = keras.ops.convert_to_tensor([0.0, 0.5, 1.0])
9+
approximate_result = 0.0 + 0.5**2 * 0.5
10+
result = integrate(fn, {"x": 0.0}, steps=steps)["x"]
11+
assert result == approximate_result

0 commit comments

Comments
 (0)