Skip to content

Commit f0a5f58

Browse files
tamaranormanTorax team
authored andcommitted
Exit fixed_step_fn if reach a min_dt when adaptive stepping to avoid hanging
If a min_dt was reached computations could hang as too many steps were performed and the end state was never reached Other errors such as NAN etc generally cause a MIN_DT on the next step so this is a bit of a catch-all for all the errors Check for the other errors first as MIN_DT is the least informative and can often be caused by other errors PiperOrigin-RevId: 852834152
1 parent f3a791e commit f0a5f58

File tree

2 files changed

+37
-14
lines changed

2 files changed

+37
-14
lines changed

torax/_src/orchestration/step_function.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -142,16 +142,6 @@ def check_for_errors(
142142
post_processed_outputs: post_processing.PostProcessedOutputs,
143143
) -> state.SimError:
144144
"""Checks for errors in the simulation state."""
145-
if self._runtime_params_provider.numerics.adaptive_dt:
146-
if output_state.solver_numeric_outputs.solver_error_state == 1:
147-
# Only check for min dt if the solver did not converge. Else we may have
148-
# converged at a dt > min_dt just before we reach min_dt.
149-
if (
150-
output_state.dt
151-
/ self._runtime_params_provider.numerics.dt_reduction_factor
152-
< self._runtime_params_provider.numerics.min_dt
153-
):
154-
return state.SimError.REACHED_MIN_DT
155145

156146
# Low-temperature collapse check
157147
if output_state.core_profiles.below_minimum_temperature(
@@ -162,8 +152,20 @@ def check_for_errors(
162152
state_error = output_state.check_for_errors()
163153
if state_error != state.SimError.NO_ERROR:
164154
return state_error
165-
else:
166-
return post_processed_outputs.check_for_errors()
155+
156+
post_processed_error = post_processed_outputs.check_for_errors()
157+
if post_processed_error != state.SimError.NO_ERROR:
158+
return post_processed_error
159+
160+
# Check if reached the minimum time step last - this is often caused by
161+
# other errors so check those first to give more informative error messages.
162+
if self._runtime_params_provider.numerics.adaptive_dt:
163+
if output_state.solver_numeric_outputs.solver_error_state == 1:
164+
# If using adaptive stepping and the solver did not converge we must
165+
# have reached the minimum time step, so we can exit the simulation.
166+
return state.SimError.REACHED_MIN_DT
167+
168+
return state.SimError.NO_ERROR
167169

168170
@jax.jit
169171
def __call__(
@@ -298,8 +300,13 @@ def fixed_time_step(
298300
remaining_dt = dt
299301

300302
def cond(args):
301-
remaining_dt, _, _ = args
302-
return remaining_dt > constants.CONSTANTS.eps
303+
remaining_dt, prev_state, _ = args
304+
if self.runtime_params_provider.numerics.adaptive_dt:
305+
exit_min_dt = prev_state.solver_numeric_outputs.solver_error_state == 1
306+
else:
307+
exit_min_dt = False
308+
return jnp.logical_and(
309+
remaining_dt > constants.CONSTANTS.eps, ~exit_min_dt)
303310

304311
def body(args):
305312
remaining_dt, prev_state, prev_post_processed = args

torax/_src/orchestration/tests/step_function_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import dataclasses
15+
import os
1516
from typing import Any
1617

1718
from absl.testing import absltest
@@ -28,6 +29,7 @@
2829
from torax._src.orchestration import step_function
2930
from torax._src.output_tools import post_processing
3031
from torax._src.test_utils import default_configs
32+
from torax._src.test_utils import paths
3133
from torax._src.torax_pydantic import interpolated_param_1d
3234
from torax._src.torax_pydantic import model_config
3335

@@ -266,6 +268,20 @@ def test_adaptive_step_with_smaller_passed_max_dt(self):
266268
)
267269
self.assertTrue(np.less_equal(output_state.dt, passed_max_dt))
268270

271+
def test_fixed_step_with_high_density_errors_and_does_not_hang(self):
272+
test_data_dir = paths.test_data_dir()
273+
config_dict = config_loader.build_torax_config_from_file(
274+
os.path.join(test_data_dir, 'test_iterhybrid_radiation_collapse.py')
275+
)
276+
sim_state, post_processed_outputs, step_fn = (
277+
run_simulation.prepare_simulation(config_dict)
278+
)
279+
sim_state, post_processed_outputs = step_fn.fixed_time_step(
280+
np.array(1.), sim_state, post_processed_outputs)
281+
282+
sim_error = step_fn.check_for_errors(sim_state, post_processed_outputs)
283+
self.assertEqual(sim_error, state.SimError.NAN_DETECTED)
284+
269285
def test_call_with_sawtooth_solver_smoke_test(self):
270286
"""Smoke test for the boolean logic around the sawtooth solver.
271287

0 commit comments

Comments
 (0)