Skip to content

Commit 62bb6f3

Browse files
Nush395Torax team
authored andcommitted
Make InterpolatedVarSingleAxis a PyTree.
This is to prepare for making the `GeometryProvider`s PyTrees. Moved the sorting checks in the two classes to the beginning of the container and make JIT compatible. Speed check on iterhybrid_rampup has no timing change with/without TORAX_ERRORS_ENABLED. PiperOrigin-RevId: 775675478
1 parent d22624c commit 62bb6f3

File tree

3 files changed

+47
-29
lines changed

3 files changed

+47
-29
lines changed

.github/workflows/pytest.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ jobs:
7373
run: |
7474
pytest \
7575
-vv -n auto \
76+
--ignore=torax/tests/sim_experimental_compile_test.py \
77+
--ignore=torax/tests/sim_no_compile_test.py \
7678
--shard-id=$((${{ matrix.shard-id }} - 1)) --num-shards=${{ env.PYTEST_NUM_SHARDS }}
7779
# Two test require an extra environment variable, so we run them separately.
7880
- name: Run sim_experimental_compile

torax/_src/interpolated_param.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def get_value(self, x: chex.Numeric) -> chex.Array:
126126
"""Returns a value for this parameter interpolated at the given input."""
127127

128128

129-
class PiecewiseLinearInterpolatedParam(InterpolatedParamBase):
129+
class _PiecewiseLinearInterpolatedParam(InterpolatedParamBase):
130130
"""Parameter using piecewise-linear interpolation to compute its value."""
131131

132132
def __init__(self, xs: chex.Array, ys: chex.Array):
@@ -149,10 +149,6 @@ def __init__(self, xs: chex.Array, ys: chex.Array):
149149
if ys.ndim not in (1, 2):
150150
raise ValueError(f'ys must be either 1D or 2D. Given: {self.ys.shape}.')
151151

152-
xs_np = np.array(self.xs)
153-
if not np.array_equal(np.sort(xs_np), xs_np):
154-
raise RuntimeError('xs must be sorted.')
155-
156152
@property
157153
def xs(self) -> chex.Array:
158154
return self._xs
@@ -196,7 +192,7 @@ def get_value(
196192
raise ValueError(f'ys must be either 1D or 2D. Given: {self.ys.shape}.')
197193

198194

199-
class StepInterpolatedParam(InterpolatedParamBase):
195+
class _StepInterpolatedParam(InterpolatedParamBase):
200196
"""Parameter using step interpolation to compute its value."""
201197

202198
def __init__(self, xs: chex.Array, ys: chex.Array):
@@ -211,8 +207,6 @@ def __init__(self, xs: chex.Array, ys: chex.Array):
211207
'xs and ys must have the same number of elements in the first '
212208
f'dimension. Given: {self.xs.shape} and {self.ys.shape}.'
213209
)
214-
if not jnp.allclose(jnp.sort(self.xs), self.xs):
215-
raise RuntimeError('xs must be sorted.')
216210

217211
@property
218212
def xs(self) -> chex.Array:
@@ -325,6 +319,7 @@ def convert_input_to_xs_ys(
325319
)
326320

327321

322+
@jax.tree_util.register_pytree_node_class
328323
class InterpolatedVarSingleAxis(InterpolatedParamBase):
329324
"""Parameter that may vary based on an input coordinate.
330325
@@ -364,8 +359,12 @@ def __init__(
364359
interpolation_mode: Defines how to interpolate between values in `value`.
365360
is_bool_param: If True, the input value is assumed to be a bool and is
366361
converted to a float.
362+
Raises:
363+
RuntimeError: If the input xs is not sorted.
367364
"""
365+
self._value = value
368366
xs, ys = value
367+
jax_utils.error_if(xs, jnp.any(jnp.diff(xs) < 0), 'xs must be sorted.')
369368

370369
if not np.issubdtype(xs.dtype, np.floating):
371370
raise ValueError(f'xs must be a float array, but got {xs.dtype}.')
@@ -376,12 +375,23 @@ def __init__(
376375
self._interpolation_mode = interpolation_mode
377376
match interpolation_mode:
378377
case InterpolationMode.PIECEWISE_LINEAR:
379-
self._param = PiecewiseLinearInterpolatedParam(xs=xs, ys=ys)
378+
self._param = _PiecewiseLinearInterpolatedParam(xs=xs, ys=ys)
380379
case InterpolationMode.STEP:
381-
self._param = StepInterpolatedParam(xs=xs, ys=ys)
380+
self._param = _StepInterpolatedParam(xs=xs, ys=ys)
382381
case _:
383382
raise ValueError('Unknown interpolation mode.')
384383

384+
def tree_flatten(self):
385+
static_params = {
386+
'interpolation_mode': self.interpolation_mode,
387+
'is_bool_param': self.is_bool_param,
388+
}
389+
return (self._value, static_params)
390+
391+
@classmethod
392+
def tree_unflatten(cls, aux_data, children):
393+
return cls(children, **aux_data)
394+
385395
@property
386396
def is_bool_param(self) -> bool:
387397
"""Returns whether this param represents a bool."""

torax/_src/tests/interpolated_param_test.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from absl.testing import absltest
1818
from absl.testing import parameterized
19+
import chex
1920
import jax
2021
from jax import numpy as jnp
2122
import numpy as np
@@ -173,8 +174,8 @@ def test_multi_value_range_returns_expected_output(
173174
)
174175

175176
@parameterized.parameters(
176-
(interpolated_param.PiecewiseLinearInterpolatedParam,),
177-
(interpolated_param.StepInterpolatedParam,),
177+
(interpolated_param._PiecewiseLinearInterpolatedParam,),
178+
(interpolated_param._StepInterpolatedParam,),
178179
)
179180
def test_interpolated_param_1d_xs_and_1d_or_2d_ys(self, range_class):
180181
"""Tests that the interpolated_param only take 1D inputs."""
@@ -198,8 +199,8 @@ def test_interpolated_param_1d_xs_and_1d_or_2d_ys(self, range_class):
198199
)
199200

200201
@parameterized.parameters(
201-
(interpolated_param.PiecewiseLinearInterpolatedParam,),
202-
(interpolated_param.StepInterpolatedParam,),
202+
(interpolated_param._PiecewiseLinearInterpolatedParam,),
203+
(interpolated_param._StepInterpolatedParam,),
203204
)
204205
def test_interpolated_param_need_xs_ys_same_shape(self, range_class):
205206
"""Tests the xs and ys inputs have to have the same shape."""
@@ -214,19 +215,20 @@ def test_interpolated_param_need_xs_ys_same_shape(self, range_class):
214215
)
215216

216217
@parameterized.parameters(
217-
(interpolated_param.PiecewiseLinearInterpolatedParam,),
218-
(interpolated_param.StepInterpolatedParam,),
218+
interpolated_param.InterpolationMode.PIECEWISE_LINEAR,
219+
interpolated_param.InterpolationMode.STEP,
219220
)
220-
def test_interpolated_param_need_xs_to_be_sorted(self, range_class):
221-
"""Tests the xs inputs have to be sorted."""
222-
range_class(
223-
xs=jnp.array([1.0, 2.0, 3.0, 4.0]),
224-
ys=jnp.array([1.0, 2.0, 3.0, 4.0]),
225-
)
221+
def test_interpolated_param_is_invariant_to_xs_order(
222+
self,
223+
interpolation_mode: interpolated_param.InterpolationMode,
224+
):
226225
with self.assertRaises(RuntimeError):
227-
range_class(
228-
xs=jnp.array([4.0, 2.0, 1.0, 3.0]),
229-
ys=jnp.array([1.0, 2.0, 3.0, 4.0]),
226+
interpolated_param.InterpolatedVarSingleAxis(
227+
value=(
228+
jnp.array([4.0, 2.0, 1.0, 3.0]),
229+
jnp.array([4.0, 2.0, 1.0, 3.0]),
230+
),
231+
interpolation_mode=interpolation_mode,
230232
)
231233

232234
@parameterized.named_parameters(
@@ -493,19 +495,23 @@ def test_interpolated_param_get_value_is_jittable(
493495
interpolated_param.InterpolationMode.STEP,
494496
],
495497
)
496-
def test_interpolated_var_properties(
498+
def test_interpolated_param_is_usable_under_jit(
497499
self,
498500
is_bool: bool,
499501
interpolation_mode: interpolated_param.InterpolationMode,
500502
):
501-
"""Check the properties of the interpolated var are set correctly."""
502503
var = interpolated_param.InterpolatedVarSingleAxis(
503504
value=(np.array([0.0, 1.0]), np.array([0.0, 1.0])),
504505
is_bool_param=is_bool,
505506
interpolation_mode=interpolation_mode,
506507
)
507-
self.assertEqual(var.is_bool_param, is_bool)
508-
self.assertEqual(var.interpolation_mode, interpolation_mode)
508+
509+
def f(x: interpolated_param.InterpolatedVarSingleAxis, t: chex.Numeric):
510+
return x.get_value(x=t)
511+
512+
interpolated_output_jit = jax.jit(f)(var, 0.5)
513+
interpolated_output = f(var, 0.5)
514+
np.testing.assert_allclose(interpolated_output_jit, interpolated_output)
509515

510516

511517
if __name__ == '__main__':

0 commit comments

Comments
 (0)