@@ -38,36 +38,39 @@ class MartinFormation(torax_pydantic.BaseModelFrozen):
3838 """Configuration for Martin formation model.
3939
4040 This formation model triggers a reduction in pedestal transport when P_SOL >
41- P_LH, where P_LH is from the Martin scaling. The reduction is a smooth sigmoid
42- function of the ratio P_SOL / (P_LH * P_LH_prefactor) .
41+ P_LH, where P_LH is from the Martin scaling. The reduction is a multiplicative
42+ factor between 1.0 and base_multiplier .
4343
4444 The formula is
45- rescaled_P_LH = P_LH * P_LH_prefactor
46- normalized_deviation = (P_SOL - rescaled_P_LH) / rescaled_P_LH - offset
47- transport_multiplier = 1 - sigmoid(normalized_deviation / width)
48- transport_multiplier = transport_multiplier**exponent
45+ transport_multiplier = (1.0 - alpha) * 1.0 + alpha * base_multiplier,
46+ where alpha is a smooth sigmoid function of
47+ (P_SOL - P_LH * P_LH_prefactor) / (P_LH * P_LH_prefactor)
48+ with given sharpness and offset, namely:
49+ sigmoid(x) = 1 / (1 + exp(-sharpness * [x - offset])).
4950
50- The transport multiplier is later clipped to the range
51- [min_transport_multiplier, max_transport_multiplier].
5251
5352 Attributes:
54- sigmoid_width: Dimensionless width of the sigmoid function for smoothing the
55- formation window. Increase for a smoother L-H transition, but doing so may
56- lead to starting the L-H transition at a power below P_LH.
57- sigmoid_offset: Dimensionless offset of sigmoid function from P_LH / P_SOL =
58- 1. Increase to start the L-H transition at a higher P_SOL / P_LH ratio.
59- sigmoid_exponent: The exponent of the transport multiplier. Increase for a
60- faster reduction in transport once the L-H transition starts.
53+ sharpness: Scaling factor applied to the argument of the sigmoid function,
54+ setting the sharpness of the smooth formation window. Decrease for a
55+ smoother formation, which may be more numerically stable but may lead to
56+ starting formation at a temperature or density below the target values.
57+ offset: Bias applied to the argument of the sigmoid function, setting the
58+ dimensionless offset of the formation window. Increase to start formation
59+ at a higher P_SOL.
60+ base_multiplier: The base value of the transport multiplier. Increase for
61+ stronger decreases in transport once formation starts.
6162 P_LH_prefactor: Dimensionless multiplier for P_LH. Increase to scale up
6263 P_LH, and therefore start the L-H transition at a higher P_SOL.
6364 """
6465
6566 model_name : Annotated [Literal ["martin" ], torax_pydantic .JAX_STATIC ] = "martin"
66- sigmoid_width : pydantic .PositiveFloat = 1e-3
67- sigmoid_offset : Annotated [
67+ sharpness : pydantic .PositiveFloat = 100.0
68+ offset : Annotated [
6869 array_typing .FloatScalar , pydantic .Field (ge = - 10.0 , le = 10.0 )
6970 ] = 0.0
70- sigmoid_exponent : pydantic .PositiveFloat = 3.0
71+ base_multiplier : Annotated [
72+ array_typing .FloatScalar , pydantic .Field (gt = 0.0 , le = 1.0 )
73+ ] = 1e-6
7174 P_LH_prefactor : pydantic .PositiveFloat = 1.0
7275
7376 def build_formation_model (
@@ -80,9 +83,9 @@ def build_runtime_params(
8083 ) -> martin_formation_model .MartinFormationRuntimeParams :
8184 del t
8285 return martin_formation_model .MartinFormationRuntimeParams (
83- sigmoid_width = self .sigmoid_width ,
84- sigmoid_offset = self .sigmoid_offset ,
85- sigmoid_exponent = self .sigmoid_exponent ,
86+ sharpness = self .sharpness ,
87+ offset = self .offset ,
88+ base_multiplier = self .base_multiplier ,
8689 P_LH_prefactor = self .P_LH_prefactor ,
8790 )
8891
@@ -92,36 +95,41 @@ class ProfileValueSaturation(torax_pydantic.BaseModelFrozen):
9295
9396 This saturation model triggers an increase in pedestal transport when the
9497 pedestal temperature and density are above the values requested by the
95- pedestal model. The increase is a smooth sigmoid function of the ratio of the
98+ pedestal model. The increase is a smooth linear function of the ratio of the
9699 current value to the value requested by the pedestal model.
97100
98101 The formula is
99- normalized_deviation = (current - target) / target - offset
100- transport_multiplier = 1 / (1 - sigmoid(normalized_deviation / width))
101- transport_multiplier = transport_multiplier**exponent
102+ transport_multiplier = 1 + alpha * base_multiplier,
103+ where alpha is a softplus function of the normalized deviation from the target
104+ value, with given steepness and offset:
105+ x = (current - target) / target - offset
106+ alpha = log(1 + exp(steepness * x))
102107
103- The transport multiplier is then clipped to the range
104- [min_transport_multiplier, max_transport_multiplier].
105108
106109 Attributes:
107- sigmoid_width: Dimensionless width of the sigmoid function for smoothing the
108- saturation window. Increase for a smoother saturation, but doing so may
109- lead to starting saturation at a temperature or density below the target
110- values.
111- sigmoid_offset: Dimensionless offset of the saturation window. Increase to
112- start saturation at a higher temperature or density.
113- sigmoid_exponent: The exponent of the transport multiplier. Increase for a
114- faster increase in transport once saturation starts.
110+ steepness: Scaling factor applied to the argument of the softplus function,
111+ setting the steepness of the smooth saturation function. Decrease for a
112+ smoother saturation, which may be more numerically stable but may lead to
113+ starting saturation at a temperature or density below the target values.
114+ offset: Bias applied to the argument of the softplus function, setting the
115+ dimensionless offset of the saturation window. Increase to start
116+ saturation at a higher temperature or density.
117+ base_multiplier: The base value of the transport multiplier. Increase for
118+ stronger increases in transport once saturation starts.
115119 """
116120
117121 model_name : Annotated [Literal ["profile_value" ], torax_pydantic .JAX_STATIC ] = (
118122 "profile_value"
119123 )
120- sigmoid_width : pydantic .PositiveFloat = 0.1
121- sigmoid_offset : Annotated [
124+ steepness : pydantic .PositiveFloat = 100.0
125+ # Default offset is > 0 as otherwise saturation starts too early. This is
126+ # because the softplus function is nonzero before the argument is zero.
127+ offset : Annotated [
122128 array_typing .FloatScalar , pydantic .Field (ge = - 10.0 , le = 10.0 )
123- ] = 0.0
124- sigmoid_exponent : pydantic .PositiveFloat = 1.0
129+ ] = 0.1
130+ base_multiplier : Annotated [
131+ array_typing .FloatScalar , pydantic .Field (gt = 1.0 )
132+ ] = 1e6
125133
126134 def build_saturation_model (
127135 self ,
@@ -133,9 +141,9 @@ def build_runtime_params(
133141 ) -> runtime_params .SaturationRuntimeParams :
134142 del t
135143 return runtime_params .SaturationRuntimeParams (
136- sigmoid_width = self .sigmoid_width ,
137- sigmoid_offset = self .sigmoid_offset ,
138- sigmoid_exponent = self .sigmoid_exponent ,
144+ steepness = self .steepness ,
145+ offset = self .offset ,
146+ base_multiplier = self .base_multiplier ,
139147 )
140148
141149
@@ -169,12 +177,6 @@ class BasePedestal(torax_pydantic.BaseModelFrozen, abc.ABC):
169177 saturation_model : SaturationConfig = torax_pydantic .ValidatedDefault (
170178 ProfileValueSaturation ()
171179 )
172- max_transport_multiplier : torax_pydantic .TimeVaryingScalar = (
173- torax_pydantic .ValidatedDefault (10.0 )
174- )
175- min_transport_multiplier : torax_pydantic .TimeVaryingScalar = (
176- torax_pydantic .ValidatedDefault (0.1 )
177- )
178180
179181 @pydantic .model_validator (mode = "before" )
180182 @classmethod
@@ -205,8 +207,6 @@ def build_runtime_params(
205207 mode = self .mode ,
206208 formation = self .formation_model .build_runtime_params (t ),
207209 saturation = self .saturation_model .build_runtime_params (t ),
208- max_transport_multiplier = self .max_transport_multiplier .get_value (t ),
209- min_transport_multiplier = self .min_transport_multiplier .get_value (t ),
210210 )
211211
212212
@@ -262,8 +262,6 @@ def build_runtime_params(
262262 rho_norm_ped_top = self .rho_norm_ped_top .get_value (t ),
263263 formation = base_runtime_params .formation ,
264264 saturation = base_runtime_params .saturation ,
265- max_transport_multiplier = base_runtime_params .max_transport_multiplier ,
266- min_transport_multiplier = base_runtime_params .min_transport_multiplier ,
267265 )
268266
269267
@@ -318,8 +316,6 @@ def build_runtime_params(
318316 rho_norm_ped_top = self .rho_norm_ped_top .get_value (t ),
319317 formation = base_runtime_params .formation ,
320318 saturation = base_runtime_params .saturation ,
321- max_transport_multiplier = base_runtime_params .max_transport_multiplier ,
322- min_transport_multiplier = base_runtime_params .min_transport_multiplier ,
323319 )
324320
325321
@@ -351,8 +347,6 @@ def build_runtime_params(
351347 mode = base_runtime_params .mode ,
352348 formation = base_runtime_params .formation ,
353349 saturation = base_runtime_params .saturation ,
354- max_transport_multiplier = base_runtime_params .max_transport_multiplier ,
355- min_transport_multiplier = base_runtime_params .min_transport_multiplier ,
356350 )
357351
358352
0 commit comments