1717A class for combining transport models.
1818"""
1919import dataclasses
20- from typing import Sequence
20+ from typing import Callable , Sequence
2121
2222import jax
23+ import jax .numpy as jnp
2324from torax ._src import state
2425from torax ._src .config import runtime_params_slice
2526from torax ._src .geometry import geometry
3435@dataclasses .dataclass (frozen = True )
3536class DynamicRuntimeParams (runtime_params_lib .DynamicRuntimeParams ):
3637 transport_model_params : Sequence [runtime_params_lib .DynamicRuntimeParams ]
38+ pedestal_transport_model_params : Sequence [
39+ runtime_params_lib .DynamicRuntimeParams
40+ ]
3741
3842
3943class CombinedTransportModel (transport_model_lib .TransportModel ):
4044 """Combines coefficients from a list of transport models."""
4145
4246 def __init__ (
43- self , transport_models : Sequence [transport_model_lib .TransportModel ]
47+ self ,
48+ transport_models : Sequence [transport_model_lib .TransportModel ],
49+ pedestal_transport_models : Sequence [transport_model_lib .TransportModel ],
4450 ):
4551 super ().__init__ ()
4652 self .transport_models = transport_models
53+ self .pedestal_transport_models = pedestal_transport_models
4754 self ._frozen = True
4855
56+ def __call__ (
57+ self ,
58+ dynamic_runtime_params_slice : runtime_params_slice .DynamicRuntimeParamsSlice ,
59+ geo : geometry .Geometry ,
60+ core_profiles : state .CoreProfiles ,
61+ pedestal_model_output : pedestal_model_lib .PedestalModelOutput ,
62+ ) -> transport_model_lib .TurbulentTransport :
63+ if not getattr (self , "_frozen" , False ):
64+ raise RuntimeError (
65+ f"Subclass implementation { type (self )} forgot to "
66+ "freeze at the end of __init__."
67+ )
68+
69+ transport_runtime_params = dynamic_runtime_params_slice .transport
70+
71+ # Calculate the transport coefficients - includes contribution from pedestal
72+ # and core transport models.
73+ transport_coeffs = self ._call_implementation (
74+ transport_runtime_params ,
75+ dynamic_runtime_params_slice ,
76+ geo ,
77+ core_profiles ,
78+ pedestal_model_output ,
79+ )
80+
81+ # In contrast to the base TransportModel, we do not apply domain restriction
82+ # as this is handled at the component model level
83+
84+ # Apply min/max clipping
85+ transport_coeffs = self ._apply_clipping (
86+ transport_runtime_params ,
87+ transport_coeffs ,
88+ )
89+
90+ # In contrast to the base TransportModel, we do not apply patches, as these
91+ # should be handled by instantiating constant component models instead.
92+ # However, the rho_inner and rho_outer arguments are currently required
93+ # in the case where the inner/outer region are to be excluded from
94+ # smoothing. Smoothing is applied to
95+ # rho_inner < rho_norm < min(rho_ped_top, rho_outer) unless
96+ # smooth_everywhere is True.
97+ return self ._smooth_coeffs (
98+ transport_runtime_params ,
99+ dynamic_runtime_params_slice ,
100+ geo ,
101+ transport_coeffs ,
102+ pedestal_model_output ,
103+ )
104+
49105 def _call_implementation (
50106 self ,
51107 transport_dynamic_runtime_params : runtime_params_lib .DynamicRuntimeParams ,
@@ -72,48 +128,96 @@ def _call_implementation(
72128 # Required for pytype
73129 assert isinstance (transport_dynamic_runtime_params , DynamicRuntimeParams )
74130
75- component_transport_coeffs_list = []
76-
77- for component_model , component_params in zip (
78- self .transport_models ,
79- transport_dynamic_runtime_params .transport_model_params ,
80- ):
81- # Use the component model's _call_implementation, rather than __call__
82- # directly. This ensures postprocessing (clipping, smoothing, patches) are
83- # performed on the output of CombinedTransportModel rather than its
84- # component models.
131+ def apply_and_restrict (
132+ component_model : transport_model_lib .TransportModel ,
133+ component_params : runtime_params_lib .DynamicRuntimeParams ,
134+ restriction_fn : Callable [
135+ [
136+ runtime_params_lib .DynamicRuntimeParams ,
137+ geometry .Geometry ,
138+ transport_model_lib .TurbulentTransport ,
139+ pedestal_model_lib .PedestalModelOutput ,
140+ ],
141+ transport_model_lib .TurbulentTransport ,
142+ ],
143+ ) -> transport_model_lib .TurbulentTransport :
144+ # TODO(b/434175682): Consider only computing transport coefficients for
145+ # the active domain, rather than masking them out later. This could be
146+ # significantly more efficient especially for pedestal models, as these
147+ # are only active in a small region of the domain.
85148 component_transport_coeffs = component_model ._call_implementation (
86149 component_params ,
87150 dynamic_runtime_params_slice ,
88151 geo ,
89152 core_profiles ,
90153 pedestal_model_output ,
91154 )
92-
93- # Apply domain restriction
94- # This is a property of each component_model, so needs to be applied
95- # at the component model level rather than the global level
96- component_transport_coeffs = component_model ._apply_domain_restriction (
155+ component_transport_coeffs = restriction_fn (
97156 component_params ,
98157 geo ,
99158 component_transport_coeffs ,
100159 pedestal_model_output ,
101160 )
102-
103- component_transport_coeffs_list .append (component_transport_coeffs )
104-
161+ return component_transport_coeffs
162+
163+ pedestal_coeffs = [
164+ apply_and_restrict (
165+ model , params , self ._apply_pedestal_domain_restriction
166+ )
167+ for model , params in zip (
168+ self .pedestal_transport_models ,
169+ transport_dynamic_runtime_params .pedestal_transport_model_params ,
170+ )
171+ ]
172+
173+ core_coeffs = [
174+ apply_and_restrict (model , params , model ._apply_domain_restriction )
175+ for model , params in zip (
176+ self .transport_models ,
177+ transport_dynamic_runtime_params .transport_model_params ,
178+ )
179+ ]
180+
181+ # Combine the transport coefficients from core and pedestal models.
105182 combined_transport_coeffs = jax .tree .map (
106183 lambda * leaves : sum (leaves ),
107- * component_transport_coeffs_list ,
184+ * pedestal_coeffs ,
185+ * core_coeffs ,
108186 )
109187
110188 return combined_transport_coeffs
111189
190+ def _apply_pedestal_domain_restriction (
191+ self ,
192+ unused_transport_runtime_params : runtime_params_lib .DynamicRuntimeParams ,
193+ geo : geometry .Geometry ,
194+ transport_coeffs : transport_model_lib .TurbulentTransport ,
195+ pedestal_model_output : pedestal_model_lib .PedestalModelOutput ,
196+ ) -> transport_model_lib .TurbulentTransport :
197+ del unused_transport_runtime_params
198+ active_mask = geo .rho_face_norm > pedestal_model_output .rho_norm_ped_top
199+
200+ chi_face_ion = jnp .where (active_mask , transport_coeffs .chi_face_ion , 0.0 )
201+ chi_face_el = jnp .where (active_mask , transport_coeffs .chi_face_el , 0.0 )
202+ d_face_el = jnp .where (active_mask , transport_coeffs .d_face_el , 0.0 )
203+ v_face_el = jnp .where (active_mask , transport_coeffs .v_face_el , 0.0 )
204+
205+ return dataclasses .replace (
206+ transport_coeffs ,
207+ chi_face_ion = chi_face_ion ,
208+ chi_face_el = chi_face_el ,
209+ d_face_el = d_face_el ,
210+ v_face_el = v_face_el ,
211+ )
212+
112213 def __hash__ (self ):
113- return hash (tuple (self .transport_models ))
214+ return hash (
215+ tuple (self .transport_models ) + tuple (self .pedestal_transport_models )
216+ )
114217
115218 def __eq__ (self , other ):
116219 return (
117220 isinstance (other , CombinedTransportModel )
118221 and self .transport_models == other .transport_models
222+ and self .pedestal_transport_models == other .pedestal_transport_models
119223 )
0 commit comments