Skip to content

Commit 5742f5e

Browse files
sbodensteinTorax team
authored andcommitted
Switch the newton_raphson_solve_block to use jax_root_finding.root_newton_raphson, and JIT the entire newton_raphson_solve_block.
The `jax` and `jaxlib` versions are incremented to `0.7` in order to support the new `--xla_backend_extra_options=xla_cpu_flatten_after_fusion` flag, and Python incremented to 3.11 to support jaxlib 0.7.0. This flag treats all functions wrapped with a `jax.jit` within a larger `jax.jit` as independently compilable subfunctions (ie. prevents compiler inlining). It is critical for preventing blowups of compile-times (XLA:CPU compile-times scale roughly quadratically with the number of ops otherwise). Use of jax.lax.custom_root is temporarily disabled until a workaround is found for it significantly increasing compile times. This increases simulation speed by around 20%, but at the cost of ~10% increases in compile-time PiperOrigin-RevId: 782947706
1 parent 68d7481 commit 5742f5e

File tree

7 files changed

+97
-318
lines changed

7 files changed

+97
-318
lines changed

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ dependencies = [
1818
"absl-py>=2.0.0",
1919
"typing_extensions>=4.2.0",
2020
"immutabledict>=1.0.0",
21-
"jax>=0.4.32",
22-
"jaxlib>=0.4.32",
21+
"jax>=0.7.0",
22+
"jaxlib>=0.7.0",
2323
"jaxopt>=0.8.2",
2424
"flax>=0.10.0",
2525
"fusion_surrogates==0.1.0",
2626
"matplotlib>=3.3.0",
2727
"numpy>2",
28-
"setuptools;python_version>='3.10'",
28+
"setuptools;python_version>='3.11'",
2929
"chex>=0.1.88",
3030
"equinox>=0.11.3",
3131
"PyYAML>=6.0.1",

torax/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535

3636
# pylint: enable=g-importing-member
3737

38+
39+
os.environ['XLA_FLAGS'] = (
40+
os.environ.get('XLA_FLAGS', '')
41+
+ ' --xla_backend_extra_options=xla_cpu_flatten_after_fusion'
42+
)
43+
3844
__version__ = version.TORAX_VERSION
3945
__version_info__ = version.TORAX_VERSION_INFO
4046

torax/_src/fvm/jax_root_finding.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def root_newton_raphson(
4747
delta_reduction_factor: float = 0.5,
4848
tau_min: float = 0.01,
4949
log_iterations: bool = False,
50+
use_jax_custom_root: bool = True,
5051
) -> tuple[jax.Array, RootMetadata]:
5152
"""A differentiable Newton-Raphson root finder.
5253
@@ -66,6 +67,9 @@ def root_newton_raphson(
6667
routine resets at a lower timestep.
6768
log_iterations: If true, output diagnostic information from within iteration
6869
loop.
70+
use_jax_custom_root: If true, use jax.lax.custom_root to allow for
71+
differentiable solving. This can increase compile times even when no
72+
derivatives are requested.
6973
7074
Returns:
7175
A tuple `(x_root, RootMetadata(...))`.
@@ -113,13 +117,16 @@ def _newton_raphson(f, x):
113117
def back(g, y):
114118
return jnp.linalg.solve(jax.jacfwd(g)(y), y)
115119

116-
x_out, metadata = jax.lax.custom_root(
117-
f=fun,
118-
initial_guess=x0,
119-
solve=_newton_raphson,
120-
tangent_solve=back,
121-
has_aux=True,
122-
)
120+
if use_jax_custom_root:
121+
x_out, metadata = jax.lax.custom_root(
122+
f=fun,
123+
initial_guess=x0,
124+
solve=_newton_raphson,
125+
tangent_solve=back,
126+
has_aux=True,
127+
)
128+
else:
129+
x_out, metadata = _newton_raphson(fun, x0)
123130

124131
# Tell the caller whether or not x_new successfully reduces the residual below
125132
# the tolerance by providing an extra output, error.

0 commit comments

Comments
 (0)