Skip to content

Commit 42958fb

Browse files
theo-brownTorax team
authored andcommitted
Add predictive L-H transition based on the Martin scaling.
* Add from_pedestal_model saturation model, which increases transport in the pedestal if the requested pedestal top values are exceeded, reducing the pedestal height. * Add Martin scaling formation model, which reduces the transport in the pedestal if the LH threshold is crossed, increasing the pedestal height. * Introduce ADAPTIVE_SOURCE and ADAPTIVE_TRANSPORT modes. If in ADAPTIVE transport mode, the formation and saturation models are used. * Add integration test: an ITER-like case with an L-H transition. PiperOrigin-RevId: 874140482
1 parent 3f83f89 commit 42958fb

22 files changed

+1276
-71
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2026 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2026 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Base class for pedestal formation models."""
16+
17+
import abc
18+
import dataclasses
19+
from torax._src import state
20+
from torax._src import static_dataclass
21+
from torax._src.config import runtime_params as runtime_params_lib
22+
from torax._src.geometry import geometry
23+
from torax._src.pedestal_model import pedestal_model_output
24+
from torax._src.sources import source_profiles as source_profiles_lib
25+
26+
27+
@dataclasses.dataclass(frozen=True, eq=False)
28+
class FormationModel(static_dataclass.StaticDataclass, abc.ABC):
29+
"""Base class for pedestal formation models."""
30+
31+
@abc.abstractmethod
32+
def __call__(
33+
self,
34+
runtime_params: runtime_params_lib.RuntimeParams,
35+
geo: geometry.Geometry,
36+
core_profiles: state.CoreProfiles,
37+
source_profiles: source_profiles_lib.SourceProfiles,
38+
) -> pedestal_model_output.TransportMultipliers:
39+
"""Calculates the transport decrease multipliers.
40+
41+
Args:
42+
runtime_params: Runtime parameters.
43+
geo: Geometry.
44+
core_profiles: Core profiles.
45+
source_profiles: Source profiles.
46+
47+
Returns:
48+
transport_decrease_multiplier: Factors to multiply transport coefficients
49+
by (<= 1.0).
50+
"""
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2026 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Martin scaling pedestal formation model."""
16+
17+
import dataclasses
18+
import jax
19+
import jax.numpy as jnp
20+
from torax._src import array_typing
21+
from torax._src import math_utils
22+
from torax._src import state
23+
from torax._src.config import runtime_params as runtime_params_lib
24+
from torax._src.geometry import geometry
25+
from torax._src.pedestal_model import pedestal_model_output
26+
from torax._src.pedestal_model import runtime_params as pedestal_runtime_params_lib
27+
from torax._src.pedestal_model.formation import base
28+
from torax._src.physics import scaling_laws
29+
from torax._src.sources import source_profiles as source_profiles_lib
30+
31+
# pylint: disable=invalid-name
32+
33+
34+
@jax.tree_util.register_dataclass
35+
@dataclasses.dataclass(frozen=True)
36+
class MartinFormationRuntimeParams(
37+
pedestal_runtime_params_lib.FormationRuntimeParams
38+
):
39+
"""Runtime params for pedestal formation models."""
40+
41+
P_LH_prefactor: array_typing.FloatScalar = 1.0
42+
43+
44+
def _calculate_P_SOL_total(
45+
core_sources: source_profiles_lib.SourceProfiles,
46+
geo: geometry.Geometry,
47+
) -> jax.Array:
48+
"""Calculates the total power out of the separatrix."""
49+
P_SOL_e = sum(
50+
math_utils.volume_integration(source, geo)
51+
for source in core_sources.T_e.values()
52+
)
53+
P_SOL_i = sum(
54+
math_utils.volume_integration(source, geo)
55+
for source in core_sources.T_i.values()
56+
)
57+
# TODO(b/488318267): Missing dW/dt terms.
58+
return P_SOL_e + P_SOL_i
59+
60+
61+
@dataclasses.dataclass(frozen=True, eq=False)
62+
class MartinFormationModel(base.FormationModel):
63+
"""Pedestal formation based on P_SOL and P_LH, using Martin scaling."""
64+
65+
def __call__(
66+
self,
67+
runtime_params: runtime_params_lib.RuntimeParams,
68+
geo: geometry.Geometry,
69+
core_profiles: state.CoreProfiles,
70+
core_sources: source_profiles_lib.SourceProfiles,
71+
) -> pedestal_model_output.TransportMultipliers:
72+
"""Calculates the transport decrease multipliers using Martin scaling."""
73+
assert isinstance(
74+
runtime_params.pedestal.formation, MartinFormationRuntimeParams
75+
)
76+
77+
P_SOL_total = _calculate_P_SOL_total(core_sources, geo)
78+
_, _, P_LH, _ = scaling_laws.calculate_plh_scaling_factor(
79+
geo, core_profiles
80+
)
81+
rescaled_P_LH = P_LH * runtime_params.pedestal.formation.P_LH_prefactor
82+
83+
# Calculate transport_multiplier
84+
# If P_SOL > P_LH, multiplier tends to 0.0
85+
# If P_SOL < P_LH, multiplier tends to 1.0
86+
# TODO(b/323504363): Add hysteresis logic here later
87+
width = runtime_params.pedestal.formation.sigmoid_width
88+
exponent = runtime_params.pedestal.formation.sigmoid_exponent
89+
offset = runtime_params.pedestal.formation.sigmoid_offset
90+
normalized_deviation = (
91+
P_SOL_total - rescaled_P_LH
92+
) / rescaled_P_LH - offset
93+
transport_multiplier = 1 - jax.nn.sigmoid(normalized_deviation / width)
94+
transport_multiplier = transport_multiplier**exponent
95+
transport_multiplier = jnp.clip(
96+
transport_multiplier,
97+
min=runtime_params.pedestal.min_transport_multiplier,
98+
max=runtime_params.pedestal.max_transport_multiplier,
99+
)
100+
101+
return pedestal_model_output.TransportMultipliers(
102+
chi_e_multiplier=transport_multiplier,
103+
chi_i_multiplier=transport_multiplier,
104+
D_e_multiplier=transport_multiplier,
105+
v_e_multiplier=transport_multiplier,
106+
)
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright 2026 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import dataclasses
16+
from absl.testing import absltest
17+
from absl.testing import parameterized
18+
import jax.numpy as jnp
19+
import numpy as np
20+
from torax._src.orchestration import initial_state
21+
from torax._src.orchestration import run_simulation
22+
from torax._src.pedestal_model.formation import martin_formation_model
23+
from torax._src.test_utils import default_configs
24+
from torax._src.torax_pydantic import model_config
25+
26+
# pylint: disable=invalid-name
27+
28+
29+
class MartinFormationModelTest(parameterized.TestCase):
30+
31+
def setUp(self):
32+
super().setUp()
33+
config = default_configs.get_default_config_dict()
34+
# Switch to use the Martin formation model.
35+
config['pedestal'] = {
36+
'set_pedestal': True,
37+
'mode': 'ADAPTIVE_TRANSPORT',
38+
'formation_model': {'model_name': 'martin'},
39+
}
40+
# Add a source so that P_SOL is non-zero.
41+
config['sources'] = {
42+
'generic_heat': {
43+
'gaussian_location': 0.15,
44+
'gaussian_width': 0.1,
45+
'P_total': 20.0e6,
46+
'electron_heat_fraction': 0.8,
47+
}
48+
}
49+
self.torax_config = model_config.ToraxConfig.from_dict(config)
50+
step_fn = run_simulation.make_step_fn(self.torax_config)
51+
self.initial_state, self.initial_post_processed_outputs = (
52+
initial_state.get_initial_state_and_post_processed_outputs(step_fn)
53+
)
54+
self.runtime_params = step_fn.runtime_params_provider(t=0.0)
55+
56+
def test_calculate_P_SOL_total(self):
57+
P_SOL_total = martin_formation_model._calculate_P_SOL_total(
58+
self.initial_state.core_sources, self.initial_state.geometry
59+
)
60+
61+
np.testing.assert_allclose(
62+
P_SOL_total, self.initial_post_processed_outputs.P_SOL_total
63+
)
64+
65+
@parameterized.named_parameters(
66+
dict(
67+
# If P_sol >> P_LH, we expect the suppression multiplier to be very
68+
# small (significant suppression). However, it's clipped internally to
69+
# be 0.1.
70+
testcase_name='above_threshold',
71+
power=1e6,
72+
expected_multiplier=0.1,
73+
),
74+
dict(
75+
# If P_sol << P_LH, we expect the suppression multiplier to be 1.0
76+
# (no suppression).
77+
testcase_name='below_threshold',
78+
# We set the aux power to be negative (equivalent to a heat sink)
79+
# to make sure that P_sol < P_LH.
80+
power=-1e6,
81+
expected_multiplier=1.0,
82+
),
83+
)
84+
def test_martin_formation_model_suppression(self, power, expected_multiplier):
85+
formation_model = martin_formation_model.MartinFormationModel()
86+
87+
aux_power_profile = power * jnp.ones_like(self.initial_state.geometry.rho)
88+
high_power_profiles = dataclasses.replace(
89+
self.initial_state.core_sources,
90+
T_e={'aux': aux_power_profile},
91+
T_i={'aux': aux_power_profile},
92+
)
93+
94+
transport_multipliers = formation_model(
95+
self.runtime_params,
96+
self.initial_state.geometry,
97+
self.initial_state.core_profiles,
98+
high_power_profiles,
99+
)
100+
for k, multiplier in dataclasses.asdict(transport_multipliers).items():
101+
np.testing.assert_allclose(
102+
multiplier,
103+
expected_multiplier,
104+
atol=1e-3,
105+
err_msg=(
106+
f'{k}={multiplier} is not close to the expected value of'
107+
f' {expected_multiplier}.'
108+
),
109+
)
110+
111+
112+
if __name__ == '__main__':
113+
absltest.main()

torax/_src/pedestal_model/pedestal_model.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
from torax._src.config import runtime_params as runtime_params_lib
2828
from torax._src.geometry import geometry
2929
from torax._src.pedestal_model import pedestal_model_output
30+
from torax._src.pedestal_model import runtime_params as pedestal_runtime_params_lib
31+
from torax._src.pedestal_model.formation import base as formation_base
32+
from torax._src.pedestal_model.saturation import base as saturation_base
3033
from torax._src.sources import source_profiles as source_profiles_lib
3134

3235
# pylint: disable=invalid-name
@@ -37,16 +40,48 @@
3740
class PedestalModel(static_dataclass.StaticDataclass, abc.ABC):
3841
"""Calculates temperature and density of the pedestal."""
3942

43+
formation_model: formation_base.FormationModel
44+
saturation_model: saturation_base.SaturationModel
45+
4046
def __call__(
4147
self,
4248
runtime_params: runtime_params_lib.RuntimeParams,
4349
geo: geometry.Geometry,
4450
core_profiles: state.CoreProfiles,
4551
source_profiles: source_profiles_lib.SourceProfiles,
4652
) -> pedestal_model_output.PedestalModelOutput:
53+
pedestal_output = self._call_implementation(
54+
runtime_params, geo, core_profiles
55+
)
56+
57+
# If in ADAPTIVE_TRANSPORT mode, calculate the transport multipliers based
58+
# on the formation and saturation models.
59+
if (
60+
runtime_params.pedestal.mode
61+
== pedestal_runtime_params_lib.Mode.ADAPTIVE_TRANSPORT
62+
):
63+
transport_decrease_multipliers = self.formation_model(
64+
runtime_params, geo, core_profiles, source_profiles
65+
)
66+
transport_increase_multipliers = self.saturation_model(
67+
runtime_params, geo, core_profiles, pedestal_output
68+
)
69+
70+
# Combine via exp(log) for numerical stability, as multipliers can
71+
# be very small or large.
72+
transport_multipliers = jax.tree.map(
73+
lambda x, y: jnp.exp(jnp.log(x) + jnp.log(y)),
74+
transport_decrease_multipliers,
75+
transport_increase_multipliers,
76+
)
77+
78+
pedestal_output = dataclasses.replace(
79+
pedestal_output, transport_multipliers=transport_multipliers
80+
)
81+
4782
return jax.lax.cond(
4883
runtime_params.pedestal.set_pedestal,
49-
lambda: self._call_implementation(runtime_params, geo, core_profiles),
84+
lambda: pedestal_output,
5085
lambda: pedestal_model_output.PedestalModelOutput(
5186
rho_norm_ped_top=jnp.inf,
5287
rho_norm_ped_top_idx=geo.torax_mesh.nx,

0 commit comments

Comments
 (0)