Skip to content

Commit c9ab82e

Browse files
author
Torax team
committed
Merge pull request #1330 from google-deepmind:pedestal-transport
PiperOrigin-RevId: 787238780
2 parents 7ec1000 + 7efed82 commit c9ab82e

File tree

8 files changed

+326
-91
lines changed

8 files changed

+326
-91
lines changed

docs/configuration.rst

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,17 +1090,30 @@ combined
10901090

10911091
A combined (additive) model, where the total transport coefficients are
10921092
calculated by summing contributions from a list of component models. Each
1093-
component model is active only within its defined radial domain, set using
1094-
``rho_min``` and ``rho_max``. These zones can be overlapping or
1095-
non-overlapping; in regions of overlap, the total transport coefficients are
1096-
computed by adding the contributions from component models active at those
1097-
coordinates. Post-processing (clipping and smoothing) is performed on the
1098-
summed value.
1093+
component model is active only within its defined radial domain, which can
1094+
be overlapping or non-overlapping; in regions of overlap, the total
1095+
transport coefficients are computed by adding the contributions from
1096+
component models active at those coordinates.
1097+
For individual core transport models defined in ``transport_models``, the active
1098+
domain (where transport coefficients are non-zero) is set by ``rho_min``` and
1099+
``rho_max``. If a pedestal is active, the active domain is then limited by
1100+
``rho_norm_ped_top`` if ``rho_norm_ped_top`` is less than ``rho_max``.
1101+
``rho_norm_ped_top`` is set in the ``pedestal`` section of the config.
1102+
For models in ``pedestal_transport_models``, the active domain is only for
1103+
radii above ``rho_norm_ped_top``.
1104+
Post-processing (clipping and smoothing) is performed on the summed
1105+
values from all component models, including in the pedestal.
10991106

11001107
The runtime parameters are as follows.
11011108

11021109
``transport_models`` (list[dict])
1103-
A list containing config dicts for the component transport models.
1110+
A list containing config dicts for the component models for turbulent
1111+
transport in the core.
1112+
1113+
``pedestal_transport_models`` (list[dict])
1114+
A list containing config dicts for the component models for turbulent
1115+
transport in the pedestal.
1116+
11041117

11051118
.. warning::
11061119
TORAX will throw a ``ValueError`` if any of the component transport
@@ -1115,29 +1128,36 @@ The runtime parameters are as follows.
11151128
Example:
11161129

11171130
.. code-block:: python
1118-
1131+
...
11191132
'transport': {
1120-
'model_name': 'combined',
1121-
'transport_models': [
1122-
{
1123-
'model_name': 'constant',
1124-
'chi_i': 1.0,
1125-
'rho_max': 0.3,
1126-
},
1127-
{
1128-
'model_name': 'constant',
1129-
'chi_i': 2.0,
1130-
'rho_min': 0.2
1131-
'rho_max': 0.5,
1132-
},
1133-
{
1134-
'model_name': 'constant',
1135-
'chi_i': 0.5,
1136-
'rho_min': 0.5
1137-
'rho_max': 1.0,
1138-
},
1139-
],
1140-
}
1133+
'model_name': 'combined',
1134+
'transport_models': [
1135+
{
1136+
'model_name': 'constant',
1137+
'chi_i': 1.0,
1138+
'rho_max': 0.3,
1139+
},
1140+
{
1141+
'model_name': 'constant',
1142+
'chi_i': 2.0,
1143+
'rho_min': 0.2,
1144+
},
1145+
],
1146+
'pedestal_transport_models': [
1147+
{
1148+
'model_name': 'constant',
1149+
'chi_i': 0.5,
1150+
},
1151+
],
1152+
},
1153+
'pedestal': {
1154+
'model_name': 'set_T_ped_n_ped',
1155+
'set_pedestal': True,
1156+
'rho_norm_ped_top': 0.9,
1157+
'n_e_ped': 0.8,
1158+
'n_e_ped_is_fGW': True,
1159+
},
1160+
...
11411161
11421162
This would produce a ``chi_i`` profile that looks like the following.
11431163

@@ -1147,9 +1167,9 @@ This would produce a ``chi_i`` profile that looks like the following.
11471167

11481168
Note that in the region :math:`[0, 0.2]`, only the first component is active,
11491169
so ``chi_i = 1.0``. In :math:`(0.2, 0.3]` the first two components are both
1150-
active, leading to a combined value of ``chi_i = 3.0``. In :math:`(0.3, 0.5]`,
1151-
only the second model is active (``chi_i = 2.0``), and in :math:`(0.5, 1.0]`
1152-
only the fourth model is active (``chi_i = 0.5``).
1170+
active, leading to a combined value of ``chi_i = 3.0``. In :math:`(0.3, 0.9]`,
1171+
only the second model is active (``chi_i = 2.0``), and in :math:`(0.9, 1.0]`
1172+
only the pedestal transport model is active (``chi_i = 0.5``).
11531173

11541174

11551175
sources
509 Bytes
Loading

docs/scripts/combined_transport_example.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Script for plotting the combined transport model in the docs."""
1616
from typing import Sequence
17+
1718
from absl import app
1819
import matplotlib.pyplot as plt
1920
import torax
@@ -33,7 +34,13 @@ def main(argv: Sequence[str]) -> None:
3334
'geometry_type': 'circular',
3435
'n_rho': 30, # for higher resolution plotting
3536
},
36-
'pedestal': {},
37+
'pedestal': {
38+
'model_name': 'set_T_ped_n_ped',
39+
'set_pedestal': True,
40+
'rho_norm_ped_top': 0.9,
41+
'n_e_ped': 0.8,
42+
'n_e_ped_is_fGW': True,
43+
},
3744
'neoclassical': {},
3845
'sources': {},
3946
'solver': {},
@@ -49,13 +56,12 @@ def main(argv: Sequence[str]) -> None:
4956
'model_name': 'constant',
5057
'chi_i': 2.0,
5158
'rho_min': 0.2,
52-
'rho_max': 0.5,
5359
},
60+
],
61+
'pedestal_transport_models': [
5462
{
5563
'model_name': 'constant',
5664
'chi_i': 0.5,
57-
'rho_min': 0.5,
58-
'rho_max': 1.0,
5965
},
6066
],
6167
},

torax/_src/transport_model/combined.py

Lines changed: 126 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
A class for combining transport models.
1818
"""
1919
import dataclasses
20-
from typing import Sequence
20+
from typing import Callable, Sequence
2121

2222
import jax
23+
import jax.numpy as jnp
2324
from torax._src import state
2425
from torax._src.config import runtime_params_slice
2526
from torax._src.geometry import geometry
@@ -34,18 +35,73 @@
3435
@dataclasses.dataclass(frozen=True)
3536
class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
3637
transport_model_params: Sequence[runtime_params_lib.DynamicRuntimeParams]
38+
pedestal_transport_model_params: Sequence[
39+
runtime_params_lib.DynamicRuntimeParams
40+
]
3741

3842

3943
class CombinedTransportModel(transport_model_lib.TransportModel):
4044
"""Combines coefficients from a list of transport models."""
4145

4246
def __init__(
43-
self, transport_models: Sequence[transport_model_lib.TransportModel]
47+
self,
48+
transport_models: Sequence[transport_model_lib.TransportModel],
49+
pedestal_transport_models: Sequence[transport_model_lib.TransportModel],
4450
):
4551
super().__init__()
4652
self.transport_models = transport_models
53+
self.pedestal_transport_models = pedestal_transport_models
4754
self._frozen = True
4855

56+
def __call__(
57+
self,
58+
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
59+
geo: geometry.Geometry,
60+
core_profiles: state.CoreProfiles,
61+
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
62+
) -> transport_model_lib.TurbulentTransport:
63+
if not getattr(self, "_frozen", False):
64+
raise RuntimeError(
65+
f"Subclass implementation {type(self)} forgot to "
66+
"freeze at the end of __init__."
67+
)
68+
69+
transport_runtime_params = dynamic_runtime_params_slice.transport
70+
71+
# Calculate the transport coefficients - includes contribution from pedestal
72+
# and core transport models.
73+
transport_coeffs = self._call_implementation(
74+
transport_runtime_params,
75+
dynamic_runtime_params_slice,
76+
geo,
77+
core_profiles,
78+
pedestal_model_output,
79+
)
80+
81+
# In contrast to the base TransportModel, we do not apply domain restriction
82+
# as this is handled at the component model level
83+
84+
# Apply min/max clipping
85+
transport_coeffs = self._apply_clipping(
86+
transport_runtime_params,
87+
transport_coeffs,
88+
)
89+
90+
# In contrast to the base TransportModel, we do not apply patches, as these
91+
# should be handled by instantiating constant component models instead.
92+
# However, the rho_inner and rho_outer arguments are currently required
93+
# in the case where the inner/outer region are to be excluded from
94+
# smoothing. Smoothing is applied to
95+
# rho_inner < rho_norm < min(rho_ped_top, rho_outer) unless
96+
# smooth_everywhere is True.
97+
return self._smooth_coeffs(
98+
transport_runtime_params,
99+
dynamic_runtime_params_slice,
100+
geo,
101+
transport_coeffs,
102+
pedestal_model_output,
103+
)
104+
49105
def _call_implementation(
50106
self,
51107
transport_dynamic_runtime_params: runtime_params_lib.DynamicRuntimeParams,
@@ -72,48 +128,96 @@ def _call_implementation(
72128
# Required for pytype
73129
assert isinstance(transport_dynamic_runtime_params, DynamicRuntimeParams)
74130

75-
component_transport_coeffs_list = []
76-
77-
for component_model, component_params in zip(
78-
self.transport_models,
79-
transport_dynamic_runtime_params.transport_model_params,
80-
):
81-
# Use the component model's _call_implementation, rather than __call__
82-
# directly. This ensures postprocessing (clipping, smoothing, patches) are
83-
# performed on the output of CombinedTransportModel rather than its
84-
# component models.
131+
def apply_and_restrict(
132+
component_model: transport_model_lib.TransportModel,
133+
component_params: runtime_params_lib.DynamicRuntimeParams,
134+
restriction_fn: Callable[
135+
[
136+
runtime_params_lib.DynamicRuntimeParams,
137+
geometry.Geometry,
138+
transport_model_lib.TurbulentTransport,
139+
pedestal_model_lib.PedestalModelOutput,
140+
],
141+
transport_model_lib.TurbulentTransport,
142+
],
143+
) -> transport_model_lib.TurbulentTransport:
144+
# TODO(b/434175682): Consider only computing transport coefficients for
145+
# the active domain, rather than masking them out later. This could be
146+
# significantly more efficient especially for pedestal models, as these
147+
# are only active in a small region of the domain.
85148
component_transport_coeffs = component_model._call_implementation(
86149
component_params,
87150
dynamic_runtime_params_slice,
88151
geo,
89152
core_profiles,
90153
pedestal_model_output,
91154
)
92-
93-
# Apply domain restriction
94-
# This is a property of each component_model, so needs to be applied
95-
# at the component model level rather than the global level
96-
component_transport_coeffs = component_model._apply_domain_restriction(
155+
component_transport_coeffs = restriction_fn(
97156
component_params,
98157
geo,
99158
component_transport_coeffs,
100159
pedestal_model_output,
101160
)
102-
103-
component_transport_coeffs_list.append(component_transport_coeffs)
104-
161+
return component_transport_coeffs
162+
163+
pedestal_coeffs = [
164+
apply_and_restrict(
165+
model, params, self._apply_pedestal_domain_restriction
166+
)
167+
for model, params in zip(
168+
self.pedestal_transport_models,
169+
transport_dynamic_runtime_params.pedestal_transport_model_params,
170+
)
171+
]
172+
173+
core_coeffs = [
174+
apply_and_restrict(model, params, model._apply_domain_restriction)
175+
for model, params in zip(
176+
self.transport_models,
177+
transport_dynamic_runtime_params.transport_model_params,
178+
)
179+
]
180+
181+
# Combine the transport coefficients from core and pedestal models.
105182
combined_transport_coeffs = jax.tree.map(
106183
lambda *leaves: sum(leaves),
107-
*component_transport_coeffs_list,
184+
*pedestal_coeffs,
185+
*core_coeffs,
108186
)
109187

110188
return combined_transport_coeffs
111189

190+
def _apply_pedestal_domain_restriction(
191+
self,
192+
unused_transport_runtime_params: runtime_params_lib.DynamicRuntimeParams,
193+
geo: geometry.Geometry,
194+
transport_coeffs: transport_model_lib.TurbulentTransport,
195+
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
196+
) -> transport_model_lib.TurbulentTransport:
197+
del unused_transport_runtime_params
198+
active_mask = geo.rho_face_norm > pedestal_model_output.rho_norm_ped_top
199+
200+
chi_face_ion = jnp.where(active_mask, transport_coeffs.chi_face_ion, 0.0)
201+
chi_face_el = jnp.where(active_mask, transport_coeffs.chi_face_el, 0.0)
202+
d_face_el = jnp.where(active_mask, transport_coeffs.d_face_el, 0.0)
203+
v_face_el = jnp.where(active_mask, transport_coeffs.v_face_el, 0.0)
204+
205+
return dataclasses.replace(
206+
transport_coeffs,
207+
chi_face_ion=chi_face_ion,
208+
chi_face_el=chi_face_el,
209+
d_face_el=d_face_el,
210+
v_face_el=v_face_el,
211+
)
212+
112213
def __hash__(self):
113-
return hash(tuple(self.transport_models))
214+
return hash(
215+
tuple(self.transport_models) + tuple(self.pedestal_transport_models)
216+
)
114217

115218
def __eq__(self, other):
116219
return (
117220
isinstance(other, CombinedTransportModel)
118221
and self.transport_models == other.transport_models
222+
and self.pedestal_transport_models == other.pedestal_transport_models
119223
)

0 commit comments

Comments
 (0)