Skip to content

Commit a385fd2

Browse files
tamaranormanTorax team
authored andcommitted
Speeding up tests
- Avoid using newton-raphson when not needed - aka in update tests - Move slow gradient tests to their own file - Remove the transport model from the radiation collapse to streamline PiperOrigin-RevId: 853666131
1 parent f3a791e commit a385fd2

File tree

4 files changed

+122
-70
lines changed

4 files changed

+122
-70
lines changed

torax/_src/orchestration/step_function.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -142,16 +142,6 @@ def check_for_errors(
142142
post_processed_outputs: post_processing.PostProcessedOutputs,
143143
) -> state.SimError:
144144
"""Checks for errors in the simulation state."""
145-
if self._runtime_params_provider.numerics.adaptive_dt:
146-
if output_state.solver_numeric_outputs.solver_error_state == 1:
147-
# Only check for min dt if the solver did not converge. Else we may have
148-
# converged at a dt > min_dt just before we reach min_dt.
149-
if (
150-
output_state.dt
151-
/ self._runtime_params_provider.numerics.dt_reduction_factor
152-
< self._runtime_params_provider.numerics.min_dt
153-
):
154-
return state.SimError.REACHED_MIN_DT
155145

156146
# Low-temperature collapse check
157147
if output_state.core_profiles.below_minimum_temperature(
@@ -162,8 +152,20 @@ def check_for_errors(
162152
state_error = output_state.check_for_errors()
163153
if state_error != state.SimError.NO_ERROR:
164154
return state_error
165-
else:
166-
return post_processed_outputs.check_for_errors()
155+
156+
post_processed_error = post_processed_outputs.check_for_errors()
157+
if post_processed_error != state.SimError.NO_ERROR:
158+
return post_processed_error
159+
160+
# Check if reached the minimum time step last - this is often caused by
161+
# other errors so check those first to give more informative error messages.
162+
if self._runtime_params_provider.numerics.adaptive_dt:
163+
if output_state.solver_numeric_outputs.solver_error_state == 1:
164+
# If using adaptive stepping and the solver did not converge we must
165+
# have reached the minimum time step, so we can exit the simulation.
166+
return state.SimError.REACHED_MIN_DT
167+
168+
return state.SimError.NO_ERROR
167169

168170
@jax.jit
169171
def __call__(
@@ -298,8 +300,13 @@ def fixed_time_step(
298300
remaining_dt = dt
299301

300302
def cond(args):
301-
remaining_dt, _, _ = args
302-
return remaining_dt > constants.CONSTANTS.eps
303+
remaining_dt, prev_state, _ = args
304+
if self.runtime_params_provider.numerics.adaptive_dt:
305+
exit_min_dt = prev_state.solver_numeric_outputs.solver_error_state == 1
306+
else:
307+
exit_min_dt = False
308+
return jnp.logical_and(
309+
remaining_dt > constants.CONSTANTS.eps, ~exit_min_dt)
303310

304311
def body(args):
305312
remaining_dt, prev_state, prev_post_processed = args
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from absl import logging
15+
from absl.testing import absltest
16+
from absl.testing import parameterized
17+
import jax
18+
import jax.test_util as jtu
19+
from torax._src.config import config_loader
20+
from torax._src.orchestration import run_simulation
21+
from torax._src.torax_pydantic import interpolated_param_1d
22+
23+
24+
class StepFunctionTest(parameterized.TestCase):
25+
26+
@parameterized.parameters([
27+
'basic_config',
28+
'iterhybrid_predictor_corrector',
29+
])
30+
def test_step_function_grad(self, config_name_no_py):
31+
example_config_paths = config_loader.example_config_paths()
32+
example_config_path = example_config_paths[config_name_no_py]
33+
logging.info('a')
34+
cfg = config_loader.build_torax_config_from_file(example_config_path)
35+
logging.info('b')
36+
(
37+
sim_state,
38+
post_processed_outputs,
39+
step_fn,
40+
) = run_simulation.prepare_simulation(cfg)
41+
logging.info('c')
42+
params_provider = step_fn.runtime_params_provider
43+
input_value = params_provider.profile_conditions.Ip.value
44+
logging.info('d')
45+
46+
@jax.jit
47+
def f(override_value):
48+
ip_update = interpolated_param_1d.TimeVaryingScalarUpdate(
49+
value=override_value
50+
)
51+
runtime_params_overrides = params_provider.update_provider(
52+
lambda x: (x.profile_conditions.Ip,),
53+
(ip_update,),
54+
)
55+
_, new_post_processed_outputs = step_fn(
56+
sim_state,
57+
post_processed_outputs,
58+
runtime_params_overrides=runtime_params_overrides,
59+
)
60+
return new_post_processed_outputs.Q_fusion
61+
62+
logging.info('e')
63+
jtu.check_grads(f, (input_value,), order=1, modes=('rev',))
64+
logging.info('f')
65+
66+
67+
if __name__ == '__main__':
68+
absltest.main()

torax/_src/orchestration/tests/step_function_test.py

Lines changed: 29 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import dataclasses
15+
import os
1516
from typing import Any
1617

1718
from absl.testing import absltest
1819
from absl.testing import parameterized
1920
import chex
20-
import jax
2121
import jax.numpy as jnp
22-
import jax.test_util as jtu
2322
import numpy as np
2423
from torax._src import state
2524
from torax._src.config import config_loader
@@ -28,6 +27,7 @@
2827
from torax._src.orchestration import step_function
2928
from torax._src.output_tools import post_processing
3029
from torax._src.test_utils import default_configs
30+
from torax._src.test_utils import paths
3131
from torax._src.torax_pydantic import interpolated_param_1d
3232
from torax._src.torax_pydantic import model_config
3333

@@ -266,6 +266,24 @@ def test_adaptive_step_with_smaller_passed_max_dt(self):
266266
)
267267
self.assertTrue(np.less_equal(output_state.dt, passed_max_dt))
268268

269+
def test_fixed_step_with_high_density_errors_and_does_not_hang(self):
270+
# This test enforces that we exit the fixed step function early if we hit
271+
# min_dt. If we don't do this then we risk hanging for a very long time as
272+
# we stay at min_dt and the step never seems to make progress. This test
273+
# ensures that we don't hang and instead fail early.
274+
test_data_dir = paths.test_data_dir()
275+
torax_config = config_loader.build_torax_config_from_file(
276+
os.path.join(test_data_dir, 'test_iterhybrid_radiation_collapse.py')
277+
)
278+
sim_state, post_processed_outputs, step_fn = (
279+
run_simulation.prepare_simulation(torax_config)
280+
)
281+
sim_state, post_processed_outputs = step_fn.fixed_time_step(
282+
np.array(1.), sim_state, post_processed_outputs)
283+
284+
sim_error = step_fn.check_for_errors(sim_state, post_processed_outputs)
285+
self.assertEqual(sim_error, state.SimError.NAN_DETECTED)
286+
269287
def test_call_with_sawtooth_solver_smoke_test(self):
270288
"""Smoke test for the boolean logic around the sawtooth solver.
271289
@@ -338,49 +356,11 @@ def test_fixed_time_step_t_less_than_min_dt(self):
338356
)
339357
np.testing.assert_allclose(output_state.dt, 0.01, atol=1e-7)
340358

341-
@parameterized.parameters([
342-
'basic_config',
343-
'iterhybrid_predictor_corrector',
344-
])
345-
def test_step_function_grad(self, config_name_no_py):
346-
example_config_paths = config_loader.example_config_paths()
347-
example_config_path = example_config_paths[config_name_no_py]
348-
cfg = config_loader.build_torax_config_from_file(example_config_path)
349-
(
350-
sim_state,
351-
post_processed_outputs,
352-
step_fn,
353-
) = run_simulation.prepare_simulation(cfg)
354-
params_provider = step_fn.runtime_params_provider
355-
input_value = params_provider.profile_conditions.Ip.value
356-
357-
@jax.jit
358-
def f(override_value):
359-
ip_update = interpolated_param_1d.TimeVaryingScalarUpdate(
360-
value=override_value
361-
)
362-
runtime_params_overrides = params_provider.update_provider(
363-
lambda x: (x.profile_conditions.Ip,),
364-
(ip_update,),
365-
)
366-
_, new_post_processed_outputs = step_fn(
367-
sim_state,
368-
post_processed_outputs,
369-
runtime_params_overrides=runtime_params_overrides,
370-
)
371-
return new_post_processed_outputs.Q_fusion
372-
373-
jtu.check_grads(f, (input_value,), order=1, modes=('rev',))
374-
375-
@parameterized.parameters([
376-
'iterhybrid_predictor_corrector',
377-
'iterhybrid_rampup',
378-
])
379-
def test_step_function_overrides(self, config_name_no_py):
380-
example_config_paths = config_loader.example_config_paths()
381-
example_config_path = example_config_paths[config_name_no_py]
382-
raw_config = config_loader.import_module(example_config_path)['CONFIG']
383-
cfg = config_loader.build_torax_config_from_file(example_config_path)
359+
def test_step_function_overrides(self):
360+
original_ip = 15e6
361+
config_dict = default_configs.get_default_config_dict()
362+
config_dict['profile_conditions']['Ip'] = original_ip
363+
cfg = model_config.ToraxConfig.from_dict(config_dict)
384364
(
385365
sim_state,
386366
post_processed_outputs,
@@ -403,10 +383,7 @@ def test_step_function_overrides(self, config_name_no_py):
403383
)
404384

405385
# Update the config itself and re-run the step.
406-
doubled_ip = jax.tree_util.tree_map(
407-
lambda x: x * 2.0, raw_config['profile_conditions']['Ip']
408-
)
409-
cfg.update_fields({'profile_conditions.Ip': doubled_ip})
386+
cfg.update_fields({'profile_conditions.Ip': original_ip * 2.0})
410387
step_fn = run_simulation.make_step_fn(cfg)
411388
ref_state, ref_post_processed_outputs = step_fn(
412389
# Use original state and post-processed outputs as the initial value.
@@ -419,13 +396,9 @@ def test_step_function_overrides(self, config_name_no_py):
419396
override_post_processed_outputs, ref_post_processed_outputs
420397
)
421398

422-
@parameterized.parameters([
423-
('iterhybrid_rampup',),
424-
])
425-
def test_step_function_geo_overrides(self, config_name_no_py):
426-
example_config_paths = config_loader.example_config_paths()
427-
example_config_path = example_config_paths[config_name_no_py]
428-
cfg = config_loader.build_torax_config_from_file(example_config_path)
399+
def test_step_function_geo_overrides(self):
400+
config_dict = default_configs.get_default_config_dict()
401+
cfg = model_config.ToraxConfig.from_dict(config_dict)
429402
(
430403
sim_state,
431404
post_processed_outputs,

torax/tests/test_data/test_iterhybrid_radiation_collapse.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
'W': W_frac,
3232
}
3333
CONFIG['plasma_composition']['Z_eff'] = 3.0
34+
35+
# Remove QLKNN transport model to simplify step and avoid QLKNN load.
36+
CONFIG['transport'] = {}
37+
3438
CONFIG['sources']['impurity_radiation'] = {
3539
'model_name': 'mavrin_fit',
3640
}

0 commit comments

Comments
 (0)