Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 47 additions & 32 deletions docs/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1090,17 +1090,25 @@ combined

A combined (additive) model, where the total transport coefficients are
calculated by summing contributions from a list of component models. Each
component model is active only within its defined radial domain, set using
``rho_min``` and ``rho_max``. These zones can be overlapping or
non-overlapping; in regions of overlap, the total transport coefficients are
computed by adding the contributions from component models active at those
coordinates. Post-processing (clipping and smoothing) is performed on the
summed value.
component model is active only within its defined radial domain, which can
be overlapping or non-overlapping; in regions of overlap, the total
transport coefficients are computed by adding the contributions from
component models active at those coordinates.
For models defined in ``transport_models``, the active domain is set by
``rho_min``` and ``rho_max``. For models in ``pedestal_transport_models``,
the active domain is set by the ``rho_norm_ped_top`` parameter from the
``pedestal`` section of the config.
Post-processing (clipping and smoothing) is performed on the summed
values from all component models, including in the pedestal.

The runtime parameters are as follows.

``transport_models`` (list[dict])
A list containing config dicts for the component transport models.
A list containing config dicts for the component models for turbulent transport in the core.

``pedestal_transport_models`` (list[dict])
A list containing config dicts for the component models for turbulent transport in the pedestal.


.. warning::
TORAX will throw a ``ValueError`` if any of the component transport
Expand All @@ -1115,29 +1123,36 @@ The runtime parameters are as follows.
Example:

.. code-block:: python

...
'transport': {
'model_name': 'combined',
'transport_models': [
{
'model_name': 'constant',
'chi_i': 1.0,
'rho_max': 0.3,
},
{
'model_name': 'constant',
'chi_i': 2.0,
'rho_min': 0.2
'rho_max': 0.5,
},
{
'model_name': 'constant',
'chi_i': 0.5,
'rho_min': 0.5
'rho_max': 1.0,
},
],
}
'model_name': 'combined',
'transport_models': [
{
'model_name': 'constant',
'chi_i': 1.0,
'rho_max': 0.3,
},
{
'model_name': 'constant',
'chi_i': 2.0,
'rho_min': 0.2,
},
],
'pedestal_transport_models': [
{
'model_name': 'constant',
'chi_i': 0.5,
},
],
},
'pedestal': {
'model_name': 'set_T_ped_n_ped',
'set_pedestal': True,
'rho_norm_ped_top': 0.9,
'n_e_ped': 0.8,
'n_e_ped_is_fGW': True,
},
...

This would produce a ``chi_i`` profile that looks like the following.

Expand All @@ -1147,9 +1162,9 @@ This would produce a ``chi_i`` profile that looks like the following.

Note that in the region :math:`[0, 0.2]`, only the first component is active,
so ``chi_i = 1.0``. In :math:`(0.2, 0.3]` the first two components are both
active, leading to a combined value of ``chi_i = 3.0``. In :math:`(0.3, 0.5]`,
only the second model is active (``chi_i = 2.0``), and in :math:`(0.5, 1.0]`
only the fourth model is active (``chi_i = 0.5``).
active, leading to a combined value of ``chi_i = 3.0``. In :math:`(0.3, 0.9]`,
only the second model is active (``chi_i = 2.0``), and in :math:`(0.9, 1.0]`
only the pedestal transport model is active (``chi_i = 0.5``).


sources
Expand Down
Binary file modified docs/images/combined_transport_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 10 additions & 4 deletions docs/scripts/combined_transport_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Script for plotting the combined transport model in the docs."""
from typing import Sequence

from absl import app
import matplotlib.pyplot as plt
import torax
Expand All @@ -33,7 +34,13 @@ def main(argv: Sequence[str]) -> None:
'geometry_type': 'circular',
'n_rho': 30, # for higher resolution plotting
},
'pedestal': {},
'pedestal': {
'model_name': 'set_T_ped_n_ped',
'set_pedestal': True,
'rho_norm_ped_top': 0.9,
'n_e_ped': 0.8,
'n_e_ped_is_fGW': True,
},
'neoclassical': {},
'sources': {},
'solver': {},
Expand All @@ -49,13 +56,12 @@ def main(argv: Sequence[str]) -> None:
'model_name': 'constant',
'chi_i': 2.0,
'rho_min': 0.2,
'rho_max': 0.5,
},
],
'pedestal_transport_models': [
{
'model_name': 'constant',
'chi_i': 0.5,
'rho_min': 0.5,
'rho_max': 1.0,
},
],
},
Expand Down
136 changes: 114 additions & 22 deletions torax/_src/transport_model/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
A class for combining transport models.
"""
import dataclasses
from typing import Sequence
from typing import Callable, Sequence

import jax
import jax.numpy as jnp
from torax._src import state
from torax._src.config import runtime_params_slice
from torax._src.geometry import geometry
Expand All @@ -34,18 +35,72 @@
@dataclasses.dataclass(frozen=True)
class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
transport_model_params: Sequence[runtime_params_lib.DynamicRuntimeParams]
pedestal_transport_model_params: Sequence[
runtime_params_lib.DynamicRuntimeParams
]


class CombinedTransportModel(transport_model_lib.TransportModel):
"""Combines coefficients from a list of transport models."""

def __init__(
self, transport_models: Sequence[transport_model_lib.TransportModel]
self,
transport_models: Sequence[transport_model_lib.TransportModel],
pedestal_transport_models: Sequence[transport_model_lib.TransportModel],
):
super().__init__()
self.transport_models = transport_models
self.pedestal_transport_models = pedestal_transport_models
self._frozen = True

def __call__(
self,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
) -> transport_model_lib.TurbulentTransport:
if not getattr(self, "_frozen", False):
raise RuntimeError(
f"Subclass implementation {type(self)} forgot to "
"freeze at the end of __init__."
)

transport_runtime_params = dynamic_runtime_params_slice.transport

# Calculate the transport coefficients - includes contribution from pedestal
# and core transport models.
transport_coeffs = self._call_implementation(
transport_runtime_params,
dynamic_runtime_params_slice,
geo,
core_profiles,
pedestal_model_output,
)

# In contrast to the base TransportModel, we do not apply domain restriction
# as this is handled at the component model level

# Apply min/max clipping
transport_coeffs = self._apply_clipping(
transport_runtime_params,
transport_coeffs,
)

# In contrast to the base TransportModel, we do not apply patches, as these
# should be handled by instantiating constant component models instead.
# However, the rho_inner and rho_outer arguments are currently required
# in the case where the inner/outer region are to be excluded from smoothing.
# Smoothing is applied to rho_inner < rho_norm < min(rho_ped_top, rho_outer)
# unless smooth_everywhere is True.
return self._smooth_coeffs(
transport_runtime_params,
dynamic_runtime_params_slice,
geo,
transport_coeffs,
pedestal_model_output,
)

def _call_implementation(
self,
transport_dynamic_runtime_params: runtime_params_lib.DynamicRuntimeParams,
Expand All @@ -72,48 +127,85 @@ def _call_implementation(
# Required for pytype
assert isinstance(transport_dynamic_runtime_params, DynamicRuntimeParams)

component_transport_coeffs_list = []

for component_model, component_params in zip(
self.transport_models,
transport_dynamic_runtime_params.transport_model_params,
):
# Use the component model's _call_implementation, rather than __call__
# directly. This ensures postprocessing (clipping, smoothing, patches) are
# performed on the output of CombinedTransportModel rather than its
# component models.
def apply_and_restrict(
component_model: transport_model_lib.TransportModel,
component_params: runtime_params_lib.DynamicRuntimeParams,
restriction_fn: Callable,
) -> transport_model_lib.TurbulentTransport:
# TODO: Consider only computing transport coefficients for the active
# domain, rather than masking them out later. This could be significantly
# more efficient especially for pedestal models, as these are only active
# in a small region of the domain.
component_transport_coeffs = component_model._call_implementation(
component_params,
dynamic_runtime_params_slice,
geo,
core_profiles,
pedestal_model_output,
)

# Apply domain restriction
# This is a property of each component_model, so needs to be applied
# at the component model level rather than the global level
component_transport_coeffs = component_model._apply_domain_restriction(
component_transport_coeffs = restriction_fn(
component_params,
geo,
component_transport_coeffs,
pedestal_model_output,
)

component_transport_coeffs_list.append(component_transport_coeffs)

return component_transport_coeffs

pedestal_coeffs = [
apply_and_restrict(
model, params, self._apply_pedestal_domain_restriction
)
for model, params in zip(
self.pedestal_transport_models,
transport_dynamic_runtime_params.pedestal_transport_model_params,
)
]

core_coeffs = [
apply_and_restrict(model, params, model._apply_domain_restriction)
for model, params in zip(
self.transport_models,
transport_dynamic_runtime_params.transport_model_params,
)
]

# Combine the transport coefficients from core and pedestal models.
combined_transport_coeffs = jax.tree.map(
lambda *leaves: sum(leaves),
*component_transport_coeffs_list,
*pedestal_coeffs,
*core_coeffs,
)

return combined_transport_coeffs

def _apply_pedestal_domain_restriction(
self,
transport_runtime_params: runtime_params_lib.DynamicRuntimeParams,
geo: geometry.Geometry,
transport_coeffs: transport_model_lib.TurbulentTransport,
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
) -> transport_model_lib.TurbulentTransport:
active_mask = geo.rho_face_norm > pedestal_model_output.rho_norm_ped_top

chi_face_ion = jnp.where(active_mask, transport_coeffs.chi_face_ion, 0.0)
chi_face_el = jnp.where(active_mask, transport_coeffs.chi_face_el, 0.0)
d_face_el = jnp.where(active_mask, transport_coeffs.d_face_el, 0.0)
v_face_el = jnp.where(active_mask, transport_coeffs.v_face_el, 0.0)

return dataclasses.replace(
transport_coeffs,
chi_face_ion=chi_face_ion,
chi_face_el=chi_face_el,
d_face_el=d_face_el,
v_face_el=v_face_el,
)

def __hash__(self):
return hash(tuple(self.transport_models))
return hash(tuple(self.transport_models + self.pedestal_transport_models))

def __eq__(self, other):
return (
isinstance(other, CombinedTransportModel)
and self.transport_models == other.transport_models
and self.pedestal_transport_models == other.pedestal_transport_models
)
Loading