Skip to content

Commit f68547e

Browse files
tamaranormanTorax team
authored andcommitted
Improve the interpolated_param logic to be more flexible around tracers versus arrays etc
PiperOrigin-RevId: 853700330
1 parent f3a791e commit f68547e

File tree

6 files changed

+170
-75
lines changed

6 files changed

+170
-75
lines changed

torax/_src/orchestration/step_function.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -142,16 +142,6 @@ def check_for_errors(
142142
post_processed_outputs: post_processing.PostProcessedOutputs,
143143
) -> state.SimError:
144144
"""Checks for errors in the simulation state."""
145-
if self._runtime_params_provider.numerics.adaptive_dt:
146-
if output_state.solver_numeric_outputs.solver_error_state == 1:
147-
# Only check for min dt if the solver did not converge. Else we may have
148-
# converged at a dt > min_dt just before we reach min_dt.
149-
if (
150-
output_state.dt
151-
/ self._runtime_params_provider.numerics.dt_reduction_factor
152-
< self._runtime_params_provider.numerics.min_dt
153-
):
154-
return state.SimError.REACHED_MIN_DT
155145

156146
# Low-temperature collapse check
157147
if output_state.core_profiles.below_minimum_temperature(
@@ -162,8 +152,20 @@ def check_for_errors(
162152
state_error = output_state.check_for_errors()
163153
if state_error != state.SimError.NO_ERROR:
164154
return state_error
165-
else:
166-
return post_processed_outputs.check_for_errors()
155+
156+
post_processed_error = post_processed_outputs.check_for_errors()
157+
if post_processed_error != state.SimError.NO_ERROR:
158+
return post_processed_error
159+
160+
# Check if reached the minimum time step last - this is often caused by
161+
# other errors so check those first to give more informative error messages.
162+
if self._runtime_params_provider.numerics.adaptive_dt:
163+
if output_state.solver_numeric_outputs.solver_error_state == 1:
164+
# If using adaptive stepping and the solver did not converge we must
165+
# have reached the minimum time step, so we can exit the simulation.
166+
return state.SimError.REACHED_MIN_DT
167+
168+
return state.SimError.NO_ERROR
167169

168170
@jax.jit
169171
def __call__(
@@ -298,8 +300,13 @@ def fixed_time_step(
298300
remaining_dt = dt
299301

300302
def cond(args):
301-
remaining_dt, _, _ = args
302-
return remaining_dt > constants.CONSTANTS.eps
303+
remaining_dt, prev_state, _ = args
304+
if self.runtime_params_provider.numerics.adaptive_dt:
305+
exit_min_dt = prev_state.solver_numeric_outputs.solver_error_state == 1
306+
else:
307+
exit_min_dt = False
308+
return jnp.logical_and(
309+
remaining_dt > constants.CONSTANTS.eps, ~exit_min_dt)
303310

304311
def body(args):
305312
remaining_dt, prev_state, prev_post_processed = args
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from absl.testing import absltest
15+
from absl.testing import parameterized
16+
import jax
17+
import jax.test_util as jtu
18+
from torax._src.config import config_loader
19+
from torax._src.orchestration import run_simulation
20+
from torax._src.torax_pydantic import interpolated_param_1d
21+
22+
23+
class StepFunctionTest(parameterized.TestCase):
24+
25+
@parameterized.parameters([
26+
'basic_config',
27+
'iterhybrid_predictor_corrector',
28+
])
29+
def test_step_function_grad(self, config_name_no_py):
30+
example_config_paths = config_loader.example_config_paths()
31+
example_config_path = example_config_paths[config_name_no_py]
32+
cfg = config_loader.build_torax_config_from_file(example_config_path)
33+
(
34+
sim_state,
35+
post_processed_outputs,
36+
step_fn,
37+
) = run_simulation.prepare_simulation(cfg)
38+
params_provider = step_fn.runtime_params_provider
39+
input_value = params_provider.profile_conditions.Ip.value
40+
41+
@jax.jit
42+
def f(override_value):
43+
ip_update = interpolated_param_1d.TimeVaryingScalarUpdate(
44+
value=override_value
45+
)
46+
runtime_params_overrides = params_provider.update_provider(
47+
lambda x: (x.profile_conditions.Ip,),
48+
(ip_update,),
49+
)
50+
_, new_post_processed_outputs = step_fn(
51+
sim_state,
52+
post_processed_outputs,
53+
runtime_params_overrides=runtime_params_overrides,
54+
)
55+
return new_post_processed_outputs.Q_fusion
56+
57+
jtu.check_grads(f, (input_value,), order=1, modes=('rev',))
58+
59+
60+
if __name__ == '__main__':
61+
absltest.main()

torax/_src/orchestration/tests/step_function_test.py

Lines changed: 29 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import dataclasses
15+
import os
1516
from typing import Any
1617

1718
from absl.testing import absltest
1819
from absl.testing import parameterized
1920
import chex
20-
import jax
2121
import jax.numpy as jnp
22-
import jax.test_util as jtu
2322
import numpy as np
2423
from torax._src import state
2524
from torax._src.config import config_loader
@@ -28,6 +27,7 @@
2827
from torax._src.orchestration import step_function
2928
from torax._src.output_tools import post_processing
3029
from torax._src.test_utils import default_configs
30+
from torax._src.test_utils import paths
3131
from torax._src.torax_pydantic import interpolated_param_1d
3232
from torax._src.torax_pydantic import model_config
3333

@@ -266,6 +266,24 @@ def test_adaptive_step_with_smaller_passed_max_dt(self):
266266
)
267267
self.assertTrue(np.less_equal(output_state.dt, passed_max_dt))
268268

269+
def test_fixed_step_with_high_density_errors_and_does_not_hang(self):
270+
# This test enforces that we exit the fixed step function early if we hit
271+
# min_dt. If we don't do this then we risk hanging for a very long time as
272+
# we stay at min_dt and the step never seems to make progress. This test
273+
# ensures that we don't hang and instead fail early.
274+
test_data_dir = paths.test_data_dir()
275+
torax_config = config_loader.build_torax_config_from_file(
276+
os.path.join(test_data_dir, 'test_iterhybrid_radiation_collapse.py')
277+
)
278+
sim_state, post_processed_outputs, step_fn = (
279+
run_simulation.prepare_simulation(torax_config)
280+
)
281+
sim_state, post_processed_outputs = step_fn.fixed_time_step(
282+
np.array(1.), sim_state, post_processed_outputs)
283+
284+
sim_error = step_fn.check_for_errors(sim_state, post_processed_outputs)
285+
self.assertEqual(sim_error, state.SimError.NAN_DETECTED)
286+
269287
def test_call_with_sawtooth_solver_smoke_test(self):
270288
"""Smoke test for the boolean logic around the sawtooth solver.
271289
@@ -338,49 +356,11 @@ def test_fixed_time_step_t_less_than_min_dt(self):
338356
)
339357
np.testing.assert_allclose(output_state.dt, 0.01, atol=1e-7)
340358

341-
@parameterized.parameters([
342-
'basic_config',
343-
'iterhybrid_predictor_corrector',
344-
])
345-
def test_step_function_grad(self, config_name_no_py):
346-
example_config_paths = config_loader.example_config_paths()
347-
example_config_path = example_config_paths[config_name_no_py]
348-
cfg = config_loader.build_torax_config_from_file(example_config_path)
349-
(
350-
sim_state,
351-
post_processed_outputs,
352-
step_fn,
353-
) = run_simulation.prepare_simulation(cfg)
354-
params_provider = step_fn.runtime_params_provider
355-
input_value = params_provider.profile_conditions.Ip.value
356-
357-
@jax.jit
358-
def f(override_value):
359-
ip_update = interpolated_param_1d.TimeVaryingScalarUpdate(
360-
value=override_value
361-
)
362-
runtime_params_overrides = params_provider.update_provider(
363-
lambda x: (x.profile_conditions.Ip,),
364-
(ip_update,),
365-
)
366-
_, new_post_processed_outputs = step_fn(
367-
sim_state,
368-
post_processed_outputs,
369-
runtime_params_overrides=runtime_params_overrides,
370-
)
371-
return new_post_processed_outputs.Q_fusion
372-
373-
jtu.check_grads(f, (input_value,), order=1, modes=('rev',))
374-
375-
@parameterized.parameters([
376-
'iterhybrid_predictor_corrector',
377-
'iterhybrid_rampup',
378-
])
379-
def test_step_function_overrides(self, config_name_no_py):
380-
example_config_paths = config_loader.example_config_paths()
381-
example_config_path = example_config_paths[config_name_no_py]
382-
raw_config = config_loader.import_module(example_config_path)['CONFIG']
383-
cfg = config_loader.build_torax_config_from_file(example_config_path)
359+
def test_step_function_overrides(self):
360+
original_ip = 15e6
361+
config_dict = default_configs.get_default_config_dict()
362+
config_dict['profile_conditions']['Ip'] = original_ip
363+
cfg = model_config.ToraxConfig.from_dict(config_dict)
384364
(
385365
sim_state,
386366
post_processed_outputs,
@@ -403,10 +383,7 @@ def test_step_function_overrides(self, config_name_no_py):
403383
)
404384

405385
# Update the config itself and re-run the step.
406-
doubled_ip = jax.tree_util.tree_map(
407-
lambda x: x * 2.0, raw_config['profile_conditions']['Ip']
408-
)
409-
cfg.update_fields({'profile_conditions.Ip': doubled_ip})
386+
cfg.update_fields({'profile_conditions.Ip': original_ip * 2.0})
410387
step_fn = run_simulation.make_step_fn(cfg)
411388
ref_state, ref_post_processed_outputs = step_fn(
412389
# Use original state and post-processed outputs as the initial value.
@@ -419,13 +396,9 @@ def test_step_function_overrides(self, config_name_no_py):
419396
override_post_processed_outputs, ref_post_processed_outputs
420397
)
421398

422-
@parameterized.parameters([
423-
('iterhybrid_rampup',),
424-
])
425-
def test_step_function_geo_overrides(self, config_name_no_py):
426-
example_config_paths = config_loader.example_config_paths()
427-
example_config_path = example_config_paths[config_name_no_py]
428-
cfg = config_loader.build_torax_config_from_file(example_config_path)
399+
def test_step_function_geo_overrides(self):
400+
config_dict = default_configs.get_default_config_dict()
401+
cfg = model_config.ToraxConfig.from_dict(config_dict)
429402
(
430403
sim_state,
431404
post_processed_outputs,

torax/_src/torax_pydantic/interpolated_param_2d.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,22 @@ class TimeVaryingArrayUpdate:
8383

8484
def __post_init__(self):
8585
"""Consistency checks for the provided values."""
86-
if not isinstance(self.value, type(self.rho_norm)):
86+
if (self.rho_norm is None and self.value is not None) or (
87+
self.rho_norm is not None and self.value is None
88+
):
8789
raise ValueError(
88-
'If rho_norm is provided, value must also be provided. Got value:'
89-
f' {type(self.value)}, rho_norm: {type(self.rho_norm)}'
90+
'Either both or neither of rho_norm and value must be provided.'
9091
)
92+
9193
if self.rho_norm is not None and self.value is not None:
9294
rho_norm_shape = self.rho_norm.shape
9395
if rho_norm_shape[0] != self.value.shape[1]:
9496
raise ValueError(
95-
'rho_norm and value must have the same shape. Got rho_norm shape:'
96-
f' {rho_norm_shape} and value shape: {self.value.shape}'
97+
'rho_norm and value must have the same trailing dimension. '
98+
f'Got rho_norm shape: {rho_norm_shape} and value shape: '
99+
f'{self.value.shape}'
97100
)
101+
98102
if self.value is not None and self.time is not None:
99103
if self.value.shape[0] != self.time.shape[0]:
100104
raise ValueError(
@@ -588,6 +592,7 @@ def _get_face_centers(nx: int, dx: float) -> np.ndarray:
588592
def _get_cell_centers(nx: int, dx: float) -> np.ndarray:
589593
return np.linspace(dx * 0.5, (nx - 0.5) * dx, nx)
590594

595+
591596
NonNegativeTimeVaryingArray: TypeAlias = typing_extensions.Annotated[
592597
TimeVaryingArray, pydantic.AfterValidator(_is_non_negative)
593598
]

torax/_src/torax_pydantic/tests/interpolated_param_2d_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,51 @@ def f(
470470
interpolated.get_value(t=0.0, grid_type='face_right'),
471471
)
472472

473+
def test_time_varying_array_update_validations_value_only(self):
474+
with self.assertRaisesRegex(
475+
ValueError,
476+
'Either both or neither of rho_norm and value must be provided.',
477+
):
478+
interpolated_param_2d.TimeVaryingArrayUpdate(
479+
value=np.array([[1.0]]), rho_norm=None
480+
)
481+
482+
def test_time_varying_array_update_validations_rhonorm_only(self):
483+
with self.assertRaisesRegex(
484+
ValueError,
485+
'Either both or neither of rho_norm and value must be provided.',
486+
):
487+
interpolated_param_2d.TimeVaryingArrayUpdate(
488+
value=None, rho_norm=np.array([1.0])
489+
)
490+
491+
def test_time_varying_array_update_validations_shape_mismatch(self):
492+
with self.assertRaisesRegex(
493+
ValueError,
494+
'rho_norm and value must have the same trailing dimension.',
495+
):
496+
interpolated_param_2d.TimeVaryingArrayUpdate(
497+
value=np.array([[1.0, 2.0], [3.0, 4.0]]), rho_norm=np.array([1.0])
498+
)
499+
500+
def test_time_varying_array_update_validations_time_dimension_mismatch(self):
501+
with self.assertRaisesRegex(
502+
ValueError,
503+
'value and time arrays must have same leading dimension.',
504+
):
505+
interpolated_param_2d.TimeVaryingArrayUpdate(
506+
value=np.array([[1.0, 2.0], [3.0, 4.0]]),
507+
rho_norm=np.array([0.0, 1.0]),
508+
time=np.array([0.0]),
509+
)
510+
511+
def test_allowed_mix_of_numpy_and_jax_arrays_for_update(self):
512+
interpolated_param_2d.TimeVaryingArrayUpdate(
513+
value=jnp.array([[1.0, 2.0], [3.0, 4.0]]),
514+
rho_norm=np.array([0.0, 1.0]),
515+
time=np.array([0.0, 1.0]),
516+
)
517+
473518
@parameterized.named_parameters(
474519
dict(
475520
testcase_name='update_values',

torax/tests/test_data/test_iterhybrid_radiation_collapse.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
'W': W_frac,
3232
}
3333
CONFIG['plasma_composition']['Z_eff'] = 3.0
34+
35+
# Remove QLKNN transport model to simplify step and avoid QLKNN load.
36+
CONFIG['transport'] = {}
37+
3438
CONFIG['sources']['impurity_radiation'] = {
3539
'model_name': 'mavrin_fit',
3640
}

0 commit comments

Comments
 (0)