Skip to content

Commit 9653014

Browse files
committed
Add pedestal transport model section to CombinedTransportModel
1 parent 4e923e3 commit 9653014

File tree

6 files changed

+192
-58
lines changed

6 files changed

+192
-58
lines changed

docs/scripts/combined_transport_example.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def main(argv: Sequence[str]) -> None:
3939
'solver': {},
4040
'transport': {
4141
'model_name': 'combined',
42-
'transport_models': [
42+
'core_transport_models': [
4343
{
4444
'model_name': 'constant',
4545
'chi_i': 1.0,
@@ -51,13 +51,13 @@ def main(argv: Sequence[str]) -> None:
5151
'rho_min': 0.2,
5252
'rho_max': 0.5,
5353
},
54-
{
55-
'model_name': 'constant',
56-
'chi_i': 0.5,
57-
'rho_min': 0.5,
58-
'rho_max': 1.0,
59-
},
6054
],
55+
'pedestal_transport_model': {
56+
'model_name': 'constant',
57+
'chi_i': 0.5,
58+
'rho_min': 0.5,
59+
'rho_max': 1.0,
60+
},
6161
},
6262
}
6363
torax_config = model_config.ToraxConfig.from_dict(config)

torax/_src/transport_model/combined.py

Lines changed: 111 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 DeepMind Technologies Limited
1+
7 # Copyright 2024 DeepMind Technologies Limited
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -18,9 +18,11 @@
1818
"""
1919

2020
from typing import Sequence
21+
import dataclasses
2122

2223
import chex
2324
import jax
25+
import jax.numpy as jnp
2426
from torax._src import state
2527
from torax._src.config import runtime_params_slice
2628
from torax._src.geometry import geometry
@@ -33,19 +35,72 @@
3335

3436
@chex.dataclass(frozen=True)
3537
class DynamicRuntimeParams(runtime_params_lib.DynamicRuntimeParams):
36-
transport_model_params: Sequence[runtime_params_lib.DynamicRuntimeParams]
38+
core_transport_model_params: Sequence[runtime_params_lib.DynamicRuntimeParams]
39+
pedestal_transport_model_params: Sequence[
40+
runtime_params_lib.DynamicRuntimeParams
41+
]
3742

3843

3944
class CombinedTransportModel(transport_model_lib.TransportModel):
4045
"""Combines coefficients from a list of transport models."""
4146

4247
def __init__(
43-
self, transport_models: Sequence[transport_model_lib.TransportModel]
48+
self,
49+
core_transport_models: Sequence[transport_model_lib.TransportModel],
50+
pedestal_transport_model: transport_model_lib.TransportModel,
4451
):
4552
super().__init__()
46-
self.transport_models = transport_models
53+
self.core_transport_models = core_transport_models
54+
self.pedestal_transport_model = pedestal_transport_model
4755
self._frozen = True
4856

57+
def __call__(
58+
self,
59+
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
60+
geo: geometry.Geometry,
61+
core_profiles: state.CoreProfiles,
62+
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
63+
) -> transport_model_lib.TurbulentTransport:
64+
if not getattr(self, "_frozen", False):
65+
raise RuntimeError(
66+
f"Subclass implementation {type(self)} forgot to "
67+
"freeze at the end of __init__."
68+
)
69+
70+
transport_runtime_params = dynamic_runtime_params_slice.transport
71+
72+
# Calculate the transport coefficients - includes contribution from pedestal
73+
# and core transport models.
74+
transport_coeffs = self._call_implementation(
75+
transport_runtime_params,
76+
dynamic_runtime_params_slice,
77+
geo,
78+
core_profiles,
79+
pedestal_model_output,
80+
)
81+
82+
# In contrast to the base TransportModel, we do not apply domain restriction
83+
84+
# Apply min/max clipping
85+
transport_coeffs = self._apply_clipping(
86+
transport_runtime_params,
87+
transport_coeffs,
88+
)
89+
90+
# In contrast to the base TransportModel, we do not apply patches
91+
92+
# Return smoothed coefficients if smoothing is enabled
93+
# TODO: what should be done about masking? In base TransportModel, the
94+
# pedestal and any patches are masked out before smoothing. May require a
95+
# custom implementation of _smooth_coeffs for CombinedTransportModel.
96+
return self._smooth_coeffs(
97+
transport_runtime_params,
98+
dynamic_runtime_params_slice,
99+
geo,
100+
transport_coeffs,
101+
pedestal_model_output,
102+
)
103+
49104
def _call_implementation(
50105
self,
51106
transport_dynamic_runtime_params: runtime_params_lib.DynamicRuntimeParams,
@@ -72,16 +127,15 @@ def _call_implementation(
72127
# Required for pytype
73128
assert isinstance(transport_dynamic_runtime_params, DynamicRuntimeParams)
74129

75-
component_transport_coeffs_list = []
76-
130+
# Core transport
131+
core_transport_coeffs_list = []
77132
for component_model, component_params in zip(
78-
self.transport_models,
79-
transport_dynamic_runtime_params.transport_model_params,
133+
self.core_transport_models,
134+
transport_dynamic_runtime_params.core_transport_model_params,
80135
):
81136
# Use the component model's _call_implementation, rather than __call__
82137
# directly. This ensures postprocessing (clipping, smoothing, patches) are
83-
# performed on the output of CombinedTransportModel rather than its
84-
# component models.
138+
# performed on the combined output rather than the individual components.
85139
component_transport_coeffs = component_model._call_implementation(
86140
component_params,
87141
dynamic_runtime_params_slice,
@@ -100,20 +154,63 @@ def _call_implementation(
100154
pedestal_model_output,
101155
)
102156

103-
component_transport_coeffs_list.append(component_transport_coeffs)
157+
core_transport_coeffs_list.append(component_transport_coeffs)
158+
159+
# Pedestal transport
160+
pedestal_transport_coeffs = (
161+
self.pedestal_transport_model._call_implementation(
162+
transport_dynamic_runtime_params.pedestal_transport_model_params,
163+
dynamic_runtime_params_slice,
164+
geo,
165+
core_profiles,
166+
pedestal_model_output,
167+
)
168+
)
169+
pedestal_transport_coeffs = self._apply_pedestal_domain_restriction(
170+
transport_dynamic_runtime_params.pedestal_transport_model_params,
171+
geo,
172+
pedestal_transport_coeffs,
173+
pedestal_model_output,
174+
)
104175

176+
# Combine the transport coefficients from core and pedestal models.
105177
combined_transport_coeffs = jax.tree.map(
106178
lambda *leaves: sum(leaves),
107-
*component_transport_coeffs_list,
179+
*(core_transport_coeffs_list + [pedestal_transport_coeffs]),
108180
)
109181

110182
return combined_transport_coeffs
111183

112184
def __hash__(self):
113-
return hash(tuple(self.transport_models))
185+
return hash(
186+
tuple(self.core_transport_models + [self.pedestal_transport_model])
187+
)
114188

115189
def __eq__(self, other):
116190
return (
117191
isinstance(other, CombinedTransportModel)
118-
and self.transport_models == other.transport_models
192+
and self.core_transport_models == other.core_transport_models
193+
and self.pedestal_transport_model == other.pedestal_transport_model
194+
)
195+
196+
def _apply_pedestal_domain_restriction(
197+
self,
198+
transport_runtime_params: runtime_params_lib.DynamicRuntimeParams,
199+
geo: geometry.Geometry,
200+
transport_coeffs: transport_model_lib.TurbulentTransport,
201+
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
202+
) -> transport_model_lib.TurbulentTransport:
203+
active_mask = geo.rho_face_norm > pedestal_model_output.rho_norm_ped_top
204+
205+
chi_face_ion = jnp.where(active_mask, transport_coeffs.chi_face_ion, 0.0)
206+
chi_face_el = jnp.where(active_mask, transport_coeffs.chi_face_el, 0.0)
207+
d_face_el = jnp.where(active_mask, transport_coeffs.d_face_el, 0.0)
208+
v_face_el = jnp.where(active_mask, transport_coeffs.v_face_el, 0.0)
209+
210+
return dataclasses.replace(
211+
transport_coeffs,
212+
chi_face_ion=chi_face_ion,
213+
chi_face_el=chi_face_el,
214+
d_face_el=d_face_el,
215+
v_face_el=v_face_el,
119216
)

torax/_src/transport_model/pydantic_model.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -368,50 +368,69 @@ class CombinedTransportModel(pydantic_model_base.TransportBase):
368368
369369
Attributes:
370370
model_name: The transport model to use. Hardcoded to 'combined'.
371-
transport_models: A sequence of transport models, whose outputs will be
372-
summed to give the combined transport coefficients.
371+
core_transport_models: A sequence of transport models, whose outputs will be
372+
summed to give the combined core transport coefficients.
373+
pedestal_transport_model: A model for pedestal transport coefficients.
373374
"""
374375

375-
transport_models: Sequence[CombinedCompatibleTransportModel] = pydantic.Field(
376+
core_transport_models: Sequence[
377+
CombinedCompatibleTransportModel
378+
] = pydantic.Field(
376379
default_factory=list
377380
) # pytype: disable=invalid-annotation
381+
pedestal_transport_model: CombinedCompatibleTransportModel
378382
model_name: Literal['combined'] = 'combined'
379383

380384
def build_transport_model(self) -> combined.CombinedTransportModel:
381-
model_list = [
382-
model.build_transport_model() for model in self.transport_models
385+
core_transport_model_list = [
386+
model.build_transport_model() for model in self.core_transport_models
383387
]
384-
return combined.CombinedTransportModel(transport_models=model_list)
388+
pedestal_transport_model = (
389+
self.pedestal_transport_model.build_transport_model()
390+
)
391+
392+
return combined.CombinedTransportModel(
393+
core_transport_models=core_transport_model_list,
394+
pedestal_transport_model=pedestal_transport_model,
395+
)
385396

386397
def build_dynamic_params(
387398
self, t: chex.Numeric
388399
) -> combined.DynamicRuntimeParams:
389400
base_kwargs = dataclasses.asdict(super().build_dynamic_params(t))
390-
model_params_list = [
391-
model.build_dynamic_params(t) for model in self.transport_models
401+
core_transport_model_params = [
402+
model.build_dynamic_params(t) for model in self.core_transport_models
392403
]
404+
pedestal_transport_model_params = (
405+
self.pedestal_transport_model.build_dynamic_params(t)
406+
)
407+
393408
return combined.DynamicRuntimeParams(
394-
transport_model_params=model_params_list,
409+
core_transport_model_params=core_transport_model_params,
410+
pedestal_transport_model_params=pedestal_transport_model_params,
395411
**base_kwargs,
396412
)
397413

398414
@pydantic.model_validator(mode='after')
399415
def _check_fields(self) -> typing_extensions.Self:
400416
super()._check_fields()
401-
if not self.transport_models:
417+
if not self.core_transport_models:
402418
raise ValueError(
403-
'transport_models cannot be empty for CombinedTransportModel. '
419+
'core_transport_models cannot be empty for CombinedTransportModel. '
404420
'Please provide at least one transport model configuration.'
405421
)
406-
if any([
407-
np.any(model.apply_inner_patch.value)
408-
or np.any(model.apply_outer_patch.value)
409-
for model in self.transport_models
410-
]):
422+
if (
423+
any([
424+
np.any(model.apply_inner_patch.value)
425+
or np.any(model.apply_outer_patch.value)
426+
for model in self.core_transport_models
427+
])
428+
or np.any(self.apply_inner_patch.value)
429+
or np.any(self.apply_outer_patch.value)
430+
):
411431
raise ValueError(
412-
'apply_inner_patch and apply_outer_patch and should be set in the'
413-
' config for CombinedTransportModel only, rather than its component'
414-
' models.'
432+
'apply_inner_patch and apply_outer_patch not supported for'
433+
' CombinedTransportModel or its component models.'
415434
)
416435
if np.any(self.rho_min.value != 0.0) or np.any(self.rho_max.value != 1.0):
417436
raise ValueError(
@@ -422,4 +441,3 @@ def _check_fields(self) -> typing_extensions.Self:
422441

423442

424443
TransportConfig = CombinedTransportModel | CombinedCompatibleTransportModel # pytype: disable=invalid-annotation
425-

torax/_src/transport_model/tests/combined_test.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_call_implementation(self):
3131
config = default_configs.get_default_config_dict()
3232
config['transport'] = {
3333
'model_name': 'combined',
34-
'transport_models': [
34+
'core_transport_models': [
3535
{'model_name': 'constant', 'rho_max': 0.2, 'chi_i': 1.0},
3636
{
3737
'model_name': 'constant',
@@ -41,6 +41,7 @@ def test_call_implementation(self):
4141
},
4242
{'model_name': 'constant', 'rho_min': 0.5, 'chi_i': 3.0},
4343
],
44+
'pedestal_transport_model': {'model_name': 'constant', 'chi_i': 0.1},
4445
}
4546
torax_config = model_config.ToraxConfig.from_dict(config)
4647
model = torax_config.transport.build_transport_model()
@@ -80,12 +81,12 @@ def test_call_implementation(self):
8081
mock_pedestal_outputs,
8182
)
8283
# Target:
83-
# - 0 for rho = [rho_ped_top, rho_max]
84+
# - 0.1 for rho = [rho_ped_top, rho_max]
8485
# - 3 for rho = (0.8, rho_ped_top), to check pedestal overrides it
8586
# - 5 for rho = (0.5, 0.8], to check case where models overlap
8687
# - 2 for rho = (0.2, 0.5], to check case rho_min_1 == rho_max_2
8788
# - 1 for rho = [0, 0.2], to check case where rho_min = 0
88-
target = jnp.where(geo.rho_face_norm <= 0.91, 3.0, 0.0)
89+
target = jnp.where(geo.rho_face_norm <= 0.91, 3.0, 0.1)
8990
target = jnp.where(geo.rho_face_norm <= 0.8, 5.0, target)
9091
target = jnp.where(geo.rho_face_norm <= 0.5, 2.0, target)
9192
target = jnp.where(geo.rho_face_norm <= 0.2, 1.0, target)
@@ -95,10 +96,27 @@ def test_error_if_patches_set_on_children(self):
9596
config = default_configs.get_default_config_dict()
9697
config['transport'] = {
9798
'model_name': 'combined',
98-
'transport_models': [
99+
'core_transport_models': [
99100
{'model_name': 'constant', 'apply_inner_patch': True},
100101
{'model_name': 'constant'},
101102
],
103+
'pedestal_transport_model': {'model_name': 'constant'},
104+
}
105+
with self.assertRaisesRegex(
106+
ValueError, '(?=.*patch)(?=.*CombinedTransportModel)'
107+
):
108+
model_config.ToraxConfig.from_dict(config)
109+
110+
def test_error_if_patches_set_on_self(self):
111+
config = default_configs.get_default_config_dict()
112+
config['transport'] = {
113+
'model_name': 'combined',
114+
'core_transport_models': [
115+
{'model_name': 'constant'},
116+
{'model_name': 'constant'},
117+
],
118+
'pedestal_transport_model': {'model_name': 'constant'},
119+
'apply_inner_patch': True,
102120
}
103121
with self.assertRaisesRegex(
104122
ValueError, '(?=.*patch)(?=.*CombinedTransportModel)'
@@ -109,10 +127,11 @@ def test_error_if_rho_min_or_rho_max_set(self):
109127
config = default_configs.get_default_config_dict()
110128
config['transport'] = {
111129
'model_name': 'combined',
112-
'transport_models': [
130+
'core_transport_models': [
113131
{'model_name': 'constant'},
114132
{'model_name': 'constant'},
115133
],
134+
'pedestal_transport_model': {'model_name': 'constant'},
116135
'rho_min': 0.1,
117136
}
118137
with self.assertRaisesRegex(
@@ -122,10 +141,11 @@ def test_error_if_rho_min_or_rho_max_set(self):
122141

123142
config['transport'] = {
124143
'model_name': 'combined',
125-
'transport_models': [
144+
'core_transport_models': [
126145
{'model_name': 'constant'},
127146
{'model_name': 'constant'},
128147
],
148+
'pedestal_transport_model': {'model_name': 'constant'},
129149
'rho_max': 0.9,
130150
}
131151
with self.assertRaises(ValueError):
-31.3 KB
Binary file not shown.

0 commit comments

Comments
 (0)