1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import dataclasses
15+ import os
1516from typing import Any
1617
1718from absl .testing import absltest
1819from absl .testing import parameterized
1920import chex
20- import jax
2121import jax .numpy as jnp
22- import jax .test_util as jtu
2322import numpy as np
2423from torax ._src import state
2524from torax ._src .config import config_loader
2827from torax ._src .orchestration import step_function
2928from torax ._src .output_tools import post_processing
3029from torax ._src .test_utils import default_configs
30+ from torax ._src .test_utils import paths
3131from torax ._src .torax_pydantic import interpolated_param_1d
3232from torax ._src .torax_pydantic import model_config
3333
@@ -266,6 +266,24 @@ def test_adaptive_step_with_smaller_passed_max_dt(self):
266266 )
267267 self .assertTrue (np .less_equal (output_state .dt , passed_max_dt ))
268268
269+ def test_fixed_step_with_high_density_errors_and_does_not_hang (self ):
270+ # This test enforces that we exit the fixed step function early if we hit
271+ # min_dt. If we don't do this then we risk hanging for a very long time as
272+ # we stay at min_dt and the step never seems to make progress. This test
273+ # ensures that we don't hang and instead fail early.
274+ test_data_dir = paths .test_data_dir ()
275+ torax_config = config_loader .build_torax_config_from_file (
276+ os .path .join (test_data_dir , 'test_iterhybrid_radiation_collapse.py' )
277+ )
278+ sim_state , post_processed_outputs , step_fn = (
279+ run_simulation .prepare_simulation (torax_config )
280+ )
281+ sim_state , post_processed_outputs = step_fn .fixed_time_step (
282+ np .array (1. ), sim_state , post_processed_outputs )
283+
284+ sim_error = step_fn .check_for_errors (sim_state , post_processed_outputs )
285+ self .assertEqual (sim_error , state .SimError .NAN_DETECTED )
286+
269287 def test_call_with_sawtooth_solver_smoke_test (self ):
270288 """Smoke test for the boolean logic around the sawtooth solver.
271289
@@ -338,49 +356,11 @@ def test_fixed_time_step_t_less_than_min_dt(self):
338356 )
339357 np .testing .assert_allclose (output_state .dt , 0.01 , atol = 1e-7 )
340358
341- @parameterized .parameters ([
342- 'basic_config' ,
343- 'iterhybrid_predictor_corrector' ,
344- ])
345- def test_step_function_grad (self , config_name_no_py ):
346- example_config_paths = config_loader .example_config_paths ()
347- example_config_path = example_config_paths [config_name_no_py ]
348- cfg = config_loader .build_torax_config_from_file (example_config_path )
349- (
350- sim_state ,
351- post_processed_outputs ,
352- step_fn ,
353- ) = run_simulation .prepare_simulation (cfg )
354- params_provider = step_fn .runtime_params_provider
355- input_value = params_provider .profile_conditions .Ip .value
356-
357- @jax .jit
358- def f (override_value ):
359- ip_update = interpolated_param_1d .TimeVaryingScalarUpdate (
360- value = override_value
361- )
362- runtime_params_overrides = params_provider .update_provider (
363- lambda x : (x .profile_conditions .Ip ,),
364- (ip_update ,),
365- )
366- _ , new_post_processed_outputs = step_fn (
367- sim_state ,
368- post_processed_outputs ,
369- runtime_params_overrides = runtime_params_overrides ,
370- )
371- return new_post_processed_outputs .Q_fusion
372-
373- jtu .check_grads (f , (input_value ,), order = 1 , modes = ('rev' ,))
374-
375- @parameterized .parameters ([
376- 'iterhybrid_predictor_corrector' ,
377- 'iterhybrid_rampup' ,
378- ])
379- def test_step_function_overrides (self , config_name_no_py ):
380- example_config_paths = config_loader .example_config_paths ()
381- example_config_path = example_config_paths [config_name_no_py ]
382- raw_config = config_loader .import_module (example_config_path )['CONFIG' ]
383- cfg = config_loader .build_torax_config_from_file (example_config_path )
359+ def test_step_function_overrides (self ):
360+ original_ip = 15e6
361+ config_dict = default_configs .get_default_config_dict ()
362+ config_dict ['profile_conditions' ]['Ip' ] = original_ip
363+ cfg = model_config .ToraxConfig .from_dict (config_dict )
384364 (
385365 sim_state ,
386366 post_processed_outputs ,
@@ -403,10 +383,7 @@ def test_step_function_overrides(self, config_name_no_py):
403383 )
404384
405385 # Update the config itself and re-run the step.
406- doubled_ip = jax .tree_util .tree_map (
407- lambda x : x * 2.0 , raw_config ['profile_conditions' ]['Ip' ]
408- )
409- cfg .update_fields ({'profile_conditions.Ip' : doubled_ip })
386+ cfg .update_fields ({'profile_conditions.Ip' : original_ip * 2.0 })
410387 step_fn = run_simulation .make_step_fn (cfg )
411388 ref_state , ref_post_processed_outputs = step_fn (
412389 # Use original state and post-processed outputs as the initial value.
@@ -419,13 +396,9 @@ def test_step_function_overrides(self, config_name_no_py):
419396 override_post_processed_outputs , ref_post_processed_outputs
420397 )
421398
422- @parameterized .parameters ([
423- ('iterhybrid_rampup' ,),
424- ])
425- def test_step_function_geo_overrides (self , config_name_no_py ):
426- example_config_paths = config_loader .example_config_paths ()
427- example_config_path = example_config_paths [config_name_no_py ]
428- cfg = config_loader .build_torax_config_from_file (example_config_path )
399+ def test_step_function_geo_overrides (self ):
400+ config_dict = default_configs .get_default_config_dict ()
401+ cfg = model_config .ToraxConfig .from_dict (config_dict )
429402 (
430403 sim_state ,
431404 post_processed_outputs ,
0 commit comments