Skip to content

Commit e73a1e9

Browse files
committed
Add pedestal transport model section to CombinedTransportModel
1 parent 4d18121 commit e73a1e9

File tree

8 files changed

+307
-89
lines changed

8 files changed

+307
-89
lines changed

docs/configuration.rst

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,17 +1090,25 @@ combined
10901090

10911091
A combined (additive) model, where the total transport coefficients are
10921092
calculated by summing contributions from a list of component models. Each
1093-
component model is active only within its defined radial domain, set using
1094-
``rho_min``` and ``rho_max``. These zones can be overlapping or
1095-
non-overlapping; in regions of overlap, the total transport coefficients are
1096-
computed by adding the contributions from component models active at those
1097-
coordinates. Post-processing (clipping and smoothing) is performed on the
1098-
summed value.
1093+
component model is active only within its defined radial domain, which can
1094+
be overlapping or non-overlapping; in regions of overlap, the total
1095+
transport coefficients are computed by adding the contributions from
1096+
component models active at those coordinates.
1097+
For models defined in ``transport_models``, the active domain is set by
1098+
``rho_min``` and ``rho_max``. For models in ``pedestal_transport_models``,
1099+
the active domain is set by the ``rho_norm_ped_top`` parameter from the
1100+
``pedestal`` section of the config.
1101+
Post-processing (clipping and smoothing) is performed on the summed
1102+
values from all component models, including in the pedestal.
10991103

11001104
The runtime parameters are as follows.
11011105

11021106
``transport_models`` (list[dict])
1103-
A list containing config dicts for the component transport models.
1107+
A list containing config dicts for the component models for turbulent transport in the core.
1108+
1109+
``pedestal_transport_models`` (list[dict])
1110+
A list containing config dicts for the component models for turbulent transport in the pedestal.
1111+
11041112

11051113
.. warning::
11061114
TORAX will throw a ``ValueError`` if any of the component transport
@@ -1115,29 +1123,36 @@ The runtime parameters are as follows.
11151123
Example:
11161124

11171125
.. code-block:: python
1118-
1126+
...
11191127
'transport': {
1120-
'model_name': 'combined',
1121-
'transport_models': [
1122-
{
1123-
'model_name': 'constant',
1124-
'chi_i': 1.0,
1125-
'rho_max': 0.3,
1126-
},
1127-
{
1128-
'model_name': 'constant',
1129-
'chi_i': 2.0,
1130-
'rho_min': 0.2
1131-
'rho_max': 0.5,
1132-
},
1133-
{
1134-
'model_name': 'constant',
1135-
'chi_i': 0.5,
1136-
'rho_min': 0.5
1137-
'rho_max': 1.0,
1138-
},
1139-
],
1140-
}
1128+
'model_name': 'combined',
1129+
'transport_models': [
1130+
{
1131+
'model_name': 'constant',
1132+
'chi_i': 1.0,
1133+
'rho_max': 0.3,
1134+
},
1135+
{
1136+
'model_name': 'constant',
1137+
'chi_i': 2.0,
1138+
'rho_min': 0.2,
1139+
},
1140+
],
1141+
'pedestal_transport_models': [
1142+
{
1143+
'model_name': 'constant',
1144+
'chi_i': 0.5,
1145+
},
1146+
],
1147+
},
1148+
'pedestal': {
1149+
'model_name': 'set_T_ped_n_ped',
1150+
'set_pedestal': True,
1151+
'rho_norm_ped_top': 0.9,
1152+
'n_e_ped': 0.8,
1153+
'n_e_ped_is_fGW': True,
1154+
},
1155+
...
11411156
11421157
This would produce a ``chi_i`` profile that looks like the following.
11431158

@@ -1147,9 +1162,9 @@ This would produce a ``chi_i`` profile that looks like the following.
11471162

11481163
Note that in the region :math:`[0, 0.2]`, only the first component is active,
11491164
so ``chi_i = 1.0``. In :math:`(0.2, 0.3]` the first two components are both
1150-
active, leading to a combined value of ``chi_i = 3.0``. In :math:`(0.3, 0.5]`,
1151-
only the second model is active (``chi_i = 2.0``), and in :math:`(0.5, 1.0]`
1152-
only the fourth model is active (``chi_i = 0.5``).
1165+
active, leading to a combined value of ``chi_i = 3.0``. In :math:`(0.3, 0.9]`,
1166+
only the second model is active (``chi_i = 2.0``), and in :math:`(0.9, 1.0]`
1167+
only the pedestal transport model is active (``chi_i = 0.5``).
11531168

11541169

11551170
sources
509 Bytes
Loading

docs/scripts/combined_transport_example.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Script for plotting the combined transport model in the docs."""
1616
from typing import Sequence
17+
1718
from absl import app
1819
import matplotlib.pyplot as plt
1920
import torax
@@ -33,7 +34,13 @@ def main(argv: Sequence[str]) -> None:
3334
'geometry_type': 'circular',
3435
'n_rho': 30, # for higher resolution plotting
3536
},
36-
'pedestal': {},
37+
'pedestal': {
38+
'model_name': 'set_T_ped_n_ped',
39+
'set_pedestal': True,
40+
'rho_norm_ped_top': 0.9,
41+
'n_e_ped': 0.8,
42+
'n_e_ped_is_fGW': True,
43+
},
3744
'neoclassical': {},
3845
'sources': {},
3946
'solver': {},
@@ -49,13 +56,12 @@ def main(argv: Sequence[str]) -> None:
4956
'model_name': 'constant',
5057
'chi_i': 2.0,
5158
'rho_min': 0.2,
52-
'rho_max': 0.5,
5359
},
60+
],
61+
'pedestal_transport_models': [
5462
{
5563
'model_name': 'constant',
5664
'chi_i': 0.5,
57-
'rho_min': 0.5,
58-
'rho_max': 1.0,
5965
},
6066
],
6167
},

torax/_src/transport_model/combined.py

Lines changed: 114 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
A class for combining transport models.
1818
"""
1919
import dataclasses
20-
from typing import Sequence
20+
from typing import Callable, Sequence
2121

2222
import jax
23+
import jax.numpy as jnp
2324
from torax._src import state
2425
from torax._src.config import runtime_params_slice
2526
from torax._src.geometry import geometry
@@ -34,18 +35,72 @@
3435
@dataclasses.dataclass(frozen=True)
3536
class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
3637
transport_model_params: Sequence[runtime_params_lib.DynamicRuntimeParams]
38+
pedestal_transport_model_params: Sequence[
39+
runtime_params_lib.DynamicRuntimeParams
40+
]
3741

3842

3943
class CombinedTransportModel(transport_model_lib.TransportModel):
4044
"""Combines coefficients from a list of transport models."""
4145

4246
def __init__(
43-
self, transport_models: Sequence[transport_model_lib.TransportModel]
47+
self,
48+
transport_models: Sequence[transport_model_lib.TransportModel],
49+
pedestal_transport_models: Sequence[transport_model_lib.TransportModel],
4450
):
4551
super().__init__()
4652
self.transport_models = transport_models
53+
self.pedestal_transport_models = pedestal_transport_models
4754
self._frozen = True
4855

56+
def __call__(
57+
self,
58+
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
59+
geo: geometry.Geometry,
60+
core_profiles: state.CoreProfiles,
61+
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
62+
) -> transport_model_lib.TurbulentTransport:
63+
if not getattr(self, "_frozen", False):
64+
raise RuntimeError(
65+
f"Subclass implementation {type(self)} forgot to "
66+
"freeze at the end of __init__."
67+
)
68+
69+
transport_runtime_params = dynamic_runtime_params_slice.transport
70+
71+
# Calculate the transport coefficients - includes contribution from pedestal
72+
# and core transport models.
73+
transport_coeffs = self._call_implementation(
74+
transport_runtime_params,
75+
dynamic_runtime_params_slice,
76+
geo,
77+
core_profiles,
78+
pedestal_model_output,
79+
)
80+
81+
# In contrast to the base TransportModel, we do not apply domain restriction
82+
# as this is handled at the component model level
83+
84+
# Apply min/max clipping
85+
transport_coeffs = self._apply_clipping(
86+
transport_runtime_params,
87+
transport_coeffs,
88+
)
89+
90+
# In contrast to the base TransportModel, we do not apply patches, as these
91+
# should be handled by instantiating constant component models instead.
92+
# However, the rho_inner and rho_outer arguments are currently required
93+
# in the case where the inner/outer region are to be excluded from smoothing.
94+
# Smoothing is applied to rho_inner < rho_norm < min(rho_ped_top, rho_outer)
95+
# unless smooth_everywhere is True.
96+
return self._smooth_coeffs(
97+
transport_runtime_params,
98+
dynamic_runtime_params_slice,
99+
geo,
100+
transport_coeffs,
101+
pedestal_model_output,
102+
)
103+
49104
def _call_implementation(
50105
self,
51106
transport_dynamic_runtime_params: runtime_params_lib.DynamicRuntimeParams,
@@ -72,48 +127,85 @@ def _call_implementation(
72127
# Required for pytype
73128
assert isinstance(transport_dynamic_runtime_params, DynamicRuntimeParams)
74129

75-
component_transport_coeffs_list = []
76-
77-
for component_model, component_params in zip(
78-
self.transport_models,
79-
transport_dynamic_runtime_params.transport_model_params,
80-
):
81-
# Use the component model's _call_implementation, rather than __call__
82-
# directly. This ensures postprocessing (clipping, smoothing, patches) are
83-
# performed on the output of CombinedTransportModel rather than its
84-
# component models.
130+
def apply_and_restrict(
131+
component_model: transport_model_lib.TransportModel,
132+
component_params: runtime_params_lib.DynamicRuntimeParams,
133+
restriction_fn: Callable,
134+
) -> transport_model_lib.TurbulentTransport:
135+
# TODO: Consider only computing transport coefficients for the active
136+
# domain, rather than masking them out later. This could be significantly
137+
# more efficient especially for pedestal models, as these are only active
138+
# in a small region of the domain.
85139
component_transport_coeffs = component_model._call_implementation(
86140
component_params,
87141
dynamic_runtime_params_slice,
88142
geo,
89143
core_profiles,
90144
pedestal_model_output,
91145
)
92-
93-
# Apply domain restriction
94-
# This is a property of each component_model, so needs to be applied
95-
# at the component model level rather than the global level
96-
component_transport_coeffs = component_model._apply_domain_restriction(
146+
component_transport_coeffs = restriction_fn(
97147
component_params,
98148
geo,
99149
component_transport_coeffs,
100150
pedestal_model_output,
101151
)
102-
103-
component_transport_coeffs_list.append(component_transport_coeffs)
104-
152+
return component_transport_coeffs
153+
154+
pedestal_coeffs = [
155+
apply_and_restrict(
156+
model, params, self._apply_pedestal_domain_restriction
157+
)
158+
for model, params in zip(
159+
self.pedestal_transport_models,
160+
transport_dynamic_runtime_params.pedestal_transport_model_params,
161+
)
162+
]
163+
164+
core_coeffs = [
165+
apply_and_restrict(model, params, model._apply_domain_restriction)
166+
for model, params in zip(
167+
self.transport_models,
168+
transport_dynamic_runtime_params.transport_model_params,
169+
)
170+
]
171+
172+
# Combine the transport coefficients from core and pedestal models.
105173
combined_transport_coeffs = jax.tree.map(
106174
lambda *leaves: sum(leaves),
107-
*component_transport_coeffs_list,
175+
*pedestal_coeffs,
176+
*core_coeffs,
108177
)
109178

110179
return combined_transport_coeffs
111180

181+
def _apply_pedestal_domain_restriction(
182+
self,
183+
transport_runtime_params: runtime_params_lib.DynamicRuntimeParams,
184+
geo: geometry.Geometry,
185+
transport_coeffs: transport_model_lib.TurbulentTransport,
186+
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
187+
) -> transport_model_lib.TurbulentTransport:
188+
active_mask = geo.rho_face_norm > pedestal_model_output.rho_norm_ped_top
189+
190+
chi_face_ion = jnp.where(active_mask, transport_coeffs.chi_face_ion, 0.0)
191+
chi_face_el = jnp.where(active_mask, transport_coeffs.chi_face_el, 0.0)
192+
d_face_el = jnp.where(active_mask, transport_coeffs.d_face_el, 0.0)
193+
v_face_el = jnp.where(active_mask, transport_coeffs.v_face_el, 0.0)
194+
195+
return dataclasses.replace(
196+
transport_coeffs,
197+
chi_face_ion=chi_face_ion,
198+
chi_face_el=chi_face_el,
199+
d_face_el=d_face_el,
200+
v_face_el=v_face_el,
201+
)
202+
112203
def __hash__(self):
113-
return hash(tuple(self.transport_models))
204+
return hash(tuple(self.transport_models + self.pedestal_transport_models))
114205

115206
def __eq__(self, other):
116207
return (
117208
isinstance(other, CombinedTransportModel)
118209
and self.transport_models == other.transport_models
210+
and self.pedestal_transport_models == other.pedestal_transport_models
119211
)

0 commit comments

Comments
 (0)