Skip to content

Commit 53be1e0

Browse files
committed
Add pedestal transport model section to CombinedTransportModel
1 parent 4e923e3 commit 53be1e0

File tree

6 files changed

+193
-58
lines changed

6 files changed

+193
-58
lines changed

docs/scripts/combined_transport_example.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def main(argv: Sequence[str]) -> None:
3939
'solver': {},
4040
'transport': {
4141
'model_name': 'combined',
42-
'transport_models': [
42+
'core_transport_models': [
4343
{
4444
'model_name': 'constant',
4545
'chi_i': 1.0,
@@ -51,13 +51,13 @@ def main(argv: Sequence[str]) -> None:
5151
'rho_min': 0.2,
5252
'rho_max': 0.5,
5353
},
54-
{
55-
'model_name': 'constant',
56-
'chi_i': 0.5,
57-
'rho_min': 0.5,
58-
'rho_max': 1.0,
59-
},
6054
],
55+
'pedestal_transport_model': {
56+
'model_name': 'constant',
57+
'chi_i': 0.5,
58+
'rho_min': 0.5,
59+
'rho_max': 1.0,
60+
},
6161
},
6262
}
6363
torax_config = model_config.ToraxConfig.from_dict(config)

torax/_src/transport_model/combined.py

Lines changed: 112 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 DeepMind Technologies Limited
1+
7 # Copyright 2024 DeepMind Technologies Limited
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -18,9 +18,11 @@
1818
"""
1919

2020
from typing import Sequence
21+
import dataclasses
2122

2223
import chex
2324
import jax
25+
import jax.numpy as jnp
2426
from torax._src import state
2527
from torax._src.config import runtime_params_slice
2628
from torax._src.geometry import geometry
@@ -33,19 +35,73 @@
3335

3436
@chex.dataclass(frozen=True)
3537
class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
36-
transport_model_params: Sequence[runtime_params_lib.DynamicRuntimeParams]
38+
core_transport_model_params: Sequence[runtime_params_lib.DynamicRuntimeParams]
39+
pedestal_transport_model_params: Sequence[
40+
runtime_params_lib.DynamicRuntimeParams
41+
]
3742

3843

3944
class CombinedTransportModel(transport_model_lib.TransportModel):
4045
"""Combines coefficients from a list of transport models."""
4146

4247
def __init__(
43-
self, transport_models: Sequence[transport_model_lib.TransportModel]
48+
self,
49+
core_transport_models: Sequence[transport_model_lib.TransportModel],
50+
pedestal_transport_model: transport_model_lib.TransportModel,
4451
):
4552
super().__init__()
46-
self.transport_models = transport_models
53+
self.core_transport_models = core_transport_models
54+
self.pedestal_transport_model = pedestal_transport_model
4755
self._frozen = True
4856

57+
def __call__(
58+
self,
59+
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
60+
geo: geometry.Geometry,
61+
core_profiles: state.CoreProfiles,
62+
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
63+
) -> transport_model_lib.TurbulentTransport:
64+
if not getattr(self, "_frozen", False):
65+
raise RuntimeError(
66+
f"Subclass implementation {type(self)} forgot to "
67+
"freeze at the end of __init__."
68+
)
69+
70+
transport_runtime_params = dynamic_runtime_params_slice.transport
71+
72+
# Calculate the transport coefficients - includes contribution from pedestal
73+
# and core transport models.
74+
transport_coeffs = self._call_implementation(
75+
transport_runtime_params,
76+
dynamic_runtime_params_slice,
77+
geo,
78+
core_profiles,
79+
pedestal_model_output,
80+
)
81+
82+
# In contrast to the base TransportModel, we do not apply domain restriction
83+
# as this is handled at the component model level
84+
85+
# Apply min/max clipping
86+
transport_coeffs = self._apply_clipping(
87+
transport_runtime_params,
88+
transport_coeffs,
89+
)
90+
91+
# In contrast to the base TransportModel, we do not apply patches, as these
92+
# should be handled by instantiating constant component models instead.
93+
# However, the rho_inner and rho_outer arguments are currently required
94+
# in the case where the inner/outer region are to be excluded from smoothing.
95+
# Smoothing is applied to rho_inner < rho_norm < min(rho_ped_top, rho_outer)
96+
# unless smooth_everywhere is True.
97+
return self._smooth_coeffs(
98+
transport_runtime_params,
99+
dynamic_runtime_params_slice,
100+
geo,
101+
transport_coeffs,
102+
pedestal_model_output,
103+
)
104+
49105
def _call_implementation(
50106
self,
51107
transport_dynamic_runtime_params: runtime_params_lib.DynamicRuntimeParams,
@@ -72,16 +128,15 @@ def _call_implementation(
72128
# Required for pytype
73129
assert isinstance(transport_dynamic_runtime_params, DynamicRuntimeParams)
74130

75-
component_transport_coeffs_list = []
76-
131+
# Core transport
132+
core_transport_coeffs_list = []
77133
for component_model, component_params in zip(
78-
self.transport_models,
79-
transport_dynamic_runtime_params.transport_model_params,
134+
self.core_transport_models,
135+
transport_dynamic_runtime_params.core_transport_model_params,
80136
):
81137
# Use the component model's _call_implementation, rather than __call__
82138
# directly. This ensures postprocessing (clipping, smoothing, patches) are
83-
# performed on the output of CombinedTransportModel rather than its
84-
# component models.
139+
# performed on the combined output rather than the individual components.
85140
component_transport_coeffs = component_model._call_implementation(
86141
component_params,
87142
dynamic_runtime_params_slice,
@@ -100,20 +155,63 @@ def _call_implementation(
100155
pedestal_model_output,
101156
)
102157

103-
component_transport_coeffs_list.append(component_transport_coeffs)
158+
core_transport_coeffs_list.append(component_transport_coeffs)
159+
160+
# Pedestal transport
161+
pedestal_transport_coeffs = (
162+
self.pedestal_transport_model._call_implementation(
163+
transport_dynamic_runtime_params.pedestal_transport_model_params,
164+
dynamic_runtime_params_slice,
165+
geo,
166+
core_profiles,
167+
pedestal_model_output,
168+
)
169+
)
170+
pedestal_transport_coeffs = self._apply_pedestal_domain_restriction(
171+
transport_dynamic_runtime_params.pedestal_transport_model_params,
172+
geo,
173+
pedestal_transport_coeffs,
174+
pedestal_model_output,
175+
)
104176

177+
# Combine the transport coefficients from core and pedestal models.
105178
combined_transport_coeffs = jax.tree.map(
106179
lambda *leaves: sum(leaves),
107-
*component_transport_coeffs_list,
180+
*(core_transport_coeffs_list + [pedestal_transport_coeffs]),
108181
)
109182

110183
return combined_transport_coeffs
111184

112185
def __hash__(self):
113-
return hash(tuple(self.transport_models))
186+
return hash(
187+
tuple(self.core_transport_models + [self.pedestal_transport_model])
188+
)
114189

115190
def __eq__(self, other):
116191
return (
117192
isinstance(other, CombinedTransportModel)
118-
and self.transport_models == other.transport_models
193+
and self.core_transport_models == other.core_transport_models
194+
and self.pedestal_transport_model == other.pedestal_transport_model
195+
)
196+
197+
def _apply_pedestal_domain_restriction(
198+
self,
199+
transport_runtime_params: runtime_params_lib.DynamicRuntimeParams,
200+
geo: geometry.Geometry,
201+
transport_coeffs: transport_model_lib.TurbulentTransport,
202+
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
203+
) -> transport_model_lib.TurbulentTransport:
204+
active_mask = geo.rho_face_norm > pedestal_model_output.rho_norm_ped_top
205+
206+
chi_face_ion = jnp.where(active_mask, transport_coeffs.chi_face_ion, 0.0)
207+
chi_face_el = jnp.where(active_mask, transport_coeffs.chi_face_el, 0.0)
208+
d_face_el = jnp.where(active_mask, transport_coeffs.d_face_el, 0.0)
209+
v_face_el = jnp.where(active_mask, transport_coeffs.v_face_el, 0.0)
210+
211+
return dataclasses.replace(
212+
transport_coeffs,
213+
chi_face_ion=chi_face_ion,
214+
chi_face_el=chi_face_el,
215+
d_face_el=d_face_el,
216+
v_face_el=v_face_el,
119217
)

torax/_src/transport_model/pydantic_model.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -368,50 +368,69 @@ class CombinedTransportModel(pydantic_model_base.TransportBase):
368368
369369
Attributes:
370370
model_name: The transport model to use. Hardcoded to 'combined'.
371-
transport_models: A sequence of transport models, whose outputs will be
372-
summed to give the combined transport coefficients.
371+
core_transport_models: A sequence of transport models, whose outputs will be
372+
summed to give the combined core transport coefficients.
373+
pedestal_transport_model: A model for pedestal transport coefficients.
373374
"""
374375

375-
transport_models: Sequence[CombinedCompatibleTransportModel] = pydantic.Field(
376+
core_transport_models: Sequence[
377+
CombinedCompatibleTransportModel
378+
] = pydantic.Field(
376379
default_factory=list
377380
) # pytype: disable=invalid-annotation
381+
pedestal_transport_model: CombinedCompatibleTransportModel
378382
model_name: Literal['combined'] = 'combined'
379383

380384
def build_transport_model(self) -> combined.CombinedTransportModel:
381-
model_list = [
382-
model.build_transport_model() for model in self.transport_models
385+
core_transport_model_list = [
386+
model.build_transport_model() for model in self.core_transport_models
383387
]
384-
return combined.CombinedTransportModel(transport_models=model_list)
388+
pedestal_transport_model = (
389+
self.pedestal_transport_model.build_transport_model()
390+
)
391+
392+
return combined.CombinedTransportModel(
393+
core_transport_models=core_transport_model_list,
394+
pedestal_transport_model=pedestal_transport_model,
395+
)
385396

386397
def build_dynamic_params(
387398
self, t: chex.Numeric
388399
) -> combined.DynamicRuntimeParams:
389400
base_kwargs = dataclasses.asdict(super().build_dynamic_params(t))
390-
model_params_list = [
391-
model.build_dynamic_params(t) for model in self.transport_models
401+
core_transport_model_params = [
402+
model.build_dynamic_params(t) for model in self.core_transport_models
392403
]
404+
pedestal_transport_model_params = (
405+
self.pedestal_transport_model.build_dynamic_params(t)
406+
)
407+
393408
return combined.DynamicRuntimeParams(
394-
transport_model_params=model_params_list,
409+
core_transport_model_params=core_transport_model_params,
410+
pedestal_transport_model_params=pedestal_transport_model_params,
395411
**base_kwargs,
396412
)
397413

398414
@pydantic.model_validator(mode='after')
399415
def _check_fields(self) -> typing_extensions.Self:
400416
super()._check_fields()
401-
if not self.transport_models:
417+
if not self.core_transport_models:
402418
raise ValueError(
403-
'transport_models cannot be empty for CombinedTransportModel. '
419+
'core_transport_models cannot be empty for CombinedTransportModel. '
404420
'Please provide at least one transport model configuration.'
405421
)
406-
if any([
407-
np.any(model.apply_inner_patch.value)
408-
or np.any(model.apply_outer_patch.value)
409-
for model in self.transport_models
410-
]):
422+
if (
423+
any([
424+
np.any(model.apply_inner_patch.value)
425+
or np.any(model.apply_outer_patch.value)
426+
for model in self.core_transport_models
427+
])
428+
or np.any(self.apply_inner_patch.value)
429+
or np.any(self.apply_outer_patch.value)
430+
):
411431
raise ValueError(
412-
'apply_inner_patch and apply_outer_patch and should be set in the'
413-
' config for CombinedTransportModel only, rather than its component'
414-
' models.'
432+
'apply_inner_patch and apply_outer_patch not supported for'
433+
' CombinedTransportModel or its component models.'
415434
)
416435
if np.any(self.rho_min.value != 0.0) or np.any(self.rho_max.value != 1.0):
417436
raise ValueError(
@@ -422,4 +441,3 @@ def _check_fields(self) -> typing_extensions.Self:
422441

423442

424443
TransportConfig = CombinedTransportModel | CombinedCompatibleTransportModel # pytype: disable=invalid-annotation
425-

torax/_src/transport_model/tests/combined_test.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_call_implementation(self):
3131
config = default_configs.get_default_config_dict()
3232
config['transport'] = {
3333
'model_name': 'combined',
34-
'transport_models': [
34+
'core_transport_models': [
3535
{'model_name': 'constant', 'rho_max': 0.2, 'chi_i': 1.0},
3636
{
3737
'model_name': 'constant',
@@ -41,6 +41,7 @@ def test_call_implementation(self):
4141
},
4242
{'model_name': 'constant', 'rho_min': 0.5, 'chi_i': 3.0},
4343
],
44+
'pedestal_transport_model': {'model_name': 'constant', 'chi_i': 0.1},
4445
}
4546
torax_config = model_config.ToraxConfig.from_dict(config)
4647
model = torax_config.transport.build_transport_model()
@@ -80,12 +81,12 @@ def test_call_implementation(self):
8081
mock_pedestal_outputs,
8182
)
8283
# Target:
83-
# - 0 for rho = [rho_ped_top, rho_max]
84+
# - 0.1 for rho = [rho_ped_top, rho_max]
8485
# - 3 for rho = (0.8, rho_ped_top), to check pedestal overrides it
8586
# - 5 for rho = (0.5, 0.8], to check case where models overlap
8687
# - 2 for rho = (0.2, 0.5], to check case rho_min_1 == rho_max_2
8788
# - 1 for rho = [0, 0.2], to check case where rho_min = 0
88-
target = jnp.where(geo.rho_face_norm <= 0.91, 3.0, 0.0)
89+
target = jnp.where(geo.rho_face_norm <= 0.91, 3.0, 0.1)
8990
target = jnp.where(geo.rho_face_norm <= 0.8, 5.0, target)
9091
target = jnp.where(geo.rho_face_norm <= 0.5, 2.0, target)
9192
target = jnp.where(geo.rho_face_norm <= 0.2, 1.0, target)
@@ -95,10 +96,27 @@ def test_error_if_patches_set_on_children(self):
9596
config = default_configs.get_default_config_dict()
9697
config['transport'] = {
9798
'model_name': 'combined',
98-
'transport_models': [
99+
'core_transport_models': [
99100
{'model_name': 'constant', 'apply_inner_patch': True},
100101
{'model_name': 'constant'},
101102
],
103+
'pedestal_transport_model': {'model_name': 'constant'},
104+
}
105+
with self.assertRaisesRegex(
106+
ValueError, '(?=.*patch)(?=.*CombinedTransportModel)'
107+
):
108+
model_config.ToraxConfig.from_dict(config)
109+
110+
def test_error_if_patches_set_on_self(self):
111+
config = default_configs.get_default_config_dict()
112+
config['transport'] = {
113+
'model_name': 'combined',
114+
'core_transport_models': [
115+
{'model_name': 'constant'},
116+
{'model_name': 'constant'},
117+
],
118+
'pedestal_transport_model': {'model_name': 'constant'},
119+
'apply_inner_patch': True,
102120
}
103121
with self.assertRaisesRegex(
104122
ValueError, '(?=.*patch)(?=.*CombinedTransportModel)'
@@ -109,10 +127,11 @@ def test_error_if_rho_min_or_rho_max_set(self):
109127
config = default_configs.get_default_config_dict()
110128
config['transport'] = {
111129
'model_name': 'combined',
112-
'transport_models': [
130+
'core_transport_models': [
113131
{'model_name': 'constant'},
114132
{'model_name': 'constant'},
115133
],
134+
'pedestal_transport_model': {'model_name': 'constant'},
116135
'rho_min': 0.1,
117136
}
118137
with self.assertRaisesRegex(
@@ -122,10 +141,11 @@ def test_error_if_rho_min_or_rho_max_set(self):
122141

123142
config['transport'] = {
124143
'model_name': 'combined',
125-
'transport_models': [
144+
'core_transport_models': [
126145
{'model_name': 'constant'},
127146
{'model_name': 'constant'},
128147
],
148+
'pedestal_transport_model': {'model_name': 'constant'},
129149
'rho_max': 0.9,
130150
}
131151
with self.assertRaises(ValueError):
-31.3 KB
Binary file not shown.

0 commit comments

Comments
 (0)