Skip to content

Commit cc9ccb7

Browse files
theo-brownTorax team
authored andcommitted
Improve numerics of LH transition test scenario.
Main changes: - Switch from chi to fixed timestep calculator - Add smoothing to chi profiles - Remove resistivity multiplier, which was causing numerical instability - Switch to COMBINED transport model for clarity PiperOrigin-RevId: 878555971
1 parent 23d0c01 commit cc9ccb7

14 files changed

+270
-215
lines changed

torax/_src/pedestal_model/formation/martin_formation_model.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import dataclasses
1818
import jax
19-
import jax.numpy as jnp
2019
from torax._src import array_typing
2120
from torax._src import math_utils
2221
from torax._src import state
@@ -86,19 +85,13 @@ def __call__(
8685
# If P_SOL > P_LH, multiplier tends to 0.0
8786
# If P_SOL < P_LH, multiplier tends to 1.0
8887
# TODO(b/488393318): Add hysteresis to the LH-HL transition.
89-
width = runtime_params.pedestal.formation.sigmoid_width
90-
exponent = runtime_params.pedestal.formation.sigmoid_exponent
91-
offset = runtime_params.pedestal.formation.sigmoid_offset
92-
normalized_deviation = (
93-
P_SOL_total - rescaled_P_LH
94-
) / rescaled_P_LH - offset
95-
transport_multiplier = 1 - jax.nn.sigmoid(normalized_deviation / width)
96-
transport_multiplier = transport_multiplier**exponent
97-
transport_multiplier = jnp.clip(
98-
transport_multiplier,
99-
min=runtime_params.pedestal.min_transport_multiplier,
100-
max=runtime_params.pedestal.max_transport_multiplier,
101-
)
88+
sharpness = runtime_params.pedestal.formation.sharpness
89+
offset = runtime_params.pedestal.formation.offset
90+
base_multiplier = runtime_params.pedestal.formation.base_multiplier
91+
normalized_deviation = (P_SOL_total - rescaled_P_LH) / rescaled_P_LH
92+
shifted_deviation = normalized_deviation - offset
93+
alpha = jax.nn.sigmoid(shifted_deviation * sharpness)
94+
transport_multiplier = (1.0 - alpha) * 1.0 + alpha * base_multiplier
10295

10396
return pedestal_model_output.TransportMultipliers(
10497
chi_e_multiplier=transport_multiplier,

torax/_src/pedestal_model/pydantic_model.py

Lines changed: 42 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -38,36 +38,35 @@ class MartinFormation(torax_pydantic.BaseModelFrozen):
3838
"""Configuration for Martin formation model.
3939
4040
This formation model triggers a reduction in pedestal transport when P_SOL >
41-
P_LH, where P_LH is from the Martin scaling. The reduction is a smooth sigmoid
42-
function of the ratio P_SOL / (P_LH * P_LH_prefactor).
41+
P_LH, where P_LH is from the Martin scaling. The reduction is a multiplicative
42+
factors between 1.0 and base_multiplier.
4343
4444
The formula is
45-
rescaled_P_LH = P_LH * P_LH_prefactor
46-
normalized_deviation = (P_SOL - rescaled_P_LH) / rescaled_P_LH - offset
47-
transport_multiplier = 1 - sigmoid(normalized_deviation / width)
48-
transport_multiplier = transport_multiplier**exponent
49-
50-
The transport multiplier is later clipped to the range
51-
[min_transport_multiplier, max_transport_multiplier].
45+
transport_multiplier = (1.0 - alpha) * 1.0 + alpha * base_multiplier
46+
where alpha is a smooth sigmoid function of
47+
(P_SOL - P_LH * P_LH_prefactor) / (P_LH * P_LH_prefactor)
48+
with given sharpness and offset.
5249
5350
Attributes:
54-
sigmoid_width: Dimensionless width of the sigmoid function for smoothing the
55-
formation window. Increase for a smoother L-H transition, but doing so may
56-
lead to starting the L-H transition at a power below P_LH.
57-
sigmoid_offset: Dimensionless offset of sigmoid function from P_LH / P_SOL =
58-
1. Increase to start the L-H transition at a higher P_SOL / P_LH ratio.
59-
sigmoid_exponent: The exponent of the transport multiplier. Increase for a
60-
faster reduction in transport once the L-H transition starts.
51+
sharpness: The steepness of the smooth formation window. Decrease for a
52+
smoother formation, which may be more numerically stable but may lead to
53+
starting formation at a temperature or density below the target values.
54+
offset: Dimensionless offset of the formation window. Increase to start
55+
formation at a higher P_SOL.
56+
base_multiplier: The base value of the transport multiplier. Increase for
57+
stronger decreases in transport once formation starts.
6158
P_LH_prefactor: Dimensionless multiplier for P_LH. Increase to scale up
6259
P_LH, and therefore start the L-H transition at a higher P_SOL.
6360
"""
6461

6562
model_name: Annotated[Literal["martin"], torax_pydantic.JAX_STATIC] = "martin"
66-
sigmoid_width: pydantic.PositiveFloat = 1e-3
67-
sigmoid_offset: Annotated[
63+
sharpness: pydantic.PositiveFloat = 1e3
64+
offset: Annotated[
6865
array_typing.FloatScalar, pydantic.Field(ge=-10.0, le=10.0)
6966
] = 0.0
70-
sigmoid_exponent: pydantic.PositiveFloat = 3.0
67+
base_multiplier: Annotated[
68+
array_typing.FloatScalar, pydantic.Field(gt=0.0, le=1.0)
69+
] = 0.1
7170
P_LH_prefactor: pydantic.PositiveFloat = 1.0
7271

7372
def build_formation_model(
@@ -80,9 +79,9 @@ def build_runtime_params(
8079
) -> martin_formation_model.MartinFormationRuntimeParams:
8180
del t
8281
return martin_formation_model.MartinFormationRuntimeParams(
83-
sigmoid_width=self.sigmoid_width,
84-
sigmoid_offset=self.sigmoid_offset,
85-
sigmoid_exponent=self.sigmoid_exponent,
82+
sharpness=self.sharpness,
83+
offset=self.offset,
84+
base_multiplier=self.base_multiplier,
8685
P_LH_prefactor=self.P_LH_prefactor,
8786
)
8887

@@ -92,36 +91,34 @@ class ProfileValueSaturation(torax_pydantic.BaseModelFrozen):
9291
9392
This saturation model triggers an increase in pedestal transport when the
9493
pedestal temperature and density are above the values requested by the
95-
pedestal model. The increase is a smooth sigmoid function of the ratio of the
94+
pedestal model. The increase is a smooth linear function of the ratio of the
9695
current value to the value requested by the pedestal model.
9796
9897
The formula is
99-
normalized_deviation = (current - target) / target - offset
100-
transport_multiplier = 1 / (1 - sigmoid(normalized_deviation / width))
101-
transport_multiplier = transport_multiplier**exponent
102-
103-
The transport multiplier is then clipped to the range
104-
[min_transport_multiplier, max_transport_multiplier].
98+
transport_multiplier = 1 + base_multiplier * softplus(actual, target)
99+
where softplus(actual, target) is a smooth ReLU function of the deviation
100+
from the target value, with given sharpness and offset.
105101
106102
Attributes:
107-
sigmoid_width: Dimensionless width of the sigmoid function for smoothing the
108-
saturation window. Increase for a smoother saturation, but doing so may
109-
lead to starting saturation at a temperature or density below the target
110-
values.
111-
sigmoid_offset: Dimensionless offset of the saturation window. Increase to
112-
start saturation at a higher temperature or density.
113-
sigmoid_exponent: The exponent of the transport multiplier. Increase for a
114-
faster increase in transport once saturation starts.
103+
sharpness: The steepness of the smooth saturation function. Decrease for a a
104+
smoother saturation, which may be more numerically stable but may lead to
105+
starting saturation at a temperature or density below the target values.
106+
offset: Dimensionless offset of the saturation window. Increase to start
107+
saturation at a higher temperature or density.
108+
base_multiplier: The base value of the transport multiplier. Increase for
109+
stronger increases in transport once saturation starts.
115110
"""
116111

117112
model_name: Annotated[Literal["profile_value"], torax_pydantic.JAX_STATIC] = (
118113
"profile_value"
119114
)
120-
sigmoid_width: pydantic.PositiveFloat = 0.1
121-
sigmoid_offset: Annotated[
115+
sharpness: pydantic.PositiveFloat = 1e3
116+
offset: Annotated[
122117
array_typing.FloatScalar, pydantic.Field(ge=-10.0, le=10.0)
123118
] = 0.0
124-
sigmoid_exponent: pydantic.PositiveFloat = 1.0
119+
base_multiplier: Annotated[
120+
array_typing.FloatScalar, pydantic.Field(gt=1.0)
121+
] = 10.0
125122

126123
def build_saturation_model(
127124
self,
@@ -133,9 +130,9 @@ def build_runtime_params(
133130
) -> runtime_params.SaturationRuntimeParams:
134131
del t
135132
return runtime_params.SaturationRuntimeParams(
136-
sigmoid_width=self.sigmoid_width,
137-
sigmoid_offset=self.sigmoid_offset,
138-
sigmoid_exponent=self.sigmoid_exponent,
133+
sharpness=self.sharpness,
134+
offset=self.offset,
135+
base_multiplier=self.base_multiplier,
139136
)
140137

141138

@@ -170,10 +167,10 @@ class BasePedestal(torax_pydantic.BaseModelFrozen, abc.ABC):
170167
ProfileValueSaturation()
171168
)
172169
max_transport_multiplier: torax_pydantic.TimeVaryingScalar = (
173-
torax_pydantic.ValidatedDefault(10.0)
170+
torax_pydantic.ValidatedDefault(1e3)
174171
)
175172
min_transport_multiplier: torax_pydantic.TimeVaryingScalar = (
176-
torax_pydantic.ValidatedDefault(0.1)
173+
torax_pydantic.ValidatedDefault(1e-3)
177174
)
178175

179176
@pydantic.model_validator(mode="before")

torax/_src/pedestal_model/runtime_params.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,19 @@ class Mode(enum.Enum):
3737
class FormationRuntimeParams:
3838
"""Runtime params for pedestal formation models."""
3939

40-
sigmoid_width: array_typing.FloatScalar = 0.1
41-
sigmoid_offset: array_typing.FloatScalar = 0.0
42-
sigmoid_exponent: array_typing.FloatScalar = 1.0
40+
sharpness: array_typing.FloatScalar
41+
offset: array_typing.FloatScalar
42+
base_multiplier: array_typing.FloatScalar
4343

4444

4545
@jax.tree_util.register_dataclass
4646
@dataclasses.dataclass(frozen=True)
4747
class SaturationRuntimeParams:
4848
"""Runtime params for pedestal saturation models."""
4949

50-
sigmoid_width: array_typing.FloatScalar = 0.1
51-
sigmoid_offset: array_typing.FloatScalar = 0.0
52-
sigmoid_exponent: array_typing.FloatScalar = 1.0
50+
sharpness: array_typing.FloatScalar
51+
offset: array_typing.FloatScalar
52+
base_multiplier: array_typing.FloatScalar
5353

5454

5555
@jax.tree_util.register_dataclass

torax/_src/pedestal_model/saturation/profile_value_saturation_model.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,11 @@ def _calculate_multiplier(
9191
Returns:
9292
The transport increase multiplier.
9393
"""
94-
width = pedestal_runtime_params.saturation.sigmoid_width
95-
exponent = pedestal_runtime_params.saturation.sigmoid_exponent
96-
offset = pedestal_runtime_params.saturation.sigmoid_offset
94+
sharpness = pedestal_runtime_params.saturation.sharpness
95+
offset = pedestal_runtime_params.saturation.offset
96+
base_multiplier = pedestal_runtime_params.saturation.base_multiplier
9797
normalized_deviation = (current - target) / target - offset
98-
transport_multiplier = 1 / (
99-
1 - jax.nn.sigmoid(normalized_deviation / width)
98+
transport_multiplier = 1 + base_multiplier * jax.nn.softplus(
99+
normalized_deviation * sharpness
100100
)
101-
transport_multiplier = transport_multiplier**exponent
102-
transport_multiplier = jnp.clip(
103-
transport_multiplier,
104-
min=pedestal_runtime_params.min_transport_multiplier,
105-
max=pedestal_runtime_params.max_transport_multiplier,
106-
)
107-
108101
return transport_multiplier

torax/_src/pedestal_model/saturation/tests/profile_value_saturation_model_test.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import dataclasses
1514
from absl.testing import absltest
1615
from absl.testing import parameterized
1716
import numpy as np
18-
from torax._src import constants
1917
from torax._src.config import build_runtime_params
2018
from torax._src.core_profiles import initialization
2119
from torax._src.pedestal_model import pedestal_model_output
@@ -49,43 +47,35 @@ def setUp(self):
4947
@parameterized.named_parameters(
5048
dict(
5149
testcase_name='active',
52-
# P_current >> P_target -> saturation is active.
53-
P_current_over_P_target=1e3,
50+
# T_current >> T_target -> saturation is active.
51+
T_target_over_T_current=1e-3,
5452
),
5553
dict(
5654
testcase_name='inactive',
57-
# P_current << P_target -> no saturation.
58-
P_current_over_P_target=1e-3,
55+
# T_current << T_target -> no saturation.
56+
T_target_over_T_current=1e3,
5957
),
6058
)
6159
def test_saturation_multiplier(
6260
self,
63-
P_current_over_P_target,
61+
T_target_over_T_current,
6462
):
6563
saturation_model = (
6664
profile_value_saturation_model.ProfileValueSaturationModel()
6765
)
6866

6967
# For this test, we put the pedestal top at the last grid point.
7068
ped_top_idx = -1
71-
72-
# Get the current pressure at the pedestal top.
73-
current_pressure = self.core_profiles.pressure_thermal_total.value[
74-
ped_top_idx
75-
]
69+
current_T_e_ped = self.core_profiles.T_e.face_value()[ped_top_idx]
7670

7771
# Construct a pedestal output that is asking for a pedestal with
78-
# target_pressure.
79-
target_pressure = current_pressure / P_current_over_P_target
72+
# target temperature.
8073
pedestal_output = pedestal_model_output.PedestalModelOutput(
8174
rho_norm_ped_top=self.geo.rho_face[ped_top_idx],
8275
rho_norm_ped_top_idx=ped_top_idx,
83-
# Set T_i_ped, T_e_ped, and n_e_ped such that target_pressure is
84-
# achieved. This case has n_impurity = 0 and Z_i = 1, so
85-
# P = (T_i + T_e)*keV_to_J*n_e.
86-
T_i_ped=(1.0 / constants.CONSTANTS.keV_to_J) / 2,
87-
T_e_ped=(1.0 / constants.CONSTANTS.keV_to_J) / 2,
88-
n_e_ped=target_pressure,
76+
T_i_ped=1.0,
77+
T_e_ped=current_T_e_ped * T_target_over_T_current,
78+
n_e_ped=1.0,
8979
)
9080

9181
transport_multipliers = saturation_model(
@@ -95,16 +85,14 @@ def test_saturation_multiplier(
9585
pedestal_output,
9686
)
9787

98-
if P_current_over_P_target > 1.0:
99-
# If the current pressure is above the target pressure, we expect the
100-
# multiplier to be greater than 1.0.
101-
for multiplier in dataclasses.asdict(transport_multipliers).values():
102-
self.assertGreater(multiplier, 1.0)
88+
if T_target_over_T_current > 1.0:
89+
# If the target temperature is above the current temperature, we expect
90+
# the multiplier to be equal to 1.0 - the pedestal is not saturated.
91+
np.testing.assert_allclose(transport_multipliers.chi_e_multiplier, 1.0)
10392
else:
104-
# If the current pressure is below the target pressure, we expect the
105-
# multiplier to be 1.0.
106-
for multiplier in dataclasses.asdict(transport_multipliers).values():
107-
np.testing.assert_allclose(multiplier, 1.0, rtol=1e-3)
93+
# If the target temperature is below the current temperature, we expect
94+
# the multiplier to be greater than 1.0 - the pedestal is saturated.
95+
self.assertGreater(transport_multipliers.chi_e_multiplier, 1.0)
10896

10997

11098
if __name__ == '__main__':

torax/_src/transport_model/combined.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,8 @@ def __call__(
8484
# should be handled by instantiating constant component models instead.
8585
# However, the rho_inner and rho_outer arguments are currently required
8686
# in the case where the inner/outer region are to be excluded from
87-
# smoothing. Smoothing is applied to
88-
# rho_inner < rho_norm < min(rho_ped_top, rho_outer) unless
89-
# smooth_everywhere is True.
90-
return self._smooth_coeffs(
91-
runtime_params,
92-
geo,
93-
transport_coeffs,
94-
pedestal_model_output,
95-
)
87+
# smoothing.
88+
return transport_coeffs
9689

9790
def call_implementation(
9891
self,

torax/_src/transport_model/tests/transport_model_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,13 @@ def test_smoothing_everywhere(self):
309309
core_profiles,
310310
pedestal_model_outputs,
311311
)
312+
313+
# Apply the smoothing
314+
transport_coeffs = transport_model.smooth_coeffs(
315+
runtime_params, geo, transport_coeffs, pedestal_model_outputs
316+
)
317+
318+
# Set up original transport coefficients for comparison
312319
inner_patch_idx = np.searchsorted(
313320
geo.rho_face_norm, runtime_params.transport.rho_inner
314321
)

torax/_src/transport_model/transport_coefficients_builder.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def calculate_all_transport_coeffs(
8585

8686
# Modify the turbulent + Pereverzev transport coefficients if the pedestal
8787
# model is in ADAPTIVE_TRANSPORT mode.
88-
# TODO(b/488980968): Identify speed issue with ADAPTIVE_TRANSPORT mode.
8988
if (
9089
runtime_params.pedestal.mode
9190
== pedestal_runtime_params_lib.Mode.ADAPTIVE_TRANSPORT
@@ -111,4 +110,35 @@ def calculate_all_transport_coeffs(
111110
**dataclasses.asdict(pereverzev_transport_coeffs),
112111
)
113112

113+
# TODO(b/485147781) Clean up this post-processing.
114+
# Apply smoothing to the final turbulent transport coefficients
115+
turbulent_transport_coeffs = transport_model_lib.TurbulentTransport(
116+
chi_face_ion=core_transport.chi_face_ion,
117+
chi_face_el=core_transport.chi_face_el,
118+
d_face_el=core_transport.d_face_el,
119+
v_face_el=core_transport.v_face_el,
120+
chi_face_ion_bohm=core_transport.chi_face_ion_bohm,
121+
chi_face_ion_gyrobohm=core_transport.chi_face_ion_gyrobohm,
122+
chi_face_ion_itg=core_transport.chi_face_ion_itg,
123+
chi_face_ion_tem=core_transport.chi_face_ion_tem,
124+
chi_face_el_bohm=core_transport.chi_face_el_bohm,
125+
chi_face_el_gyrobohm=core_transport.chi_face_el_gyrobohm,
126+
chi_face_el_itg=core_transport.chi_face_el_itg,
127+
chi_face_el_tem=core_transport.chi_face_el_tem,
128+
chi_face_el_etg=core_transport.chi_face_el_etg,
129+
d_face_el_itg=core_transport.d_face_el_itg,
130+
d_face_el_tem=core_transport.d_face_el_tem,
131+
v_face_el_itg=core_transport.v_face_el_itg,
132+
v_face_el_tem=core_transport.v_face_el_tem,
133+
)
134+
turbulent_transport_coeffs = transport_model.smooth_coeffs(
135+
runtime_params=runtime_params,
136+
geo=geo,
137+
pedestal_model_output=pedestal_model_output,
138+
transport_coeffs=turbulent_transport_coeffs,
139+
)
140+
core_transport = dataclasses.replace(
141+
core_transport,
142+
**dataclasses.asdict(turbulent_transport_coeffs),
143+
)
114144
return core_transport

0 commit comments

Comments
 (0)