Skip to content

Commit d1c6857

Browse files
Nush395Torax team
authored andcommitted
Add documentation for experimental compile mode.
PiperOrigin-RevId: 794238903
1 parent 8c4ae9e commit d1c6857

File tree

5 files changed

+123
-4
lines changed

5 files changed

+123
-4
lines changed

docs/running.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,18 @@ internal TORAX functions. Used for debugging.
7777
7878
export TORAX_COMPILATION_ENABLED=<True/False>
7979
80+
``EXPERIMENTAL_COMPILE`` (default: `False`) - If `True`, trigger expanded
81+
compilation up to the TORAX step function's call method. Normally with
82+
`TORAX_COMPILATION_ENABLED` flag set to `True` TORAX will compile a more limited
83+
set of functions. This is an experimental flag at the moment and will lead to
84+
faster simulations (at the expense of larger compilation time). We are currently
85+
working on reducing compile times so that the behaviour of this flag is
86+
triggered by `TORAX_COMPILATION_ENABLED`.
87+
88+
.. code-block:: console
89+
90+
export EXPERIMENTAL_COMPILE=<True/False>
91+
8092
.. _torax_flags:
8193

8294
run_torax flags

torax/_src/jax_utils.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import functools
1919
import inspect
2020
import os
21-
from typing import Any, Callable, Literal, TypeVar
21+
from typing import Any, Callable, Literal, ParamSpec, TypeAlias, TypeVar
2222

2323
import chex
2424
import equinox as eqx
@@ -27,7 +27,8 @@
2727
import numpy as np
2828

2929
T = TypeVar('T')
30-
BooleanNumeric = Any # A bool, or a Boolean array.
30+
BooleanNumeric: TypeAlias = Any # A bool, or a Boolean array.
31+
_State = ParamSpec('_State')
3132

3233

3334
@functools.cache
@@ -296,3 +297,55 @@ def init_array(x):
296297
return x
297298

298299
return jax.tree_util.tree_map(init_array, t)
300+
301+
302+
@functools.partial(
303+
jit, static_argnames=['cond_fun', 'body_fun', 'max_steps', 'scan_unroll']
304+
)
305+
def while_loop_bounded(
306+
cond_fun: Callable[[_State], BooleanNumeric],
307+
body_fun: Callable[[_State], _State],
308+
init_val: _State,
309+
max_steps: int,
310+
scan_unroll: int = 1,
311+
) -> _State:
312+
"""A reverse-mode differentiable while_loop.
313+
314+
This makes use of jax.lax.scan to define a fixed size computational graph.
315+
The body_fun is called the same number of times it would be under a
316+
jax.lax.while_loop i.e. until `cond_fun` returns False (unless the `max_steps`
317+
is reached).
318+
319+
Args:
320+
cond_fun: As in jax.lax.while_loop.
321+
body_fun: As in jax.lax.while_loop.
322+
init_val: As in jax.lax.while_loop.
323+
max_steps: An integer, the maximum number of iterations the loop can
324+
perform. This is crucial for defining a fixed computational graph for
325+
scan.
326+
scan_unroll: The number of iterations to unroll the internal scan by.
327+
328+
Returns:
329+
The final state after `cond_fun` returns `False` or `max_steps` are reached.
330+
"""
331+
# Initial carry for the scan: (current_state, while_loop_condition_met)
332+
initial_scan_carry = (init_val, jnp.array(True, dtype=jnp.bool_))
333+
334+
def scan_body(carry, _):
335+
current_state, cond_met_prev = carry
336+
337+
cond_eval = cond_fun(current_state)
338+
should_execute_body = jnp.logical_and(cond_met_prev, cond_eval)
339+
340+
next_state = jax.lax.cond(
341+
should_execute_body, body_fun, lambda s: s, current_state
342+
)
343+
next_cond_met = should_execute_body
344+
345+
return (next_state, next_cond_met), None
346+
347+
(final_state, _), _ = jax.lax.scan(
348+
scan_body, initial_scan_carry, length=max_steps, unroll=scan_unroll
349+
)
350+
351+
return final_state

torax/_src/solver/pydantic_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ class BaseSolver(torax_pydantic.BaseModelFrozen, abc.ABC):
5555
torax_pydantic.UnitInterval, torax_pydantic.JAX_STATIC
5656
] = 1.0
5757
use_predictor_corrector: Annotated[bool, torax_pydantic.JAX_STATIC] = False
58-
n_corrector_steps: pydantic.PositiveInt = 10
58+
n_corrector_steps: Annotated[
59+
pydantic.PositiveInt, torax_pydantic.JAX_STATIC
60+
] = 10
5961
convection_dirichlet_mode: Annotated[
6062
Literal['ghost', 'direct', 'semi-implicit'], torax_pydantic.JAX_STATIC
6163
] = 'ghost'

torax/_src/solver/runtime_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,4 @@ class DynamicRuntimeParams:
3636

3737
chi_pereverzev: float
3838
D_pereverzev: float # pylint: disable=invalid-name
39-
n_corrector_steps: int
39+
n_corrector_steps: int = dataclasses.field(metadata={'static': True})

torax/_src/tests/jax_utils_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,58 @@ def f(x, z, y=2.0):
150150
x = {'temp1': jnp.array(1.3), 'temp2': jnp.array(2.6)}
151151
chex.assert_trees_all_close(f_non_inlined(x, z='left'), f(x, z='left'))
152152

153+
def test_max_steps_while_loop(self):
154+
terminating_step = 4
155+
156+
def cond_fun(state):
157+
i, _ = state
158+
return i < terminating_step
159+
160+
def body_fun(state):
161+
i, value = state
162+
next_i = i + 1
163+
next_value = jnp.sin(value)
164+
return next_i, next_value
165+
166+
init_state = (0, 0.5)
167+
max_steps = 10
168+
169+
with self.subTest('forward_agrees_with_while_loop'):
170+
output_state = jax_utils.while_loop_bounded(
171+
cond_fun, body_fun, init_state, max_steps
172+
)
173+
chex.assert_trees_all_close(
174+
output_state, jax.lax.while_loop(cond_fun, body_fun, init_state)
175+
)
176+
177+
def f_while(x, max_steps=4):
178+
init_state = (0, x)
179+
return jax_utils.while_loop_bounded(
180+
cond_fun, body_fun, init_state, max_steps=max_steps
181+
)[1]
182+
183+
def f(x, n_times=terminating_step):
184+
result = x
185+
for _ in range(n_times):
186+
result = jnp.sin(result)
187+
return result
188+
189+
with self.subTest('forward_agrees_with_explicit'):
190+
chex.assert_trees_all_close(f_while(0.5), f(0.5))
191+
with self.subTest('grad_agrees_with_explicit'):
192+
chex.assert_trees_all_close(jax.grad(f_while)(0.5), jax.grad(f)(0.5))
193+
194+
with self.subTest('max_steps_is_respected'):
195+
final_i, final_value = jax_utils.while_loop_bounded(
196+
cond_fun, body_fun, init_state, max_steps=2
197+
)
198+
self.assertEqual(final_i, 2)
199+
chex.assert_trees_all_close(final_value, f(0.5, n_times=2))
200+
chex.assert_trees_all_close(
201+
jax.grad(f_while)(0.5, max_steps=2), jax.grad(f)(0.5, n_times=2)
202+
)
203+
chex.assert_trees_all_close(f_while(0.5, max_steps=3), f(0.5, n_times=3))
204+
153205

154206
if __name__ == '__main__':
155207
absltest.main()

0 commit comments

Comments
 (0)