|
18 | 18 | import functools |
19 | 19 | import inspect |
20 | 20 | import os |
21 | | -from typing import Any, Callable, Literal, TypeVar |
| 21 | +from typing import Any, Callable, Literal, ParamSpec, TypeAlias, TypeVar |
22 | 22 |
|
23 | 23 | import chex |
24 | 24 | import equinox as eqx |
|
27 | 27 | import numpy as np |
28 | 28 |
|
29 | 29 | T = TypeVar('T') |
30 | | -BooleanNumeric = Any # A bool, or a Boolean array. |
| 30 | +BooleanNumeric: TypeAlias = Any # A bool, or a Boolean array. |
| 31 | +_State = ParamSpec('_State') |
31 | 32 |
|
32 | 33 |
|
33 | 34 | @functools.cache |
@@ -296,3 +297,55 @@ def init_array(x): |
296 | 297 | return x |
297 | 298 |
|
298 | 299 | return jax.tree_util.tree_map(init_array, t) |
| 300 | + |
| 301 | + |
| 302 | +@functools.partial( |
| 303 | + jit, static_argnames=['cond_fun', 'body_fun', 'max_steps', 'scan_unroll'] |
| 304 | +) |
| 305 | +def while_loop_bounded( |
| 306 | + cond_fun: Callable[[_State], BooleanNumeric], |
| 307 | + body_fun: Callable[[_State], _State], |
| 308 | + init_val: _State, |
| 309 | + max_steps: int, |
| 310 | + scan_unroll: int = 1, |
| 311 | +) -> _State: |
| 312 | + """A reverse-mode differentiable while_loop. |
| 313 | +
|
| 314 | + This makes use of jax.lax.scan and `max_steps` to define a fixed size |
| 315 | + computational graph. The body_fun is called the same number of times it would |
| 316 | + be under a jax.lax.while_loop i.e. until `cond_fun` returns False (unless the |
| 317 | + `max_steps` is reached). |
| 318 | +
|
| 319 | + Args: |
| 320 | + cond_fun: As in jax.lax.while_loop. |
| 321 | + body_fun: As in jax.lax.while_loop. |
| 322 | + init_val: As in jax.lax.while_loop. |
| 323 | + max_steps: An integer, the maximum number of iterations the loop can |
| 324 | + perform. This is crucial for defining a fixed computational graph for |
| 325 | + scan. |
| 326 | + scan_unroll: The number of iterations to unroll the internal scan by. |
| 327 | +
|
| 328 | + Returns: |
| 329 | + The final state after `cond_fun` returns `False` or `max_steps` are reached. |
| 330 | + """ |
| 331 | + # Initial carry for the scan: (current_state, while_loop_condition_met) |
| 332 | + initial_scan_carry = (init_val, jnp.array(True, dtype=jnp.bool_)) |
| 333 | + |
| 334 | + def scan_body(carry, _): |
| 335 | + current_state, cond_prev = carry |
| 336 | + # Only execute cond if the previous cond was True. |
| 337 | + should_execute_body = jax.lax.cond( |
| 338 | + cond_prev, cond_fun, lambda _: False, current_state |
| 339 | + ) |
| 340 | + # If the `while_loop` would have terminated, we no-op. |
| 341 | + next_state = jax.lax.cond( |
| 342 | + should_execute_body, body_fun, lambda s: s, current_state |
| 343 | + ) |
| 344 | + |
| 345 | + return (next_state, should_execute_body), None |
| 346 | + |
| 347 | + (final_state, _), _ = jax.lax.scan( |
| 348 | + scan_body, initial_scan_carry, length=max_steps, unroll=scan_unroll |
| 349 | + ) |
| 350 | + |
| 351 | + return final_state |
0 commit comments