diff --git a/torax/_src/pedestal_model/formation/martin_formation_model.py b/torax/_src/pedestal_model/formation/martin_formation_model.py index c24b8d5fc..917f8bd8c 100644 --- a/torax/_src/pedestal_model/formation/martin_formation_model.py +++ b/torax/_src/pedestal_model/formation/martin_formation_model.py @@ -16,7 +16,6 @@ import dataclasses import jax -import jax.numpy as jnp from torax._src import array_typing from torax._src import math_utils from torax._src import state @@ -86,19 +85,13 @@ def __call__( # If P_SOL > P_LH, multiplier tends to 0.0 # If P_SOL < P_LH, multiplier tends to 1.0 # TODO(b/488393318): Add hysteresis to the LH-HL transition. - width = runtime_params.pedestal.formation.sigmoid_width - exponent = runtime_params.pedestal.formation.sigmoid_exponent - offset = runtime_params.pedestal.formation.sigmoid_offset - normalized_deviation = ( - P_SOL_total - rescaled_P_LH - ) / rescaled_P_LH - offset - transport_multiplier = 1 - jax.nn.sigmoid(normalized_deviation / width) - transport_multiplier = transport_multiplier**exponent - transport_multiplier = jnp.clip( - transport_multiplier, - min=runtime_params.pedestal.min_transport_multiplier, - max=runtime_params.pedestal.max_transport_multiplier, - ) + sharpness = runtime_params.pedestal.formation.sharpness + offset = runtime_params.pedestal.formation.offset + base_multiplier = runtime_params.pedestal.formation.base_multiplier + normalized_deviation = (P_SOL_total - rescaled_P_LH) / rescaled_P_LH + shifted_deviation = normalized_deviation - offset + alpha = jax.nn.sigmoid(shifted_deviation * sharpness) + transport_multiplier = (1.0 - alpha) * 1.0 + alpha * base_multiplier return pedestal_model_output.TransportMultipliers( chi_e_multiplier=transport_multiplier, diff --git a/torax/_src/pedestal_model/formation/tests/martin_formation_model_test.py b/torax/_src/pedestal_model/formation/tests/martin_formation_model_test.py index c774878ee..735c5a12e 100644 --- a/torax/_src/pedestal_model/formation/tests/martin_formation_model_test.py +++ b/torax/_src/pedestal_model/formation/tests/martin_formation_model_test.py @@ -66,12 +66,11 @@ def test_calculate_P_SOL_total(self): @parameterized.named_parameters( dict( - # If P_sol >> P_LH, we expect the suppression multiplier to be very - # small (significant suppression). However, it's clipped internally to - # be 0.1. + # If P_sol >> P_LH, we expect the suppression multiplier to be + # base_multiplier. testcase_name='above_threshold', power=1e6, - expected_multiplier=0.1, + expected_multiplier=1e-6, ), dict( # If P_sol << P_LH, we expect the suppression multiplier to be 1.0 diff --git a/torax/_src/pedestal_model/pydantic_model.py b/torax/_src/pedestal_model/pydantic_model.py index 1b6c5dc16..2b5860c77 100644 --- a/torax/_src/pedestal_model/pydantic_model.py +++ b/torax/_src/pedestal_model/pydantic_model.py @@ -38,36 +38,39 @@ class MartinFormation(torax_pydantic.BaseModelFrozen): """Configuration for Martin formation model. This formation model triggers a reduction in pedestal transport when P_SOL > - P_LH, where P_LH is from the Martin scaling. The reduction is a smooth sigmoid - function of the ratio P_SOL / (P_LH * P_LH_prefactor). + P_LH, where P_LH is from the Martin scaling. The reduction is a multiplicative + factor between 1.0 and base_multiplier. The formula is - rescaled_P_LH = P_LH * P_LH_prefactor - normalized_deviation = (P_SOL - rescaled_P_LH) / rescaled_P_LH - offset - transport_multiplier = 1 - sigmoid(normalized_deviation / width) - transport_multiplier = transport_multiplier**exponent + transport_multiplier = (1.0 - alpha) * 1.0 + alpha * base_multiplier, + where alpha is a smooth sigmoid function of + (P_SOL - P_LH * P_LH_prefactor) / (P_LH * P_LH_prefactor) + with given sharpness and offset, namely: + sigmoid(x) = 1 / (1 + exp(-sharpness * [x - offset])). - The transport multiplier is later clipped to the range - [min_transport_multiplier, max_transport_multiplier]. Attributes: - sigmoid_width: Dimensionless width of the sigmoid function for smoothing the - formation window. Increase for a smoother L-H transition, but doing so may - lead to starting the L-H transition at a power below P_LH. - sigmoid_offset: Dimensionless offset of sigmoid function from P_LH / P_SOL = - 1. Increase to start the L-H transition at a higher P_SOL / P_LH ratio. - sigmoid_exponent: The exponent of the transport multiplier. Increase for a - faster reduction in transport once the L-H transition starts. + sharpness: Scaling factor applied to the argument of the sigmoid function, + setting the sharpness of the smooth formation window. Decrease for a + smoother formation, which may be more numerically stable but may lead to + starting formation at a temperature or density below the target values. + offset: Bias applied to the argument of the sigmoid function, setting the + dimensionless offset of the formation window. Increase to start formation + at a higher P_SOL. + base_multiplier: The base value of the transport multiplier. Increase for + stronger decreases in transport once formation starts. P_LH_prefactor: Dimensionless multiplier for P_LH. Increase to scale up P_LH, and therefore start the L-H transition at a higher P_SOL. """ model_name: Annotated[Literal["martin"], torax_pydantic.JAX_STATIC] = "martin" - sigmoid_width: pydantic.PositiveFloat = 1e-3 - sigmoid_offset: Annotated[ + sharpness: pydantic.PositiveFloat = 100.0 + offset: Annotated[ array_typing.FloatScalar, pydantic.Field(ge=-10.0, le=10.0) ] = 0.0 - sigmoid_exponent: pydantic.PositiveFloat = 3.0 + base_multiplier: Annotated[ + array_typing.FloatScalar, pydantic.Field(gt=0.0, le=1.0) + ] = 1e-6 P_LH_prefactor: pydantic.PositiveFloat = 1.0 def build_formation_model( @@ -80,9 +83,9 @@ def build_runtime_params( ) -> martin_formation_model.MartinFormationRuntimeParams: del t return martin_formation_model.MartinFormationRuntimeParams( - sigmoid_width=self.sigmoid_width, - sigmoid_offset=self.sigmoid_offset, - sigmoid_exponent=self.sigmoid_exponent, + sharpness=self.sharpness, + offset=self.offset, + base_multiplier=self.base_multiplier, P_LH_prefactor=self.P_LH_prefactor, ) @@ -92,36 +95,41 @@ class ProfileValueSaturation(torax_pydantic.BaseModelFrozen): This saturation model triggers an increase in pedestal transport when the pedestal temperature and density are above the values requested by the - pedestal model. The increase is a smooth sigmoid function of the ratio of the + pedestal model. The increase is a smooth linear function of the ratio of the current value to the value requested by the pedestal model. The formula is - normalized_deviation = (current - target) / target - offset - transport_multiplier = 1 / (1 - sigmoid(normalized_deviation / width)) - transport_multiplier = transport_multiplier**exponent + transport_multiplier = 1 + alpha * base_multiplier, + where alpha is a softplus function of the normalized deviation from the target + value, with given steepness and offset: + x = (current - target) / target - offset + alpha = log(1 + exp(steepness * x)) - The transport multiplier is then clipped to the range - [min_transport_multiplier, max_transport_multiplier]. Attributes: - sigmoid_width: Dimensionless width of the sigmoid function for smoothing the - saturation window. Increase for a smoother saturation, but doing so may - lead to starting saturation at a temperature or density below the target - values. - sigmoid_offset: Dimensionless offset of the saturation window. Increase to - start saturation at a higher temperature or density. - sigmoid_exponent: The exponent of the transport multiplier. Increase for a - faster increase in transport once saturation starts. + steepness: Scaling factor applied to the argument of the softplus function, + setting the steepness of the smooth saturation function. Decrease for a + smoother saturation, which may be more numerically stable but may lead to + starting saturation at a temperature or density below the target values. + offset: Bias applied to the argument of the softplus function, setting the + dimensionless offset of the saturation window. Increase to start + saturation at a higher temperature or density. + base_multiplier: The base value of the transport multiplier. Increase for + stronger increases in transport once saturation starts. """ model_name: Annotated[Literal["profile_value"], torax_pydantic.JAX_STATIC] = ( "profile_value" ) - sigmoid_width: pydantic.PositiveFloat = 0.1 - sigmoid_offset: Annotated[ + steepness: pydantic.PositiveFloat = 100.0 + # Default offset is > 0 as otherwise saturation starts too early. This is + # because the softplus function is nonzero before the argument is zero. + offset: Annotated[ array_typing.FloatScalar, pydantic.Field(ge=-10.0, le=10.0) - ] = 0.0 - sigmoid_exponent: pydantic.PositiveFloat = 1.0 + ] = 0.1 + base_multiplier: Annotated[ + array_typing.FloatScalar, pydantic.Field(gt=1.0) + ] = 1e6 def build_saturation_model( self, @@ -133,9 +141,9 @@ def build_runtime_params( ) -> runtime_params.SaturationRuntimeParams: del t return runtime_params.SaturationRuntimeParams( - sigmoid_width=self.sigmoid_width, - sigmoid_offset=self.sigmoid_offset, - sigmoid_exponent=self.sigmoid_exponent, + steepness=self.steepness, + offset=self.offset, + base_multiplier=self.base_multiplier, ) @@ -169,12 +177,6 @@ class BasePedestal(torax_pydantic.BaseModelFrozen, abc.ABC): saturation_model: SaturationConfig = torax_pydantic.ValidatedDefault( ProfileValueSaturation() ) - max_transport_multiplier: torax_pydantic.TimeVaryingScalar = ( - torax_pydantic.ValidatedDefault(10.0) - ) - min_transport_multiplier: torax_pydantic.TimeVaryingScalar = ( - torax_pydantic.ValidatedDefault(0.1) - ) @pydantic.model_validator(mode="before") @classmethod @@ -205,8 +207,6 @@ def build_runtime_params( mode=self.mode, formation=self.formation_model.build_runtime_params(t), saturation=self.saturation_model.build_runtime_params(t), - max_transport_multiplier=self.max_transport_multiplier.get_value(t), - min_transport_multiplier=self.min_transport_multiplier.get_value(t), ) @@ -262,8 +262,6 @@ def build_runtime_params( rho_norm_ped_top=self.rho_norm_ped_top.get_value(t), formation=base_runtime_params.formation, saturation=base_runtime_params.saturation, - max_transport_multiplier=base_runtime_params.max_transport_multiplier, - min_transport_multiplier=base_runtime_params.min_transport_multiplier, ) @@ -318,8 +316,6 @@ def build_runtime_params( rho_norm_ped_top=self.rho_norm_ped_top.get_value(t), formation=base_runtime_params.formation, saturation=base_runtime_params.saturation, - max_transport_multiplier=base_runtime_params.max_transport_multiplier, - min_transport_multiplier=base_runtime_params.min_transport_multiplier, ) @@ -351,8 +347,6 @@ def build_runtime_params( mode=base_runtime_params.mode, formation=base_runtime_params.formation, saturation=base_runtime_params.saturation, - max_transport_multiplier=base_runtime_params.max_transport_multiplier, - min_transport_multiplier=base_runtime_params.min_transport_multiplier, ) diff --git a/torax/_src/pedestal_model/runtime_params.py b/torax/_src/pedestal_model/runtime_params.py index 18a040e85..73e24c620 100644 --- a/torax/_src/pedestal_model/runtime_params.py +++ b/torax/_src/pedestal_model/runtime_params.py @@ -37,9 +37,9 @@ class Mode(enum.Enum): class FormationRuntimeParams: """Runtime params for pedestal formation models.""" - sigmoid_width: array_typing.FloatScalar = 0.1 - sigmoid_offset: array_typing.FloatScalar = 0.0 - sigmoid_exponent: array_typing.FloatScalar = 1.0 + sharpness: array_typing.FloatScalar + offset: array_typing.FloatScalar + base_multiplier: array_typing.FloatScalar @jax.tree_util.register_dataclass @@ -47,9 +47,9 @@ class FormationRuntimeParams: class SaturationRuntimeParams: """Runtime params for pedestal saturation models.""" - sigmoid_width: array_typing.FloatScalar = 0.1 - sigmoid_offset: array_typing.FloatScalar = 0.0 - sigmoid_exponent: array_typing.FloatScalar = 1.0 + steepness: array_typing.FloatScalar + offset: array_typing.FloatScalar + base_multiplier: array_typing.FloatScalar @jax.tree_util.register_dataclass @@ -59,7 +59,5 @@ class RuntimeParams: set_pedestal: array_typing.BoolScalar mode: Mode = dataclasses.field(metadata={"static": True}) - min_transport_multiplier: array_typing.FloatScalar - max_transport_multiplier: array_typing.FloatScalar formation: FormationRuntimeParams saturation: SaturationRuntimeParams diff --git a/torax/_src/pedestal_model/saturation/profile_value_saturation_model.py b/torax/_src/pedestal_model/saturation/profile_value_saturation_model.py index f468839f3..0a5a6d502 100644 --- a/torax/_src/pedestal_model/saturation/profile_value_saturation_model.py +++ b/torax/_src/pedestal_model/saturation/profile_value_saturation_model.py @@ -91,18 +91,11 @@ def _calculate_multiplier( Returns: The transport increase multiplier. """ - width = pedestal_runtime_params.saturation.sigmoid_width - exponent = pedestal_runtime_params.saturation.sigmoid_exponent - offset = pedestal_runtime_params.saturation.sigmoid_offset + steepness = pedestal_runtime_params.saturation.steepness + offset = pedestal_runtime_params.saturation.offset + base_multiplier = pedestal_runtime_params.saturation.base_multiplier normalized_deviation = (current - target) / target - offset - transport_multiplier = 1 / ( - 1 - jax.nn.sigmoid(normalized_deviation / width) + transport_multiplier = 1 + base_multiplier * jax.nn.softplus( + normalized_deviation * steepness ) - transport_multiplier = transport_multiplier**exponent - transport_multiplier = jnp.clip( - transport_multiplier, - min=pedestal_runtime_params.min_transport_multiplier, - max=pedestal_runtime_params.max_transport_multiplier, - ) - return transport_multiplier diff --git a/torax/_src/pedestal_model/saturation/tests/profile_value_saturation_model_test.py b/torax/_src/pedestal_model/saturation/tests/profile_value_saturation_model_test.py index ef570c043..e5bb35544 100644 --- a/torax/_src/pedestal_model/saturation/tests/profile_value_saturation_model_test.py +++ b/torax/_src/pedestal_model/saturation/tests/profile_value_saturation_model_test.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import dataclasses from absl.testing import absltest from absl.testing import parameterized import numpy as np -from torax._src import constants from torax._src.config import build_runtime_params from torax._src.core_profiles import initialization from torax._src.pedestal_model import pedestal_model_output @@ -49,18 +47,18 @@ def setUp(self): @parameterized.named_parameters( dict( testcase_name='active', - # P_current >> P_target -> saturation is active. - P_current_over_P_target=1e3, + # T_current >> T_target -> saturation is active. + T_target_over_T_current=1e-3, ), dict( testcase_name='inactive', - # P_current << P_target -> no saturation. - P_current_over_P_target=1e-3, + # T_current << T_target -> no saturation. + T_target_over_T_current=1e3, ), ) def test_saturation_multiplier( self, - P_current_over_P_target, + T_target_over_T_current, ): saturation_model = ( profile_value_saturation_model.ProfileValueSaturationModel() @@ -68,24 +66,16 @@ def test_saturation_multiplier( # For this test, we put the pedestal top at the last grid point. ped_top_idx = -1 - - # Get the current pressure at the pedestal top. - current_pressure = self.core_profiles.pressure_thermal_total.value[ - ped_top_idx - ] + current_T_e_ped = self.core_profiles.T_e.face_value()[ped_top_idx] # Construct a pedestal output that is asking for a pedestal with - # target_pressure. - target_pressure = current_pressure / P_current_over_P_target + # target temperature. pedestal_output = pedestal_model_output.PedestalModelOutput( rho_norm_ped_top=self.geo.rho_face[ped_top_idx], rho_norm_ped_top_idx=ped_top_idx, - # Set T_i_ped, T_e_ped, and n_e_ped such that target_pressure is - # achieved. This case has n_impurity = 0 and Z_i = 1, so - # P = (T_i + T_e)*keV_to_J*n_e. - T_i_ped=(1.0 / constants.CONSTANTS.keV_to_J) / 2, - T_e_ped=(1.0 / constants.CONSTANTS.keV_to_J) / 2, - n_e_ped=target_pressure, + T_i_ped=1.0, + T_e_ped=current_T_e_ped * T_target_over_T_current, + n_e_ped=1.0, ) transport_multipliers = saturation_model( @@ -95,16 +85,14 @@ def test_saturation_multiplier( pedestal_output, ) - if P_current_over_P_target > 1.0: - # If the current pressure is above the target pressure, we expect the - # multiplier to be greater than 1.0. - for multiplier in dataclasses.asdict(transport_multipliers).values(): - self.assertGreater(multiplier, 1.0) + if T_target_over_T_current > 1.0: + # If the target temperature is above the current temperature, we expect + # the multiplier to be equal to 1.0 - the pedestal is not saturated. + np.testing.assert_allclose(transport_multipliers.chi_e_multiplier, 1.0) else: - # If the current pressure is below the target pressure, we expect the - # multiplier to be 1.0. - for multiplier in dataclasses.asdict(transport_multipliers).values(): - np.testing.assert_allclose(multiplier, 1.0, rtol=1e-3) + # If the target temperature is below the current temperature, we expect + # the multiplier to be greater than 1.0 - the pedestal is saturated. + self.assertGreater(transport_multipliers.chi_e_multiplier, 1.0) if __name__ == '__main__': diff --git a/torax/_src/pedestal_model/tests/register_model_test.py b/torax/_src/pedestal_model/tests/register_model_test.py index 33623fdb6..25a45e8b4 100644 --- a/torax/_src/pedestal_model/tests/register_model_test.py +++ b/torax/_src/pedestal_model/tests/register_model_test.py @@ -90,8 +90,6 @@ def build_runtime_params( mode=self.mode, formation=self.formation_model.build_runtime_params(t), saturation=self.saturation_model.build_runtime_params(t), - max_transport_multiplier=self.max_transport_multiplier.get_value(t), - min_transport_multiplier=self.min_transport_multiplier.get_value(t), ) diff --git a/torax/tests/pedestal_test.py b/torax/tests/pedestal_test.py index 6a28064f5..ee1e2d88e 100644 --- a/torax/tests/pedestal_test.py +++ b/torax/tests/pedestal_test.py @@ -64,8 +64,6 @@ def build_runtime_params( mode=self.mode, formation=self.formation_model.build_runtime_params(t), saturation=self.saturation_model.build_runtime_params(t), - max_transport_multiplier=self.max_transport_multiplier.get_value(t), - min_transport_multiplier=self.min_transport_multiplier.get_value(t), ) diff --git a/torax/tests/test_data/test_iterhybrid_predictor_corrector_lh_transition.py b/torax/tests/test_data/test_iterhybrid_predictor_corrector_lh_transition.py index 974baec72..ecc321796 100644 --- a/torax/tests/test_data/test_iterhybrid_predictor_corrector_lh_transition.py +++ b/torax/tests/test_data/test_iterhybrid_predictor_corrector_lh_transition.py @@ -60,20 +60,8 @@ 'T_e_ped': 4.5, 'n_e_ped': 0.62e20, 'rho_norm_ped_top': 0.9, - 'saturation_model': { - 'model_name': 'profile_value', - 'sigmoid_width': 0.1, - 'sigmoid_exponent': 1.0, - 'sigmoid_offset': 0.0, - }, - 'formation_model': { - 'model_name': 'martin', - 'sigmoid_width': 1e-3, - 'sigmoid_exponent': 3.0, - 'sigmoid_offset': 0.0, - }, - 'max_transport_multiplier': 10, - 'min_transport_multiplier': 0.1, + 'saturation_model': {'model_name': 'profile_value'}, + 'formation_model': {'model_name': 'martin'}, } # Use nonlinear solver, as linear solver struggles with the fast dynamics of the