|
1 | | -from collections.abc import Callable |
| 1 | +from collections.abc import Callable, Sequence |
2 | 2 | from functools import partial |
3 | 3 |
|
4 | 4 | import keras |
5 | 5 |
|
| 6 | +import numpy as np |
6 | 7 | from typing import Literal |
7 | 8 |
|
8 | 9 | from bayesflow.types import Tensor |
@@ -233,21 +234,62 @@ def body(_state, _time, _step_size, _step): |
233 | 234 | return state |
234 | 235 |
|
235 | 236 |
|
| 237 | +def integrate_scheduled( |
| 238 | + fn: Callable, |
| 239 | + state: dict[str, ArrayLike], |
| 240 | + steps: Tensor | np.ndarray, |
| 241 | + method: str = "rk45", |
| 242 | + **kwargs, |
| 243 | +) -> dict[str, ArrayLike]: |
| 244 | + match method: |
| 245 | + case "euler": |
| 246 | + step_fn = euler_step |
| 247 | + case "rk45": |
| 248 | + step_fn = rk45_step |
| 249 | + case str() as name: |
| 250 | + raise ValueError(f"Unknown integration method name: {name!r}") |
| 251 | + case other: |
| 252 | + raise TypeError(f"Invalid integration method: {other!r}") |
| 253 | + |
| 254 | + step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False) |
| 255 | + |
| 256 | + def body(_loop_var, _loop_state): |
| 257 | + _time = steps[_loop_var] |
| 258 | + step_size = steps[_loop_var + 1] - steps[_loop_var] |
| 259 | + |
| 260 | + _loop_state, _, _ = step_fn(_loop_state, _time, step_size) |
| 261 | + return _loop_state |
| 262 | + |
| 263 | + state = keras.ops.fori_loop(0, len(steps) - 1, body, state) |
| 264 | + return state |
| 265 | + |
| 266 | + |
236 | 267 | def integrate( |
237 | 268 | fn: Callable, |
238 | 269 | state: dict[str, ArrayLike], |
239 | | - start_time: ArrayLike, |
240 | | - stop_time: ArrayLike, |
| 270 | + start_time: ArrayLike | None = None, |
| 271 | + stop_time: ArrayLike | None = None, |
241 | 272 | min_steps: int = 10, |
242 | 273 | max_steps: int = 10_000, |
243 | | - steps: int | Literal["adaptive"] = 100, |
| 274 | + steps: int | Literal["adaptive"] | Tensor | np.ndarray = 100, |
244 | 275 | method: str = "euler", |
245 | 276 | **kwargs, |
246 | 277 | ) -> dict[str, ArrayLike]: |
247 | | - match steps: |
248 | | - case "adaptive" | "dynamic": |
249 | | - return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs) |
250 | | - case int(): |
251 | | - return integrate_fixed(fn, state, start_time, stop_time, steps, method, **kwargs) |
252 | | - case _: |
253 | | - raise RuntimeError("Type or value of `steps` not understood.") |
| 278 | + if isinstance(steps, str) and steps in ["adaptive", "dynamic"]: |
| 279 | + if start_time is None or stop_time is None: |
| 280 | + raise ValueError( |
| 281 | + "Please provide start_time and stop_time for the integration, was " |
| 282 | + f"'start_time={start_time}', 'stop_time={stop_time}'." |
| 283 | + ) |
| 284 | + return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs) |
| 285 | + elif isinstance(steps, int): |
| 286 | + if start_time is None or stop_time is None: |
| 287 | + raise ValueError( |
| 288 | + "Please provide start_time and stop_time for the integration, was " |
| 289 | + f"'start_time={start_time}', 'stop_time={stop_time}'." |
| 290 | + ) |
| 291 | + return integrate_fixed(fn, state, start_time, stop_time, steps, method, **kwargs) |
| 292 | + elif isinstance(steps, Sequence) or isinstance(steps, np.ndarray) or keras.ops.is_tensor(steps): |
| 293 | + return integrate_scheduled(fn, state, steps, method, **kwargs) |
| 294 | + else: |
| 295 | + raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})") |
0 commit comments