Skip to content

Commit 541fc63

Browse files
committed
Add low temperature collapse detection
Refactor numerics: rename minimum_temperature_eV to T_minimum_eV, update checks and tests minor:resolving conflict minor fix: updated LowtemperatureCollapseTest Fixing : low temperature collapse test Fixing : low temperature collapse test Fixing: Code duplication Remove __init__.py from orchestration/tests Add low-temperature collapse detection and error handling - Add T_minimum_eV config parameter (default 50 eV) to numerics - Implement low_temperature_below() method in CoreProfiles - Add LOW_TEMPERATURE_COLLAPSE error state to SimError enum - Refactor check_for_errors() as SimulationStepFn class method - Add mock-based test for low-temperature collapse detection - Remove empty __init__.py from orchestration/tests
1 parent 2aeaa48 commit 541fc63

File tree

5 files changed

+95
-3
lines changed

5 files changed

+95
-3
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,7 @@ _build
2929

3030
# venv
3131
venv
32+
.venv
33+
34+
# Python version file
35+
.python-version

torax/_src/config/numerics.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import dataclasses
1818
import functools
1919
from typing import Annotated
20+
from typing import Optional
2021

2122
import chex
2223
import jax
@@ -126,6 +127,9 @@ class Numerics(torax_pydantic.BaseModelFrozen):
126127
adaptive_T_source_prefactor: pydantic.PositiveFloat = 2.0e10
127128
adaptive_n_source_prefactor: pydantic.PositiveFloat = 2.0e8
128129

130+
# NEW: Minimum allowed physical temperature (in eV)
131+
T_minimum_eV: pydantic.PositiveFloat = 5.0
132+
129133
@pydantic.model_validator(mode='after')
130134
def model_validation(self) -> Self:
131135
if self.t_initial > self.t_final:

torax/_src/orchestration/step_function.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def solver(self) -> solver_lib.Solver:
127127
@property
128128
def time_step_calculator(self) -> ts.TimeStepCalculator:
129129
return self._time_step_calculator
130-
130+
131131
def is_done(self, t: jax.Array) -> bool | jax.Array:
132132
return self._time_step_calculator.is_done(
133133
t=t,
@@ -142,6 +142,7 @@ def check_for_errors(
142142
) -> state.SimError:
143143
"""Checks for errors in the simulation state."""
144144
if self._runtime_params_provider.numerics.adaptive_dt:
145+
145146
if output_state.solver_numeric_outputs.solver_error_state == 1:
146147
# Only check for min dt if the solver did not converge. Else we may have
147148
# converged at a dt > min_dt just before we reach min_dt.
@@ -151,6 +152,13 @@ def check_for_errors(
151152
< self._runtime_params_provider.numerics.min_dt
152153
):
153154
return state.SimError.REACHED_MIN_DT
155+
156+
# Low-temperature collapse check
157+
if output_state.core_profiles.low_temperature_below(
158+
self._runtime_params_provider.numerics.T_minimum_eV
159+
):
160+
return state.SimError.LOW_TEMPERATURE_COLLAPSE
161+
154162
state_error = output_state.check_for_errors()
155163
if state_error != state.SimError.NO_ERROR:
156164
return state_error
@@ -556,4 +564,4 @@ def _fixed_step(
556564
input_post_processed_outputs=previous_post_processed_outputs,
557565
)
558566
)
559-
return output_state, post_processed_outputs
567+
return output_state, post_processed_outputs

torax/_src/orchestration/tests/initial_state_test.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,19 @@
1717
from torax._src.config import build_runtime_params
1818
from torax._src.orchestration import initial_state
1919
from torax._src.orchestration import run_simulation
20+
from torax._src.orchestration import step_function
2021
from torax._src.output_tools import output
2122
from torax._src.test_utils import core_profile_helpers
2223
from torax._src.test_utils import sim_test_case
2324
from torax._src.torax_pydantic import model_config
25+
from torax._src.orchestration import sim_state
26+
from torax._src.output_tools import post_processing
27+
import jax.numpy as jnp
28+
from torax._src import state
29+
from torax._src.config import numerics as numerics_lib
30+
from torax._src.fvm import cell_variable
31+
from torax._src.geometry import circular_geometry
32+
from unittest import mock
2433

2534
# pylint: disable=invalid-name
2635

@@ -127,5 +136,51 @@ def test_core_profile_final_step(self, test_config):
127136
)
128137

129138

139+
class LowTemperatureCollapseTest(sim_test_case.SimTestCase):
140+
"""Tests for low temperature collapse error detection."""
141+
142+
def test_low_temperature_triggers_error(self):
143+
"""Test that temperatures below threshold trigger LOW_TEMPERATURE_COLLAPSE error."""
144+
145+
# Create mock CoreProfiles
146+
mock_core_profiles = mock.MagicMock()
147+
mock_core_profiles.low_temperature_below.return_value = True
148+
149+
# Create mock output_state
150+
mock_output_state = mock.MagicMock()
151+
mock_output_state.core_profiles = mock_core_profiles
152+
mock_output_state.solver_numeric_outputs.solver_error_state = 0
153+
mock_output_state.check_for_errors.return_value = state.SimError.NO_ERROR
154+
155+
# Create mock post_processed_outputs
156+
mock_post_processed_outputs = mock.MagicMock()
157+
mock_post_processed_outputs.check_for_errors.return_value = state.SimError.NO_ERROR
158+
159+
# Create a real step_fn from a test config
160+
torax_config = self._get_torax_config('test_iterhybrid_rampup.py')
161+
real_step_fn = run_simulation.make_step_fn(torax_config)
162+
163+
# Mock the numerics to have our test value
164+
mock_numerics = mock.MagicMock()
165+
mock_numerics.T_minimum_eV = 50.0
166+
mock_numerics.adaptive_dt = False
167+
168+
# Replace the runtime_params_provider's numerics
169+
original_provider = real_step_fn._runtime_params_provider
170+
mock_provider = mock.MagicMock()
171+
mock_provider.numerics = mock_numerics
172+
real_step_fn._runtime_params_provider = mock_provider
173+
174+
# Call check_for_errors
175+
error = real_step_fn.check_for_errors(mock_output_state, mock_post_processed_outputs)
176+
177+
# Restore original provider
178+
real_step_fn._runtime_params_provider = original_provider
179+
180+
# Assert that LOW_TEMPERATURE_COLLAPSE error is detected
181+
self.assertEqual(error, state.SimError.LOW_TEMPERATURE_COLLAPSE)
182+
# Verify that low_temperature_below was called with the right threshold
183+
mock_core_profiles.low_temperature_below.assert_called_once_with(50.0)
184+
130185
if __name__ == '__main__':
131-
absltest.main()
186+
absltest.main()

torax/_src/state.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,19 @@ def negative_temperature_or_density(self) -> jax.Array:
167167
])
168168
)
169169

170+
def low_temperature_below(self, T_minimum_eV: float) -> bool:
171+
"""Return True if T_e or T_i fall below the minimum temperature threshold."""
172+
# Convert eV → keV since internal storage is keV
173+
te_min_keV = T_minimum_eV / 1000.0
174+
175+
return np.any(
176+
np.array([
177+
np.any(np.less(self.T_e.value, te_min_keV)),
178+
np.any(np.less(self.T_i.value, te_min_keV)),
179+
])
180+
)
181+
182+
170183
def __str__(self) -> str:
171184
return f"""
172185
CoreProfiles(
@@ -292,6 +305,8 @@ class SimError(enum.Enum):
292305
QUASINEUTRALITY_BROKEN = 2
293306
NEGATIVE_CORE_PROFILES = 3
294307
REACHED_MIN_DT = 4
308+
LOW_TEMPERATURE_COLLAPSE = 5
309+
295310

296311
def log_error(self):
297312
match self:
@@ -320,6 +335,12 @@ def log_error(self):
320335
quasineutrality. Check the output file for near-zero temperatures or
321336
densities at the last valid step.
322337
""")
338+
case SimError.LOW_TEMPERATURE_COLLAPSE:
339+
logging.error("""
340+
Simulation stopped because electron temperature fell below the configured minimum
341+
threshold (Te_min). This is usually caused by radiative collapse or runaway
342+
cooling. Output file contains all profiles up to the last valid step
343+
""")
323344
case SimError.NO_ERROR:
324345
pass
325346
case _:

0 commit comments

Comments
 (0)