Skip to content

Commit cfb5920

Browse files
Nush395Torax team
authored andcommitted
Add reverse-mode differentiable jax.lax.while_loop to jax_utils.
This will be useful for being able to reverse mode differentiate through the TORAX simulation. PiperOrigin-RevId: 796482656
1 parent 4711f67 commit cfb5920

File tree

2 files changed

+107
-2
lines changed

2 files changed

+107
-2
lines changed

torax/_src/jax_utils.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import functools
1919
import inspect
2020
import os
21-
from typing import Any, Callable, Literal, TypeVar
21+
from typing import Any, Callable, Literal, ParamSpec, TypeAlias, TypeVar
2222

2323
import chex
2424
import equinox as eqx
@@ -27,7 +27,8 @@
2727
import numpy as np
2828

2929
T = TypeVar('T')
30-
BooleanNumeric = Any # A bool, or a Boolean array.
30+
BooleanNumeric: TypeAlias = Any # A bool, or a Boolean array.
31+
_State = ParamSpec('_State')
3132

3233

3334
@functools.cache
@@ -296,3 +297,55 @@ def init_array(x):
296297
return x
297298

298299
return jax.tree_util.tree_map(init_array, t)
300+
301+
302+
@functools.partial(
303+
jit, static_argnames=['cond_fun', 'body_fun', 'max_steps', 'scan_unroll']
304+
)
305+
def while_loop_bounded(
306+
cond_fun: Callable[[_State], BooleanNumeric],
307+
body_fun: Callable[[_State], _State],
308+
init_val: _State,
309+
max_steps: int,
310+
scan_unroll: int = 1,
311+
) -> _State:
312+
"""A reverse-mode differentiable while_loop.
313+
314+
This makes use of jax.lax.scan and `max_steps` to define a fixed size
315+
computational graph. The body_fun is called the same number of times it would
316+
be under a jax.lax.while_loop i.e. until `cond_fun` returns False (unless the
317+
`max_steps` is reached).
318+
319+
Args:
320+
cond_fun: As in jax.lax.while_loop.
321+
body_fun: As in jax.lax.while_loop.
322+
init_val: As in jax.lax.while_loop.
323+
max_steps: An integer, the maximum number of iterations the loop can
324+
perform. This is crucial for defining a fixed computational graph for
325+
scan.
326+
scan_unroll: The number of iterations to unroll the internal scan by.
327+
328+
Returns:
329+
The final state after `cond_fun` returns `False` or `max_steps` are reached.
330+
"""
331+
# Initial carry for the scan: (current_state, while_loop_condition_met)
332+
initial_scan_carry = (init_val, jnp.array(True, dtype=jnp.bool_))
333+
334+
def scan_body(carry, _):
335+
current_state, cond_prev = carry
336+
# Only execute cond if the previous cond was True.
337+
should_execute_body = jax.lax.cond(
338+
cond_prev, cond_fun, lambda _: False, current_state
339+
)
340+
# If the `while_loop` would have terminated, we no-op.
341+
next_state = jax.lax.cond(
342+
should_execute_body, body_fun, lambda s: s, current_state
343+
)
344+
345+
return (next_state, should_execute_body), None
346+
347+
(final_state, _), _ = jax.lax.scan(
348+
scan_body, initial_scan_carry, length=max_steps, unroll=scan_unroll
349+
)
350+
351+
return final_state

torax/_src/tests/jax_utils_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,58 @@ def f(x, z, y=2.0):
150150
x = {'temp1': jnp.array(1.3), 'temp2': jnp.array(2.6)}
151151
chex.assert_trees_all_close(f_non_inlined(x, z='left'), f(x, z='left'))
152152

153+
def test_max_steps_while_loop(self):
154+
terminating_step = 4
155+
156+
def cond_fun(state):
157+
i, _ = state
158+
return i < terminating_step
159+
160+
def body_fun(state):
161+
i, value = state
162+
next_i = i + 1
163+
next_value = jnp.sin(value)
164+
return next_i, next_value
165+
166+
init_state = (0, 0.5)
167+
max_steps = 10
168+
169+
with self.subTest('forward_agrees_with_while_loop'):
170+
output_state = jax_utils.while_loop_bounded(
171+
cond_fun, body_fun, init_state, max_steps
172+
)
173+
chex.assert_trees_all_close(
174+
output_state, jax.lax.while_loop(cond_fun, body_fun, init_state)
175+
)
176+
177+
def f_while(x, max_steps=max_steps):
178+
init_state = (0, x)
179+
return jax_utils.while_loop_bounded(
180+
cond_fun, body_fun, init_state, max_steps=max_steps
181+
)[1]
182+
183+
def f(x, n_times=terminating_step):
184+
result = x
185+
for _ in range(n_times):
186+
result = jnp.sin(result)
187+
return result
188+
189+
with self.subTest('forward_agrees_with_explicit'):
190+
chex.assert_trees_all_close(f_while(0.5), f(0.5))
191+
with self.subTest('grad_agrees_with_explicit'):
192+
chex.assert_trees_all_close(jax.grad(f_while)(0.5), jax.grad(f)(0.5))
193+
194+
with self.subTest('max_steps_is_respected'):
195+
final_i, final_value = jax_utils.while_loop_bounded(
196+
cond_fun, body_fun, init_state, max_steps=2
197+
)
198+
self.assertEqual(final_i, 2)
199+
chex.assert_trees_all_close(final_value, f(0.5, n_times=2))
200+
chex.assert_trees_all_close(
201+
jax.grad(f_while)(0.5, max_steps=2), jax.grad(f)(0.5, n_times=2)
202+
)
203+
chex.assert_trees_all_close(f_while(0.5, max_steps=3), f(0.5, n_times=3))
204+
153205

154206
if __name__ == '__main__':
155207
absltest.main()

0 commit comments

Comments
 (0)