Skip to content

Commit 4905765

Browse files
theo-brownTorax team
authored andcommitted
Improve numerics of saturation and formation models.
- Switch the formation model to use linear interpolation between base_multiplier and 1.0 based on sigmoid(P_sol - P_LH) - Switch the saturation model to use softplus, giving linear increase in transport: (value_at_ped_top - desired_value_at_ped_top) * base_multiplier - Ensure clipping is only applied to the final transport multipliers, rather than components - Update names of model parameters for improved interpretability These changes made the solver converge noticeably faster. PiperOrigin-RevId: 878555972
1 parent aa3a55c commit 4905765

File tree

9 files changed

+90
-141
lines changed

9 files changed

+90
-141
lines changed

torax/_src/pedestal_model/formation/martin_formation_model.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import dataclasses
1818
import jax
19-
import jax.numpy as jnp
2019
from torax._src import array_typing
2120
from torax._src import math_utils
2221
from torax._src import state
@@ -86,19 +85,13 @@ def __call__(
8685
# If P_SOL > P_LH, multiplier tends to 0.0
8786
# If P_SOL < P_LH, multiplier tends to 1.0
8887
# TODO(b/488393318): Add hysteresis to the LH-HL transition.
89-
width = runtime_params.pedestal.formation.sigmoid_width
90-
exponent = runtime_params.pedestal.formation.sigmoid_exponent
91-
offset = runtime_params.pedestal.formation.sigmoid_offset
92-
normalized_deviation = (
93-
P_SOL_total - rescaled_P_LH
94-
) / rescaled_P_LH - offset
95-
transport_multiplier = 1 - jax.nn.sigmoid(normalized_deviation / width)
96-
transport_multiplier = transport_multiplier**exponent
97-
transport_multiplier = jnp.clip(
98-
transport_multiplier,
99-
min=runtime_params.pedestal.min_transport_multiplier,
100-
max=runtime_params.pedestal.max_transport_multiplier,
101-
)
88+
sharpness = runtime_params.pedestal.formation.sharpness
89+
offset = runtime_params.pedestal.formation.offset
90+
base_multiplier = runtime_params.pedestal.formation.base_multiplier
91+
normalized_deviation = (P_SOL_total - rescaled_P_LH) / rescaled_P_LH
92+
shifted_deviation = normalized_deviation - offset
93+
alpha = jax.nn.sigmoid(shifted_deviation * sharpness)
94+
transport_multiplier = (1.0 - alpha) * 1.0 + alpha * base_multiplier
10295

10396
return pedestal_model_output.TransportMultipliers(
10497
chi_e_multiplier=transport_multiplier,

torax/_src/pedestal_model/formation/tests/martin_formation_model_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,11 @@ def test_calculate_P_SOL_total(self):
6666

6767
@parameterized.named_parameters(
6868
dict(
69-
# If P_sol >> P_LH, we expect the suppression multiplier to be very
70-
# small (significant suppression). However, it's clipped internally to
71-
# be 0.1.
69+
# If P_sol >> P_LH, we expect the suppression multiplier to be
70+
# base_multiplier.
7271
testcase_name='above_threshold',
7372
power=1e6,
74-
expected_multiplier=0.1,
73+
expected_multiplier=1e-6,
7574
),
7675
dict(
7776
# If P_sol << P_LH, we expect the suppression multiplier to be 1.0

torax/_src/pedestal_model/pydantic_model.py

Lines changed: 50 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -38,36 +38,39 @@ class MartinFormation(torax_pydantic.BaseModelFrozen):
3838
"""Configuration for Martin formation model.
3939
4040
This formation model triggers a reduction in pedestal transport when P_SOL >
41-
P_LH, where P_LH is from the Martin scaling. The reduction is a smooth sigmoid
42-
function of the ratio P_SOL / (P_LH * P_LH_prefactor).
41+
P_LH, where P_LH is from the Martin scaling. The reduction is a multiplicative
42+
factor between 1.0 and base_multiplier.
4343
4444
The formula is
45-
rescaled_P_LH = P_LH * P_LH_prefactor
46-
normalized_deviation = (P_SOL - rescaled_P_LH) / rescaled_P_LH - offset
47-
transport_multiplier = 1 - sigmoid(normalized_deviation / width)
48-
transport_multiplier = transport_multiplier**exponent
45+
transport_multiplier = (1.0 - alpha) * 1.0 + alpha * base_multiplier,
46+
where alpha is a smooth sigmoid function of
47+
(P_SOL - P_LH * P_LH_prefactor) / (P_LH * P_LH_prefactor)
48+
with given sharpness and offset, namely:
49+
sigmoid(x) = 1 / (1 + exp(-sharpness * [x - offset])).
4950
50-
The transport multiplier is later clipped to the range
51-
[min_transport_multiplier, max_transport_multiplier].
5251
5352
Attributes:
54-
sigmoid_width: Dimensionless width of the sigmoid function for smoothing the
55-
formation window. Increase for a smoother L-H transition, but doing so may
56-
lead to starting the L-H transition at a power below P_LH.
57-
sigmoid_offset: Dimensionless offset of sigmoid function from P_LH / P_SOL =
58-
1. Increase to start the L-H transition at a higher P_SOL / P_LH ratio.
59-
sigmoid_exponent: The exponent of the transport multiplier. Increase for a
60-
faster reduction in transport once the L-H transition starts.
53+
sharpness: Scaling factor applied to the argument of the sigmoid function,
54+
setting the sharpness of the smooth formation window. Decrease for a
55+
smoother formation, which may be more numerically stable but may lead to
56+
starting formation at a temperature or density below the target values.
57+
offset: Bias applied to the argument of the sigmoid function, setting the
58+
dimensionless offset of the formation window. Increase to start formation
59+
at a higher P_SOL.
60+
base_multiplier: The base value of the transport multiplier. Increase for
61+
stronger decreases in transport once formation starts.
6162
P_LH_prefactor: Dimensionless multiplier for P_LH. Increase to scale up
6263
P_LH, and therefore start the L-H transition at a higher P_SOL.
6364
"""
6465

6566
model_name: Annotated[Literal["martin"], torax_pydantic.JAX_STATIC] = "martin"
66-
sigmoid_width: pydantic.PositiveFloat = 1e-3
67-
sigmoid_offset: Annotated[
67+
sharpness: pydantic.PositiveFloat = 100.0
68+
offset: Annotated[
6869
array_typing.FloatScalar, pydantic.Field(ge=-10.0, le=10.0)
6970
] = 0.0
70-
sigmoid_exponent: pydantic.PositiveFloat = 3.0
71+
base_multiplier: Annotated[
72+
array_typing.FloatScalar, pydantic.Field(gt=0.0, le=1.0)
73+
] = 1e-6
7174
P_LH_prefactor: pydantic.PositiveFloat = 1.0
7275

7376
def build_formation_model(
@@ -80,9 +83,9 @@ def build_runtime_params(
8083
) -> martin_formation_model.MartinFormationRuntimeParams:
8184
del t
8285
return martin_formation_model.MartinFormationRuntimeParams(
83-
sigmoid_width=self.sigmoid_width,
84-
sigmoid_offset=self.sigmoid_offset,
85-
sigmoid_exponent=self.sigmoid_exponent,
86+
sharpness=self.sharpness,
87+
offset=self.offset,
88+
base_multiplier=self.base_multiplier,
8689
P_LH_prefactor=self.P_LH_prefactor,
8790
)
8891

@@ -92,36 +95,41 @@ class ProfileValueSaturation(torax_pydantic.BaseModelFrozen):
9295
9396
This saturation model triggers an increase in pedestal transport when the
9497
pedestal temperature and density are above the values requested by the
95-
pedestal model. The increase is a smooth sigmoid function of the ratio of the
98+
pedestal model. The increase is a smooth linear function of the ratio of the
9699
current value to the value requested by the pedestal model.
97100
98101
The formula is
99-
normalized_deviation = (current - target) / target - offset
100-
transport_multiplier = 1 / (1 - sigmoid(normalized_deviation / width))
101-
transport_multiplier = transport_multiplier**exponent
102+
transport_multiplier = 1 + alpha * base_multiplier,
103+
where alpha is a softplus function of the normalized deviation from the target
104+
value, with given steepness and offset:
105+
x = (current - target) / target - offset
106+
alpha = log(1 + exp(steepness * x))
102107
103-
The transport multiplier is then clipped to the range
104-
[min_transport_multiplier, max_transport_multiplier].
105108
106109
Attributes:
107-
sigmoid_width: Dimensionless width of the sigmoid function for smoothing the
108-
saturation window. Increase for a smoother saturation, but doing so may
109-
lead to starting saturation at a temperature or density below the target
110-
values.
111-
sigmoid_offset: Dimensionless offset of the saturation window. Increase to
112-
start saturation at a higher temperature or density.
113-
sigmoid_exponent: The exponent of the transport multiplier. Increase for a
114-
faster increase in transport once saturation starts.
110+
steepness: Scaling factor applied to the argument of the softplus function,
111+
setting the steepness of the smooth saturation function. Decrease for a
112+
smoother saturation, which may be more numerically stable but may lead to
113+
starting saturation at a temperature or density below the target values.
114+
offset: Bias applied to the argument of the softplus function, setting the
115+
dimensionless offset of the saturation window. Increase to start
116+
saturation at a higher temperature or density.
117+
base_multiplier: The base value of the transport multiplier. Increase for
118+
stronger increases in transport once saturation starts.
115119
"""
116120

117121
model_name: Annotated[Literal["profile_value"], torax_pydantic.JAX_STATIC] = (
118122
"profile_value"
119123
)
120-
sigmoid_width: pydantic.PositiveFloat = 0.1
121-
sigmoid_offset: Annotated[
124+
steepness: pydantic.PositiveFloat = 100.0
125+
# Default offset is > 0 as otherwise saturation starts too early. This is
126+
# because the softplus function is nonzero before the argument is zero.
127+
offset: Annotated[
122128
array_typing.FloatScalar, pydantic.Field(ge=-10.0, le=10.0)
123-
] = 0.0
124-
sigmoid_exponent: pydantic.PositiveFloat = 1.0
129+
] = 0.1
130+
base_multiplier: Annotated[
131+
array_typing.FloatScalar, pydantic.Field(gt=1.0)
132+
] = 1e6
125133

126134
def build_saturation_model(
127135
self,
@@ -133,9 +141,9 @@ def build_runtime_params(
133141
) -> runtime_params.SaturationRuntimeParams:
134142
del t
135143
return runtime_params.SaturationRuntimeParams(
136-
sigmoid_width=self.sigmoid_width,
137-
sigmoid_offset=self.sigmoid_offset,
138-
sigmoid_exponent=self.sigmoid_exponent,
144+
steepness=self.steepness,
145+
offset=self.offset,
146+
base_multiplier=self.base_multiplier,
139147
)
140148

141149

@@ -169,12 +177,6 @@ class BasePedestal(torax_pydantic.BaseModelFrozen, abc.ABC):
169177
saturation_model: SaturationConfig = torax_pydantic.ValidatedDefault(
170178
ProfileValueSaturation()
171179
)
172-
max_transport_multiplier: torax_pydantic.TimeVaryingScalar = (
173-
torax_pydantic.ValidatedDefault(10.0)
174-
)
175-
min_transport_multiplier: torax_pydantic.TimeVaryingScalar = (
176-
torax_pydantic.ValidatedDefault(0.1)
177-
)
178180

179181
@pydantic.model_validator(mode="before")
180182
@classmethod
@@ -205,8 +207,6 @@ def build_runtime_params(
205207
mode=self.mode,
206208
formation=self.formation_model.build_runtime_params(t),
207209
saturation=self.saturation_model.build_runtime_params(t),
208-
max_transport_multiplier=self.max_transport_multiplier.get_value(t),
209-
min_transport_multiplier=self.min_transport_multiplier.get_value(t),
210210
)
211211

212212

@@ -262,8 +262,6 @@ def build_runtime_params(
262262
rho_norm_ped_top=self.rho_norm_ped_top.get_value(t),
263263
formation=base_runtime_params.formation,
264264
saturation=base_runtime_params.saturation,
265-
max_transport_multiplier=base_runtime_params.max_transport_multiplier,
266-
min_transport_multiplier=base_runtime_params.min_transport_multiplier,
267265
)
268266

269267

@@ -318,8 +316,6 @@ def build_runtime_params(
318316
rho_norm_ped_top=self.rho_norm_ped_top.get_value(t),
319317
formation=base_runtime_params.formation,
320318
saturation=base_runtime_params.saturation,
321-
max_transport_multiplier=base_runtime_params.max_transport_multiplier,
322-
min_transport_multiplier=base_runtime_params.min_transport_multiplier,
323319
)
324320

325321

@@ -351,8 +347,6 @@ def build_runtime_params(
351347
mode=base_runtime_params.mode,
352348
formation=base_runtime_params.formation,
353349
saturation=base_runtime_params.saturation,
354-
max_transport_multiplier=base_runtime_params.max_transport_multiplier,
355-
min_transport_multiplier=base_runtime_params.min_transport_multiplier,
356350
)
357351

358352

torax/_src/pedestal_model/runtime_params.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,19 @@ class Mode(enum.Enum):
3737
class FormationRuntimeParams:
3838
"""Runtime params for pedestal formation models."""
3939

40-
sigmoid_width: array_typing.FloatScalar = 0.1
41-
sigmoid_offset: array_typing.FloatScalar = 0.0
42-
sigmoid_exponent: array_typing.FloatScalar = 1.0
40+
sharpness: array_typing.FloatScalar
41+
offset: array_typing.FloatScalar
42+
base_multiplier: array_typing.FloatScalar
4343

4444

4545
@jax.tree_util.register_dataclass
4646
@dataclasses.dataclass(frozen=True)
4747
class SaturationRuntimeParams:
4848
"""Runtime params for pedestal saturation models."""
4949

50-
sigmoid_width: array_typing.FloatScalar = 0.1
51-
sigmoid_offset: array_typing.FloatScalar = 0.0
52-
sigmoid_exponent: array_typing.FloatScalar = 1.0
50+
steepness: array_typing.FloatScalar
51+
offset: array_typing.FloatScalar
52+
base_multiplier: array_typing.FloatScalar
5353

5454

5555
@jax.tree_util.register_dataclass
@@ -59,7 +59,5 @@ class RuntimeParams:
5959

6060
set_pedestal: array_typing.BoolScalar
6161
mode: Mode = dataclasses.field(metadata={"static": True})
62-
min_transport_multiplier: array_typing.FloatScalar
63-
max_transport_multiplier: array_typing.FloatScalar
6462
formation: FormationRuntimeParams
6563
saturation: SaturationRuntimeParams

torax/_src/pedestal_model/saturation/profile_value_saturation_model.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,11 @@ def _calculate_multiplier(
9191
Returns:
9292
The transport increase multiplier.
9393
"""
94-
width = pedestal_runtime_params.saturation.sigmoid_width
95-
exponent = pedestal_runtime_params.saturation.sigmoid_exponent
96-
offset = pedestal_runtime_params.saturation.sigmoid_offset
94+
steepness = pedestal_runtime_params.saturation.steepness
95+
offset = pedestal_runtime_params.saturation.offset
96+
base_multiplier = pedestal_runtime_params.saturation.base_multiplier
9797
normalized_deviation = (current - target) / target - offset
98-
transport_multiplier = 1 / (
99-
1 - jax.nn.sigmoid(normalized_deviation / width)
98+
transport_multiplier = 1 + base_multiplier * jax.nn.softplus(
99+
normalized_deviation * steepness
100100
)
101-
transport_multiplier = transport_multiplier**exponent
102-
transport_multiplier = jnp.clip(
103-
transport_multiplier,
104-
min=pedestal_runtime_params.min_transport_multiplier,
105-
max=pedestal_runtime_params.max_transport_multiplier,
106-
)
107-
108101
return transport_multiplier

torax/_src/pedestal_model/saturation/tests/profile_value_saturation_model_test.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import dataclasses
1514
from absl.testing import absltest
1615
from absl.testing import parameterized
1716
import numpy as np
18-
from torax._src import constants
1917
from torax._src.config import build_runtime_params
2018
from torax._src.core_profiles import initialization
2119
from torax._src.pedestal_model import pedestal_model_output
@@ -49,43 +47,35 @@ def setUp(self):
4947
@parameterized.named_parameters(
5048
dict(
5149
testcase_name='active',
52-
# P_current >> P_target -> saturation is active.
53-
P_current_over_P_target=1e3,
50+
# T_current >> T_target -> saturation is active.
51+
T_target_over_T_current=1e-3,
5452
),
5553
dict(
5654
testcase_name='inactive',
57-
# P_current << P_target -> no saturation.
58-
P_current_over_P_target=1e-3,
55+
# T_current << T_target -> no saturation.
56+
T_target_over_T_current=1e3,
5957
),
6058
)
6159
def test_saturation_multiplier(
6260
self,
63-
P_current_over_P_target,
61+
T_target_over_T_current,
6462
):
6563
saturation_model = (
6664
profile_value_saturation_model.ProfileValueSaturationModel()
6765
)
6866

6967
# For this test, we put the pedestal top at the last grid point.
7068
ped_top_idx = -1
71-
72-
# Get the current pressure at the pedestal top.
73-
current_pressure = self.core_profiles.pressure_thermal_total.value[
74-
ped_top_idx
75-
]
69+
current_T_e_ped = self.core_profiles.T_e.face_value()[ped_top_idx]
7670

7771
# Construct a pedestal output that is asking for a pedestal with
78-
# target_pressure.
79-
target_pressure = current_pressure / P_current_over_P_target
72+
# target temperature.
8073
pedestal_output = pedestal_model_output.PedestalModelOutput(
8174
rho_norm_ped_top=self.geo.rho_face[ped_top_idx],
8275
rho_norm_ped_top_idx=ped_top_idx,
83-
# Set T_i_ped, T_e_ped, and n_e_ped such that target_pressure is
84-
# achieved. This case has n_impurity = 0 and Z_i = 1, so
85-
# P = (T_i + T_e)*keV_to_J*n_e.
86-
T_i_ped=(1.0 / constants.CONSTANTS.keV_to_J) / 2,
87-
T_e_ped=(1.0 / constants.CONSTANTS.keV_to_J) / 2,
88-
n_e_ped=target_pressure,
76+
T_i_ped=1.0,
77+
T_e_ped=current_T_e_ped * T_target_over_T_current,
78+
n_e_ped=1.0,
8979
)
9080

9181
transport_multipliers = saturation_model(
@@ -95,16 +85,14 @@ def test_saturation_multiplier(
9585
pedestal_output,
9686
)
9787

98-
if P_current_over_P_target > 1.0:
99-
# If the current pressure is above the target pressure, we expect the
100-
# multiplier to be greater than 1.0.
101-
for multiplier in dataclasses.asdict(transport_multipliers).values():
102-
self.assertGreater(multiplier, 1.0)
88+
if T_target_over_T_current > 1.0:
89+
# If the target temperature is above the current temperature, we expect
90+
# the multiplier to be equal to 1.0 - the pedestal is not saturated.
91+
np.testing.assert_allclose(transport_multipliers.chi_e_multiplier, 1.0)
10392
else:
104-
# If the current pressure is below the target pressure, we expect the
105-
# multiplier to be 1.0.
106-
for multiplier in dataclasses.asdict(transport_multipliers).values():
107-
np.testing.assert_allclose(multiplier, 1.0, rtol=1e-3)
93+
# If the target temperature is below the current temperature, we expect
94+
# the multiplier to be greater than 1.0 - the pedestal is saturated.
95+
self.assertGreater(transport_multipliers.chi_e_multiplier, 1.0)
10896

10997

11098
if __name__ == '__main__':

0 commit comments

Comments
 (0)