Skip to content

Commit 1ab5d23

Browse files
Nush395Torax team
authored andcommitted
Mark n_corrector_steps as static for allowing for backwards mode differentiation of predictor corrector.
jax.lax.fori_loop requires a static upper and lower to be backwards mode differentiable. PiperOrigin-RevId: 794067640
1 parent 68d7481 commit 1ab5d23

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

torax/_src/solver/pydantic_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ class BaseSolver(torax_pydantic.BaseModelFrozen, abc.ABC):
5555
torax_pydantic.UnitInterval, torax_pydantic.JAX_STATIC
5656
] = 1.0
5757
use_predictor_corrector: Annotated[bool, torax_pydantic.JAX_STATIC] = False
58-
n_corrector_steps: pydantic.PositiveInt = 10
58+
n_corrector_steps: Annotated[
59+
pydantic.PositiveInt, torax_pydantic.JAX_STATIC
60+
] = 10
5961
convection_dirichlet_mode: Annotated[
6062
Literal['ghost', 'direct', 'semi-implicit'], torax_pydantic.JAX_STATIC
6163
] = 'ghost'

torax/_src/solver/runtime_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,4 @@ class DynamicRuntimeParams:
3636

3737
chi_pereverzev: float
3838
D_pereverzev: float # pylint: disable=invalid-name
39-
n_corrector_steps: int
39+
n_corrector_steps: int = dataclasses.field(metadata={'static': True})

0 commit comments

Comments
 (0)