Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions torax/_src/pedestal_model/formation/martin_formation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import dataclasses
import jax
import jax.numpy as jnp
from torax._src import array_typing
from torax._src import math_utils
from torax._src import state
Expand Down Expand Up @@ -86,19 +85,13 @@ def __call__(
# If P_SOL > P_LH, multiplier tends to 0.0
# If P_SOL < P_LH, multiplier tends to 1.0
# TODO(b/488393318): Add hysteresis to the LH-HL transition.
width = runtime_params.pedestal.formation.sigmoid_width
exponent = runtime_params.pedestal.formation.sigmoid_exponent
offset = runtime_params.pedestal.formation.sigmoid_offset
normalized_deviation = (
P_SOL_total - rescaled_P_LH
) / rescaled_P_LH - offset
transport_multiplier = 1 - jax.nn.sigmoid(normalized_deviation / width)
transport_multiplier = transport_multiplier**exponent
transport_multiplier = jnp.clip(
transport_multiplier,
min=runtime_params.pedestal.min_transport_multiplier,
max=runtime_params.pedestal.max_transport_multiplier,
)
sharpness = runtime_params.pedestal.formation.sharpness
offset = runtime_params.pedestal.formation.offset
base_multiplier = runtime_params.pedestal.formation.base_multiplier
normalized_deviation = (P_SOL_total - rescaled_P_LH) / rescaled_P_LH
shifted_deviation = normalized_deviation - offset
alpha = jax.nn.sigmoid(shifted_deviation * sharpness)
transport_multiplier = (1.0 - alpha) * 1.0 + alpha * base_multiplier

return pedestal_model_output.TransportMultipliers(
chi_e_multiplier=transport_multiplier,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,11 @@ def test_calculate_P_SOL_total(self):

@parameterized.named_parameters(
dict(
# If P_sol >> P_LH, we expect the suppression multiplier to be very
# small (significant suppression). However, it's clipped internally to
# be 0.1.
# If P_sol >> P_LH, we expect the suppression multiplier to be
# base_multiplier.
testcase_name='above_threshold',
power=1e6,
expected_multiplier=0.1,
expected_multiplier=1e-6,
),
dict(
# If P_sol << P_LH, we expect the suppression multiplier to be 1.0
Expand Down
106 changes: 50 additions & 56 deletions torax/_src/pedestal_model/pydantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,36 +38,39 @@ class MartinFormation(torax_pydantic.BaseModelFrozen):
"""Configuration for Martin formation model.

This formation model triggers a reduction in pedestal transport when P_SOL >
P_LH, where P_LH is from the Martin scaling. The reduction is a smooth sigmoid
function of the ratio P_SOL / (P_LH * P_LH_prefactor).
P_LH, where P_LH is from the Martin scaling. The reduction is a multiplicative
factor between 1.0 and base_multiplier.

The formula is
rescaled_P_LH = P_LH * P_LH_prefactor
normalized_deviation = (P_SOL - rescaled_P_LH) / rescaled_P_LH - offset
transport_multiplier = 1 - sigmoid(normalized_deviation / width)
transport_multiplier = transport_multiplier**exponent
transport_multiplier = (1.0 - alpha) * 1.0 + alpha * base_multiplier,
where alpha is a smooth sigmoid function of
(P_SOL - P_LH * P_LH_prefactor) / (P_LH * P_LH_prefactor)
with given sharpness and offset, namely:
sigmoid(x) = 1 / (1 + exp(-sharpness * [x - offset])).

The transport multiplier is later clipped to the range
[min_transport_multiplier, max_transport_multiplier].

Attributes:
sigmoid_width: Dimensionless width of the sigmoid function for smoothing the
formation window. Increase for a smoother L-H transition, but doing so may
lead to starting the L-H transition at a power below P_LH.
sigmoid_offset: Dimensionless offset of sigmoid function from P_LH / P_SOL =
1. Increase to start the L-H transition at a higher P_SOL / P_LH ratio.
sigmoid_exponent: The exponent of the transport multiplier. Increase for a
faster reduction in transport once the L-H transition starts.
sharpness: Scaling factor applied to the argument of the sigmoid function,
setting the sharpness of the smooth formation window. Decrease for a
smoother formation, which may be more numerically stable but may lead to
starting formation at a temperature or density below the target values.
offset: Bias applied to the argument of the sigmoid function, setting the
dimensionless offset of the formation window. Increase to start formation
at a higher P_SOL.
base_multiplier: The base value of the transport multiplier. Increase for
stronger decreases in transport once formation starts.
P_LH_prefactor: Dimensionless multiplier for P_LH. Increase to scale up
P_LH, and therefore start the L-H transition at a higher P_SOL.
"""

model_name: Annotated[Literal["martin"], torax_pydantic.JAX_STATIC] = "martin"
sigmoid_width: pydantic.PositiveFloat = 1e-3
sigmoid_offset: Annotated[
sharpness: pydantic.PositiveFloat = 100.0
offset: Annotated[
array_typing.FloatScalar, pydantic.Field(ge=-10.0, le=10.0)
] = 0.0
sigmoid_exponent: pydantic.PositiveFloat = 3.0
base_multiplier: Annotated[
array_typing.FloatScalar, pydantic.Field(gt=0.0, le=1.0)
] = 1e-6
P_LH_prefactor: pydantic.PositiveFloat = 1.0

def build_formation_model(
Expand All @@ -80,9 +83,9 @@ def build_runtime_params(
) -> martin_formation_model.MartinFormationRuntimeParams:
del t
return martin_formation_model.MartinFormationRuntimeParams(
sigmoid_width=self.sigmoid_width,
sigmoid_offset=self.sigmoid_offset,
sigmoid_exponent=self.sigmoid_exponent,
sharpness=self.sharpness,
offset=self.offset,
base_multiplier=self.base_multiplier,
P_LH_prefactor=self.P_LH_prefactor,
)

Expand All @@ -92,36 +95,41 @@ class ProfileValueSaturation(torax_pydantic.BaseModelFrozen):

This saturation model triggers an increase in pedestal transport when the
pedestal temperature and density are above the values requested by the
pedestal model. The increase is a smooth sigmoid function of the ratio of the
pedestal model. The increase is a smooth linear function of the ratio of the
current value to the value requested by the pedestal model.

The formula is
normalized_deviation = (current - target) / target - offset
transport_multiplier = 1 / (1 - sigmoid(normalized_deviation / width))
transport_multiplier = transport_multiplier**exponent
transport_multiplier = 1 + alpha * base_multiplier,
where alpha is a softplus function of the normalized deviation from the target
value, with given steepness and offset:
x = (current - target) / target - offset
alpha = log(1 + exp(steepness * x))

The transport multiplier is then clipped to the range
[min_transport_multiplier, max_transport_multiplier].

Attributes:
sigmoid_width: Dimensionless width of the sigmoid function for smoothing the
saturation window. Increase for a smoother saturation, but doing so may
lead to starting saturation at a temperature or density below the target
values.
sigmoid_offset: Dimensionless offset of the saturation window. Increase to
start saturation at a higher temperature or density.
sigmoid_exponent: The exponent of the transport multiplier. Increase for a
faster increase in transport once saturation starts.
steepness: Scaling factor applied to the argument of the softplus function,
setting the steepness of the smooth saturation function. Decrease for a
smoother saturation, which may be more numerically stable but may lead to
starting saturation at a temperature or density below the target values.
offset: Bias applied to the argument of the softplus function, setting the
dimensionless offset of the saturation window. Increase to start
saturation at a higher temperature or density.
base_multiplier: The base value of the transport multiplier. Increase for
stronger increases in transport once saturation starts.
"""

model_name: Annotated[Literal["profile_value"], torax_pydantic.JAX_STATIC] = (
"profile_value"
)
sigmoid_width: pydantic.PositiveFloat = 0.1
sigmoid_offset: Annotated[
steepness: pydantic.PositiveFloat = 100.0
# Default offset is > 0 as otherwise saturation starts too early. This is
# because the softplus function is nonzero before the argument is zero.
offset: Annotated[
array_typing.FloatScalar, pydantic.Field(ge=-10.0, le=10.0)
] = 0.0
sigmoid_exponent: pydantic.PositiveFloat = 1.0
] = 0.1
base_multiplier: Annotated[
array_typing.FloatScalar, pydantic.Field(gt=1.0)
] = 1e6

def build_saturation_model(
self,
Expand All @@ -133,9 +141,9 @@ def build_runtime_params(
) -> runtime_params.SaturationRuntimeParams:
del t
return runtime_params.SaturationRuntimeParams(
sigmoid_width=self.sigmoid_width,
sigmoid_offset=self.sigmoid_offset,
sigmoid_exponent=self.sigmoid_exponent,
steepness=self.steepness,
offset=self.offset,
base_multiplier=self.base_multiplier,
)


Expand Down Expand Up @@ -169,12 +177,6 @@ class BasePedestal(torax_pydantic.BaseModelFrozen, abc.ABC):
saturation_model: SaturationConfig = torax_pydantic.ValidatedDefault(
ProfileValueSaturation()
)
max_transport_multiplier: torax_pydantic.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(10.0)
)
min_transport_multiplier: torax_pydantic.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(0.1)
)

@pydantic.model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -205,8 +207,6 @@ def build_runtime_params(
mode=self.mode,
formation=self.formation_model.build_runtime_params(t),
saturation=self.saturation_model.build_runtime_params(t),
max_transport_multiplier=self.max_transport_multiplier.get_value(t),
min_transport_multiplier=self.min_transport_multiplier.get_value(t),
)


Expand Down Expand Up @@ -262,8 +262,6 @@ def build_runtime_params(
rho_norm_ped_top=self.rho_norm_ped_top.get_value(t),
formation=base_runtime_params.formation,
saturation=base_runtime_params.saturation,
max_transport_multiplier=base_runtime_params.max_transport_multiplier,
min_transport_multiplier=base_runtime_params.min_transport_multiplier,
)


Expand Down Expand Up @@ -318,8 +316,6 @@ def build_runtime_params(
rho_norm_ped_top=self.rho_norm_ped_top.get_value(t),
formation=base_runtime_params.formation,
saturation=base_runtime_params.saturation,
max_transport_multiplier=base_runtime_params.max_transport_multiplier,
min_transport_multiplier=base_runtime_params.min_transport_multiplier,
)


Expand Down Expand Up @@ -351,8 +347,6 @@ def build_runtime_params(
mode=base_runtime_params.mode,
formation=base_runtime_params.formation,
saturation=base_runtime_params.saturation,
max_transport_multiplier=base_runtime_params.max_transport_multiplier,
min_transport_multiplier=base_runtime_params.min_transport_multiplier,
)


Expand Down
14 changes: 6 additions & 8 deletions torax/_src/pedestal_model/runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@ class Mode(enum.Enum):
class FormationRuntimeParams:
"""Runtime params for pedestal formation models."""

sigmoid_width: array_typing.FloatScalar = 0.1
sigmoid_offset: array_typing.FloatScalar = 0.0
sigmoid_exponent: array_typing.FloatScalar = 1.0
sharpness: array_typing.FloatScalar
offset: array_typing.FloatScalar
base_multiplier: array_typing.FloatScalar


@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class SaturationRuntimeParams:
"""Runtime params for pedestal saturation models."""

sigmoid_width: array_typing.FloatScalar = 0.1
sigmoid_offset: array_typing.FloatScalar = 0.0
sigmoid_exponent: array_typing.FloatScalar = 1.0
steepness: array_typing.FloatScalar
offset: array_typing.FloatScalar
base_multiplier: array_typing.FloatScalar


@jax.tree_util.register_dataclass
Expand All @@ -59,7 +59,5 @@ class RuntimeParams:

set_pedestal: array_typing.BoolScalar
mode: Mode = dataclasses.field(metadata={"static": True})
min_transport_multiplier: array_typing.FloatScalar
max_transport_multiplier: array_typing.FloatScalar
formation: FormationRuntimeParams
saturation: SaturationRuntimeParams
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,11 @@ def _calculate_multiplier(
Returns:
The transport increase multiplier.
"""
width = pedestal_runtime_params.saturation.sigmoid_width
exponent = pedestal_runtime_params.saturation.sigmoid_exponent
offset = pedestal_runtime_params.saturation.sigmoid_offset
steepness = pedestal_runtime_params.saturation.steepness
offset = pedestal_runtime_params.saturation.offset
base_multiplier = pedestal_runtime_params.saturation.base_multiplier
normalized_deviation = (current - target) / target - offset
transport_multiplier = 1 / (
1 - jax.nn.sigmoid(normalized_deviation / width)
transport_multiplier = 1 + base_multiplier * jax.nn.softplus(
normalized_deviation * steepness
)
transport_multiplier = transport_multiplier**exponent
transport_multiplier = jnp.clip(
transport_multiplier,
min=pedestal_runtime_params.min_transport_multiplier,
max=pedestal_runtime_params.max_transport_multiplier,
)

return transport_multiplier
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from torax._src import constants
from torax._src.config import build_runtime_params
from torax._src.core_profiles import initialization
from torax._src.pedestal_model import pedestal_model_output
Expand Down Expand Up @@ -49,43 +47,35 @@ def setUp(self):
@parameterized.named_parameters(
dict(
testcase_name='active',
# P_current >> P_target -> saturation is active.
P_current_over_P_target=1e3,
# T_current >> T_target -> saturation is active.
T_target_over_T_current=1e-3,
),
dict(
testcase_name='inactive',
# P_current << P_target -> no saturation.
P_current_over_P_target=1e-3,
# T_current << T_target -> no saturation.
T_target_over_T_current=1e3,
),
)
def test_saturation_multiplier(
self,
P_current_over_P_target,
T_target_over_T_current,
):
saturation_model = (
profile_value_saturation_model.ProfileValueSaturationModel()
)

# For this test, we put the pedestal top at the last grid point.
ped_top_idx = -1

# Get the current pressure at the pedestal top.
current_pressure = self.core_profiles.pressure_thermal_total.value[
ped_top_idx
]
current_T_e_ped = self.core_profiles.T_e.face_value()[ped_top_idx]

# Construct a pedestal output that is asking for a pedestal with
# target_pressure.
target_pressure = current_pressure / P_current_over_P_target
# target temperature.
pedestal_output = pedestal_model_output.PedestalModelOutput(
rho_norm_ped_top=self.geo.rho_face[ped_top_idx],
rho_norm_ped_top_idx=ped_top_idx,
# Set T_i_ped, T_e_ped, and n_e_ped such that target_pressure is
# achieved. This case has n_impurity = 0 and Z_i = 1, so
# P = (T_i + T_e)*keV_to_J*n_e.
T_i_ped=(1.0 / constants.CONSTANTS.keV_to_J) / 2,
T_e_ped=(1.0 / constants.CONSTANTS.keV_to_J) / 2,
n_e_ped=target_pressure,
T_i_ped=1.0,
T_e_ped=current_T_e_ped * T_target_over_T_current,
n_e_ped=1.0,
)

transport_multipliers = saturation_model(
Expand All @@ -95,16 +85,14 @@ def test_saturation_multiplier(
pedestal_output,
)

if P_current_over_P_target > 1.0:
# If the current pressure is above the target pressure, we expect the
# multiplier to be greater than 1.0.
for multiplier in dataclasses.asdict(transport_multipliers).values():
self.assertGreater(multiplier, 1.0)
if T_target_over_T_current > 1.0:
# If the target temperature is above the current temperature, we expect
# the multiplier to be equal to 1.0 - the pedestal is not saturated.
np.testing.assert_allclose(transport_multipliers.chi_e_multiplier, 1.0)
else:
# If the current pressure is below the target pressure, we expect the
# multiplier to be 1.0.
for multiplier in dataclasses.asdict(transport_multipliers).values():
np.testing.assert_allclose(multiplier, 1.0, rtol=1e-3)
# If the target temperature is below the current temperature, we expect
# the multiplier to be greater than 1.0 - the pedestal is saturated.
self.assertGreater(transport_multipliers.chi_e_multiplier, 1.0)


if __name__ == '__main__':
Expand Down
Loading
Loading