Skip to content

Commit b9d6e86

Browse files
sbodensteinTorax team
authored andcommitted
Switch the newton_raphson_solve_block to use jax_root_finding.root_newton_raphson, and JIT the entire newton_raphson_solve_block.
The `jax` and `jaxlib` versions are incremented to `0.7` in order to support the new `--xla_backend_extra_options=xla_cpu_flatten_after_fusion` flag, and Python incremented to 3.11 to support jaxlib 0.7.0. This flag treats all functions wrapped with a `jax.jit` within a larger `jax.jit` as independently compilable subfunctions (ie. prevents compiler inlining). It is critical for preventing blowups of compile-times (XLA:CPU compile-times scale roughly quadratically with the number of ops otherwise). Use of jax.lax.custom_root is temporarily disabled until a workaround is found for it significantly increasing compile times. This increases simulation speed by around 20%, but at the cost of ~10% increases in compile-time PiperOrigin-RevId: 782947706
1 parent 68d7481 commit b9d6e86

File tree

12 files changed

+103
-324
lines changed

12 files changed

+103
-324
lines changed

.github/workflows/linting.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
runs-on: ubuntu-latest
1515
strategy:
1616
matrix:
17-
python-version: ['3.10']
17+
python-version: ['3.11']
1818
steps:
1919
- uses: actions/checkout@v4
2020
- name: Set up Python ${{ matrix.python-version }}

.github/workflows/pytest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ jobs:
5252
# Can't reference env variables in matrix
5353
num-shards: ${{ fromJson(needs.shards-job.outputs.num-shards) }}
5454
shard-id: ${{ fromJson(needs.shards-job.outputs.shard-ids) }}
55-
python-version: ['3.10']
55+
python-version: ['3.11']
5656
os-version: [ubuntu-latest]
5757

5858
steps:

.readthedocs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ version: 2
1414
build:
1515
os: ubuntu-22.04
1616
tools:
17-
python: "3.10"
17+
python: "3.11"
1818
commands:
1919
- pip install -r docs/requirements.txt
2020
- cd docs && make html

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ is outlined in our readthedocs pages.
7070

7171
### Requirements
7272

73-
Install Python 3.10 or greater.
73+
Install Python 3.11 or greater.
7474

7575
Make sure that tkinter is installed:
7676

docs/installation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Installation Guide
88
Requirements
99
============
1010

11-
Install Python 3.10 or greater.
11+
Install Python 3.11 or greater.
1212

1313
Make sure that tkinter is installed:
1414

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
66
name = "torax"
77
description = "Differentiable 1D tokamak plasma transport simulator in JAX."
88
readme = "README.md"
9-
requires-python = ">=3.10"
9+
requires-python = ">=3.11"
1010
license = "Apache-2.0"
1111
license-files = ["LICENSE"]
1212
authors = [{name = "TORAX authors", email="torax-dev@google.com"}]
@@ -18,14 +18,14 @@ dependencies = [
1818
"absl-py>=2.0.0",
1919
"typing_extensions>=4.2.0",
2020
"immutabledict>=1.0.0",
21-
"jax>=0.4.32",
22-
"jaxlib>=0.4.32",
21+
"jax>=0.7.0",
22+
"jaxlib>=0.7.0",
2323
"jaxopt>=0.8.2",
2424
"flax>=0.10.0",
2525
"fusion_surrogates==0.1.0",
2626
"matplotlib>=3.3.0",
2727
"numpy>2",
28-
"setuptools;python_version>='3.10'",
28+
"setuptools;python_version>='3.11'",
2929
"chex>=0.1.88",
3030
"equinox>=0.11.3",
3131
"PyYAML>=6.0.1",

torax/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535

3636
# pylint: enable=g-importing-member
3737

38+
39+
os.environ['XLA_FLAGS'] = (
40+
os.environ.get('XLA_FLAGS', '')
41+
+ ' --xla_backend_extra_options=xla_cpu_flatten_after_fusion'
42+
)
43+
3844
__version__ = version.TORAX_VERSION
3945
__version_info__ = version.TORAX_VERSION_INFO
4046

torax/_src/fvm/jax_root_finding.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def root_newton_raphson(
4747
delta_reduction_factor: float = 0.5,
4848
tau_min: float = 0.01,
4949
log_iterations: bool = False,
50+
use_jax_custom_root: bool = True,
5051
) -> tuple[jax.Array, RootMetadata]:
5152
"""A differentiable Newton-Raphson root finder.
5253
@@ -66,6 +67,9 @@ def root_newton_raphson(
6667
routine resets at a lower timestep.
6768
log_iterations: If true, output diagnostic information from within iteration
6869
loop.
70+
use_jax_custom_root: If true, use jax.lax.custom_root to allow for
71+
differentiable solving. This can increase compile times even when no
72+
derivatives are requested.
6973
7074
Returns:
7175
A tuple `(x_root, RootMetadata(...))`.
@@ -113,13 +117,16 @@ def _newton_raphson(f, x):
113117
def back(g, y):
114118
return jnp.linalg.solve(jax.jacfwd(g)(y), y)
115119

116-
x_out, metadata = jax.lax.custom_root(
117-
f=fun,
118-
initial_guess=x0,
119-
solve=_newton_raphson,
120-
tangent_solve=back,
121-
has_aux=True,
122-
)
120+
if use_jax_custom_root:
121+
x_out, metadata = jax.lax.custom_root(
122+
f=fun,
123+
initial_guess=x0,
124+
solve=_newton_raphson,
125+
tangent_solve=back,
126+
has_aux=True,
127+
)
128+
else:
129+
x_out, metadata = _newton_raphson(fun, x0)
123130

124131
# Tell the caller whether or not x_new successfully reduces the residual below
125132
# the tolerance by providing an extra output, error.

0 commit comments

Comments
 (0)