Skip to content

Commit 6636467

Browse files
committed
Response to review
* Restore 'core_transport_models' -> 'transport_models' * Support multiple pedestal transport models, defaulting to an empty list * Check that rho_min/max and inner/outer patches not set on pedestal models * Update docs and image * Fix a test that was still using AssertRaises rather than AssertRaisesRegex, and was covering the wrong thing * Improve readability/code reuse in summing over transport models; added comment about future vmap * Add test to cover behaviour when no pedestal model supplied
1 parent 9b387ed commit 6636467

File tree

8 files changed

+222
-141
lines changed

8 files changed

+222
-141
lines changed

docs/configuration.rst

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,17 +1069,25 @@ combined
10691069

10701070
A combined (additive) model, where the total transport coefficients are
10711071
calculated by summing contributions from a list of component models. Each
1072-
component model is active only within its defined radial domain, set using
1073-
``rho_min``` and ``rho_max``. These zones can be overlapping or
1074-
non-overlapping; in regions of overlap, the total transport coefficients are
1075-
computed by adding the contributions from component models active at those
1076-
coordinates. Post-processing (clipping and smoothing) is performed on the
1077-
summed value.
1072+
component model is active only within its defined radial domain, which can
1073+
be overlapping or non-overlapping; in regions of overlap, the total
1074+
transport coefficients are computed by adding the contributions from
1075+
component models active at those coordinates.
1076+
For models defined in ``transport_models``, the active domain is set by
1077+
``rho_min``` and ``rho_max``. For models in ``pedestal_transport_models``,
1078+
the active domain is set by the ``rho_norm_ped_top`` parameter from the
1079+
``pedestal`` section of the config.
1080+
Post-processing (clipping and smoothing) is performed on the summed
1081+
values from all component models, including in the pedestal.
10781082

10791083
The runtime parameters are as follows.
10801084

10811085
``transport_models`` (list[dict])
1082-
A list containing config dicts for the component transport models.
1086+
A list containing config dicts for the component models for turbulent transport in the core.
1087+
1088+
``pedestal_transport_models`` (list[dict])
1089+
A list containing config dicts for the component models for turbulent transport in the pedestal.
1090+
10831091

10841092
.. warning::
10851093
TORAX will throw a ``ValueError`` if any of the component transport
@@ -1094,29 +1102,35 @@ The runtime parameters are as follows.
10941102
Example:
10951103

10961104
.. code-block:: python
1097-
1105+
...
10981106
'transport': {
1099-
'model_name': 'combined',
1100-
'transport_models': [
1101-
{
1102-
'model_name': 'constant',
1103-
'chi_i': 1.0,
1104-
'rho_max': 0.3,
1105-
},
1106-
{
1107-
'model_name': 'constant',
1108-
'chi_i': 2.0,
1109-
'rho_min': 0.2
1110-
'rho_max': 0.5,
1111-
},
1112-
{
1113-
'model_name': 'constant',
1114-
'chi_i': 0.5,
1115-
'rho_min': 0.5
1116-
'rho_max': 1.0,
1117-
},
1118-
],
1119-
}
1107+
'model_name': 'combined',
1108+
'transport_models': [
1109+
{
1110+
'model_name': 'constant',
1111+
'chi_i': 1.0,
1112+
'rho_max': 0.3,
1113+
},
1114+
{
1115+
'model_name': 'constant',
1116+
'chi_i': 2.0,
1117+
'rho_min': 0.2,
1118+
},
1119+
],
1120+
'pedestal_transport_models': [
1121+
{
1122+
'model_name': 'constant',
1123+
'chi_i': 0.5,
1124+
},
1125+
],
1126+
},
1127+
'pedestal': {
1128+
'model_name': 'set_T_ped_n_ped',
1129+
'set_pedestal': True,
1130+
'n_e_ped': 0.8,
1131+
'n_e_ped_is_fGW': True,
1132+
},
1133+
...
11201134
11211135
This would produce a ``chi_i`` profile that looks like the following.
11221136

@@ -1127,8 +1141,8 @@ This would produce a ``chi_i`` profile that looks like the following.
11271141
Note that in the region :math:`[0, 0.2]`, only the first component is active,
11281142
so ``chi_i = 1.0``. In :math:`(0.2, 0.3]` the first two components are both
11291143
active, leading to a combined value of ``chi_i = 3.0``. In :math:`(0.3, 0.5]`,
1130-
only the second model is active (``chi_i = 2.0``), and in :math:`(0.5, 1.0]`
1131-
only the fourth model is active (``chi_i = 0.5``).
1144+
only the second model is active (``chi_i = 2.0``), and in :math:`(0.8, 1.0]`
1145+
only the pedestal transport model is active (``chi_i = 0.5``).
11321146

11331147

11341148
sources
509 Bytes
Loading

docs/scripts/combined_transport_example.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
"""Script for plotting the combined transport model in the docs."""
1616
from typing import Sequence
17-
from absl import app
17+
1818
import matplotlib.pyplot as plt
19+
from absl import app
20+
1921
import torax
2022
from torax._src.torax_pydantic import model_config
2123

@@ -33,13 +35,18 @@ def main(argv: Sequence[str]) -> None:
3335
'geometry_type': 'circular',
3436
'n_rho': 30, # for higher resolution plotting
3537
},
36-
'pedestal': {},
38+
'pedestal': {
39+
'model_name': 'set_T_ped_n_ped',
40+
'set_pedestal': True,
41+
'n_e_ped': 0.8,
42+
'n_e_ped_is_fGW': True,
43+
},
3744
'neoclassical': {},
3845
'sources': {},
3946
'solver': {},
4047
'transport': {
4148
'model_name': 'combined',
42-
'core_transport_models': [
49+
'transport_models': [
4350
{
4451
'model_name': 'constant',
4552
'chi_i': 1.0,
@@ -49,15 +56,14 @@ def main(argv: Sequence[str]) -> None:
4956
'model_name': 'constant',
5057
'chi_i': 2.0,
5158
'rho_min': 0.2,
52-
'rho_max': 0.5,
5359
},
5460
],
55-
'pedestal_transport_model': {
56-
'model_name': 'constant',
57-
'chi_i': 0.5,
58-
'rho_min': 0.5,
59-
'rho_max': 1.0,
60-
},
61+
'pedestal_transport_models': [
62+
{
63+
'model_name': 'constant',
64+
'chi_i': 0.5,
65+
},
66+
],
6167
},
6268
}
6369
torax_config = model_config.ToraxConfig.from_dict(config)

torax/_src/transport_model/combined.py

Lines changed: 44 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
A class for combining transport models.
1818
"""
1919
import dataclasses
20-
from typing import Sequence
21-
import dataclasses
20+
from typing import Callable, Sequence
2221

2322
import jax
2423
import jax.numpy as jnp
@@ -35,7 +34,7 @@
3534
@jax.tree_util.register_dataclass
3635
@dataclasses.dataclass(frozen=True)
3736
class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
38-
core_transport_model_params: Sequence[runtime_params_lib.DynamicRuntimeParams]
37+
transport_model_params: Sequence[runtime_params_lib.DynamicRuntimeParams]
3938
pedestal_transport_model_params: Sequence[
4039
runtime_params_lib.DynamicRuntimeParams
4140
]
@@ -46,12 +45,12 @@ class CombinedTransportModel(transport_model_lib.TransportModel):
4645

4746
def __init__(
4847
self,
49-
core_transport_models: Sequence[transport_model_lib.TransportModel],
50-
pedestal_transport_model: transport_model_lib.TransportModel,
48+
transport_models: Sequence[transport_model_lib.TransportModel],
49+
pedestal_transport_models: Sequence[transport_model_lib.TransportModel],
5150
):
5251
super().__init__()
53-
self.core_transport_models = core_transport_models
54-
self.pedestal_transport_model = pedestal_transport_model
52+
self.transport_models = transport_models
53+
self.pedestal_transport_models = pedestal_transport_models
5554
self._frozen = True
5655

5756
def __call__(
@@ -128,72 +127,57 @@ def _call_implementation(
128127
# Required for pytype
129128
assert isinstance(transport_dynamic_runtime_params, DynamicRuntimeParams)
130129

131-
# Core transport
132-
core_transport_coeffs_list = []
133-
for component_model, component_params in zip(
134-
self.core_transport_models,
135-
transport_dynamic_runtime_params.core_transport_model_params,
136-
):
137-
# Use the component model's _call_implementation, rather than __call__
138-
# directly. This ensures postprocessing (clipping, smoothing, patches) are
139-
# performed on the combined output rather than the individual components.
130+
def apply_and_restrict(
131+
component_model: transport_model_lib.TransportModel,
132+
component_params: runtime_params_lib.DynamicRuntimeParams,
133+
restriction_fn: Callable,
134+
) -> transport_model_lib.TurbulentTransport:
135+
# TODO: Consider only computing transport coefficients for the active
136+
# domain, rather than masking them out later. This could be significantly
137+
# more efficient especially for pedestal models, as these are only active
138+
# in a small region of the domain.
140139
component_transport_coeffs = component_model._call_implementation(
141140
component_params,
142141
dynamic_runtime_params_slice,
143142
geo,
144143
core_profiles,
145144
pedestal_model_output,
146145
)
147-
148-
# Apply domain restriction
149-
# This is a property of each component_model, so needs to be applied
150-
# at the component model level rather than the global level
151-
component_transport_coeffs = component_model._apply_domain_restriction(
146+
component_transport_coeffs = restriction_fn(
152147
component_params,
153148
geo,
154149
component_transport_coeffs,
155150
pedestal_model_output,
156151
)
152+
return component_transport_coeffs
157153

158-
core_transport_coeffs_list.append(component_transport_coeffs)
159-
160-
# Pedestal transport
161-
pedestal_transport_coeffs = (
162-
self.pedestal_transport_model._call_implementation(
154+
pedestal_coeffs = [
155+
apply_and_restrict(
156+
model, params, self._apply_pedestal_domain_restriction
157+
)
158+
for model, params in zip(
159+
self.pedestal_transport_models,
163160
transport_dynamic_runtime_params.pedestal_transport_model_params,
164-
dynamic_runtime_params_slice,
165-
geo,
166-
core_profiles,
167-
pedestal_model_output,
168161
)
169-
)
170-
pedestal_transport_coeffs = self._apply_pedestal_domain_restriction(
171-
transport_dynamic_runtime_params.pedestal_transport_model_params,
172-
geo,
173-
pedestal_transport_coeffs,
174-
pedestal_model_output,
175-
)
162+
]
163+
164+
core_coeffs = [
165+
apply_and_restrict(model, params, model._apply_domain_restriction)
166+
for model, params in zip(
167+
self.transport_models,
168+
transport_dynamic_runtime_params.transport_model_params,
169+
)
170+
]
176171

177172
# Combine the transport coefficients from core and pedestal models.
178173
combined_transport_coeffs = jax.tree.map(
179174
lambda *leaves: sum(leaves),
180-
*(core_transport_coeffs_list + [pedestal_transport_coeffs]),
175+
*pedestal_coeffs,
176+
*core_coeffs,
181177
)
182178

183179
return combined_transport_coeffs
184180

185-
def __hash__(self):
186-
return hash(
187-
tuple(self.core_transport_models + [self.pedestal_transport_model])
188-
)
189-
190-
def __eq__(self, other):
191-
return (
192-
isinstance(other, CombinedTransportModel)
193-
and self.core_transport_models == other.core_transport_models
194-
and self.pedestal_transport_model == other.pedestal_transport_model
195-
)
196-
197181
def _apply_pedestal_domain_restriction(
198182
self,
199183
transport_runtime_params: runtime_params_lib.DynamicRuntimeParams,
@@ -215,3 +199,13 @@ def _apply_pedestal_domain_restriction(
215199
d_face_el=d_face_el,
216200
v_face_el=v_face_el,
217201
)
202+
203+
def __hash__(self):
204+
return hash(tuple(self.transport_models + self.pedestal_transport_models))
205+
206+
def __eq__(self, other):
207+
return (
208+
isinstance(other, CombinedTransportModel)
209+
and self.transport_models == other.transport_models
210+
and self.pedestal_transport_models == other.pedestal_transport_models
211+
)

0 commit comments

Comments
 (0)