Skip to content

Commit 2d932e1

Browse files
Nush395Torax team
authored andcommitted
Add reverse-mode differentiable jax.lax.while_loop to jax_utils.
This will be useful for being able to reverse mode differentiate through the TORAX simulation. PiperOrigin-RevId: 794095925
1 parent 68d7481 commit 2d932e1

File tree

2 files changed

+98
-2
lines changed

2 files changed

+98
-2
lines changed

torax/_src/jax_utils.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import functools
1919
import inspect
2020
import os
21-
from typing import Any, Callable, Literal, TypeVar
21+
from typing import Any, Callable, Literal, ParamSpec, TypeAlias, TypeVar
2222

2323
import chex
2424
import equinox as eqx
@@ -27,7 +27,8 @@
2727
import numpy as np
2828

2929
T = TypeVar('T')
30-
BooleanNumeric = Any # A bool, or a Boolean array.
30+
BooleanNumeric: TypeAlias = Any # A bool, or a Boolean array.
31+
_State = ParamSpec('_State')
3132

3233

3334
@functools.cache
@@ -296,3 +297,46 @@ def init_array(x):
296297
return x
297298

298299
return jax.tree_util.tree_map(init_array, t)
300+
301+
302+
@functools.partial(jit, static_argnames=['cond_fun', 'body_fun', 'max_steps'])
303+
def max_steps_while_loop(
304+
cond_fun: Callable[[_State], BooleanNumeric],
305+
body_fun: Callable[[_State], _State],
306+
init_val: _State,
307+
max_steps: int,
308+
) -> _State:
309+
"""A reverse-mode differentiable while_loop using jax.lax.scan.
310+
311+
Args:
312+
cond_fun: As in jax.lax.while_loop.
313+
body_fun: As in jax.lax.while_loop.
314+
init_val: As in jax.lax.while_loop.
315+
max_steps: An integer, the maximum number of iterations the loop can
316+
perform. This is crucial for defining a fixed computational graph for
317+
scan.
318+
319+
Returns:
320+
The final state after the loop terminates or `max_steps` are reached.
321+
"""
322+
# Initial carry for the scan: (current_state, while_loop_condition_met)
323+
initial_scan_carry = (init_val, jnp.array(True, dtype=jnp.bool_))
324+
325+
def scan_body(carry, _):
326+
current_state, cond_met_prev = carry
327+
328+
cond_eval = cond_fun(current_state)
329+
should_execute_body = jnp.logical_and(cond_met_prev, cond_eval)
330+
331+
next_state = jax.lax.cond(
332+
should_execute_body, body_fun, lambda s: s, current_state
333+
)
334+
next_cond_met = should_execute_body
335+
336+
return (next_state, next_cond_met), None
337+
338+
dummy_xs = jnp.arange(max_steps)
339+
340+
(final_state, _), _ = jax.lax.scan(scan_body, initial_scan_carry, dummy_xs)
341+
342+
return final_state

torax/_src/tests/jax_utils_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,58 @@ def f(x, z, y=2.0):
150150
x = {'temp1': jnp.array(1.3), 'temp2': jnp.array(2.6)}
151151
chex.assert_trees_all_close(f_non_inlined(x, z='left'), f(x, z='left'))
152152

153+
def test_max_steps_while_loop(self):
154+
terminating_step = 4
155+
156+
def cond_fun(state):
157+
i, _ = state
158+
return i < terminating_step
159+
160+
def body_fun(state):
161+
i, value = state
162+
next_i = i + 1
163+
next_value = jnp.sin(value)
164+
return next_i, next_value
165+
166+
init_state = (0, 0.5)
167+
max_steps = 10
168+
169+
with self.subTest('forward_agrees_with_while_loop'):
170+
output_state = jax_utils.max_steps_while_loop(
171+
cond_fun, body_fun, init_state, max_steps
172+
)
173+
chex.assert_trees_all_close(
174+
output_state, jax.lax.while_loop(cond_fun, body_fun, init_state)
175+
)
176+
177+
def f_while(x, max_steps=4):
178+
init_state = (0, x)
179+
return jax_utils.max_steps_while_loop(
180+
cond_fun, body_fun, init_state, max_steps=max_steps
181+
)[1]
182+
183+
def f(x, n_times=terminating_step):
184+
result = x
185+
for _ in range(n_times):
186+
result = jnp.sin(result)
187+
return result
188+
189+
with self.subTest('forward_agrees_with_explicit'):
190+
chex.assert_trees_all_close(f_while(0.5), f(0.5))
191+
with self.subTest('grad_agrees_with_explicit'):
192+
chex.assert_trees_all_close(jax.grad(f_while)(0.5), jax.grad(f)(0.5))
193+
194+
with self.subTest('max_steps_is_respected'):
195+
final_i, final_value = jax_utils.max_steps_while_loop(
196+
cond_fun, body_fun, init_state, max_steps=2
197+
)
198+
self.assertEqual(final_i, 2)
199+
chex.assert_trees_all_close(final_value, f(0.5, n_times=2))
200+
chex.assert_trees_all_close(
201+
jax.grad(f_while)(0.5, max_steps=2), jax.grad(f)(0.5, n_times=2)
202+
)
203+
chex.assert_trees_all_close(f_while(0.5, max_steps=3), f(0.5, n_times=3))
204+
153205

154206
if __name__ == '__main__':
155207
absltest.main()

0 commit comments

Comments
 (0)