1717A class for combining transport models.
1818"""
1919import dataclasses
20- from typing import Sequence
20+ from typing import Callable , Sequence
2121
2222import jax
23+ import jax .numpy as jnp
2324from torax ._src import state
2425from torax ._src .config import runtime_params_slice
2526from torax ._src .geometry import geometry
3435@dataclasses .dataclass (frozen = True )
3536class DynamicRuntimeParams (runtime_params_lib .DynamicRuntimeParams ):
3637 transport_model_params : Sequence [runtime_params_lib .DynamicRuntimeParams ]
38+ pedestal_transport_model_params : Sequence [
39+ runtime_params_lib .DynamicRuntimeParams
40+ ]
3741
3842
3943class CombinedTransportModel (transport_model_lib .TransportModel ):
4044 """Combines coefficients from a list of transport models."""
4145
4246 def __init__ (
43- self , transport_models : Sequence [transport_model_lib .TransportModel ]
47+ self ,
48+ transport_models : Sequence [transport_model_lib .TransportModel ],
49+ pedestal_transport_models : Sequence [transport_model_lib .TransportModel ],
4450 ):
4551 super ().__init__ ()
4652 self .transport_models = transport_models
53+ self .pedestal_transport_models = pedestal_transport_models
4754 self ._frozen = True
4855
56+ def __call__ (
57+ self ,
58+ dynamic_runtime_params_slice : runtime_params_slice .DynamicRuntimeParamsSlice ,
59+ geo : geometry .Geometry ,
60+ core_profiles : state .CoreProfiles ,
61+ pedestal_model_output : pedestal_model_lib .PedestalModelOutput ,
62+ ) -> transport_model_lib .TurbulentTransport :
63+ if not getattr (self , "_frozen" , False ):
64+ raise RuntimeError (
65+ f"Subclass implementation { type (self )} forgot to "
66+ "freeze at the end of __init__."
67+ )
68+
69+ transport_runtime_params = dynamic_runtime_params_slice .transport
70+
71+ # Calculate the transport coefficients - includes contribution from pedestal
72+ # and core transport models.
73+ transport_coeffs = self ._call_implementation (
74+ transport_runtime_params ,
75+ dynamic_runtime_params_slice ,
76+ geo ,
77+ core_profiles ,
78+ pedestal_model_output ,
79+ )
80+
81+ # In contrast to the base TransportModel, we do not apply domain restriction
82+ # as this is handled at the component model level
83+
84+ # Apply min/max clipping
85+ transport_coeffs = self ._apply_clipping (
86+ transport_runtime_params ,
87+ transport_coeffs ,
88+ )
89+
90+ # In contrast to the base TransportModel, we do not apply patches, as these
91+ # should be handled by instantiating constant component models instead.
92+ # However, the rho_inner and rho_outer arguments are currently required
93+ # in the case where the inner/outer region are to be excluded from smoothing.
94+ # Smoothing is applied to rho_inner < rho_norm < min(rho_ped_top, rho_outer)
95+ # unless smooth_everywhere is True.
96+ return self ._smooth_coeffs (
97+ transport_runtime_params ,
98+ dynamic_runtime_params_slice ,
99+ geo ,
100+ transport_coeffs ,
101+ pedestal_model_output ,
102+ )
103+
49104 def _call_implementation (
50105 self ,
51106 transport_dynamic_runtime_params : runtime_params_lib .DynamicRuntimeParams ,
@@ -72,48 +127,85 @@ def _call_implementation(
72127 # Required for pytype
73128 assert isinstance (transport_dynamic_runtime_params , DynamicRuntimeParams )
74129
75- component_transport_coeffs_list = []
76-
77- for component_model , component_params in zip (
78- self .transport_models ,
79- transport_dynamic_runtime_params .transport_model_params ,
80- ):
81- # Use the component model's _call_implementation, rather than __call__
82- # directly. This ensures postprocessing (clipping, smoothing, patches) are
83- # performed on the output of CombinedTransportModel rather than its
84- # component models.
130+ def apply_and_restrict (
131+ component_model : transport_model_lib .TransportModel ,
132+ component_params : runtime_params_lib .DynamicRuntimeParams ,
133+ restriction_fn : Callable ,
134+ ) -> transport_model_lib .TurbulentTransport :
135+ # TODO: Consider only computing transport coefficients for the active
136+ # domain, rather than masking them out later. This could be significantly
137+ # more efficient especially for pedestal models, as these are only active
138+ # in a small region of the domain.
85139 component_transport_coeffs = component_model ._call_implementation (
86140 component_params ,
87141 dynamic_runtime_params_slice ,
88142 geo ,
89143 core_profiles ,
90144 pedestal_model_output ,
91145 )
92-
93- # Apply domain restriction
94- # This is a property of each component_model, so needs to be applied
95- # at the component model level rather than the global level
96- component_transport_coeffs = component_model ._apply_domain_restriction (
146+ component_transport_coeffs = restriction_fn (
97147 component_params ,
98148 geo ,
99149 component_transport_coeffs ,
100150 pedestal_model_output ,
101151 )
102-
103- component_transport_coeffs_list .append (component_transport_coeffs )
104-
152+ return component_transport_coeffs
153+
154+ pedestal_coeffs = [
155+ apply_and_restrict (
156+ model , params , self ._apply_pedestal_domain_restriction
157+ )
158+ for model , params in zip (
159+ self .pedestal_transport_models ,
160+ transport_dynamic_runtime_params .pedestal_transport_model_params ,
161+ )
162+ ]
163+
164+ core_coeffs = [
165+ apply_and_restrict (model , params , model ._apply_domain_restriction )
166+ for model , params in zip (
167+ self .transport_models ,
168+ transport_dynamic_runtime_params .transport_model_params ,
169+ )
170+ ]
171+
172+ # Combine the transport coefficients from core and pedestal models.
105173 combined_transport_coeffs = jax .tree .map (
106174 lambda * leaves : sum (leaves ),
107- * component_transport_coeffs_list ,
175+ * pedestal_coeffs ,
176+ * core_coeffs ,
108177 )
109178
110179 return combined_transport_coeffs
111180
181+ def _apply_pedestal_domain_restriction (
182+ self ,
183+ transport_runtime_params : runtime_params_lib .DynamicRuntimeParams ,
184+ geo : geometry .Geometry ,
185+ transport_coeffs : transport_model_lib .TurbulentTransport ,
186+ pedestal_model_output : pedestal_model_lib .PedestalModelOutput ,
187+ ) -> transport_model_lib .TurbulentTransport :
188+ active_mask = geo .rho_face_norm > pedestal_model_output .rho_norm_ped_top
189+
190+ chi_face_ion = jnp .where (active_mask , transport_coeffs .chi_face_ion , 0.0 )
191+ chi_face_el = jnp .where (active_mask , transport_coeffs .chi_face_el , 0.0 )
192+ d_face_el = jnp .where (active_mask , transport_coeffs .d_face_el , 0.0 )
193+ v_face_el = jnp .where (active_mask , transport_coeffs .v_face_el , 0.0 )
194+
195+ return dataclasses .replace (
196+ transport_coeffs ,
197+ chi_face_ion = chi_face_ion ,
198+ chi_face_el = chi_face_el ,
199+ d_face_el = d_face_el ,
200+ v_face_el = v_face_el ,
201+ )
202+
112203 def __hash__ (self ):
113- return hash (tuple (self .transport_models ))
204+ return hash (tuple (self .transport_models + self . pedestal_transport_models ))
114205
115206 def __eq__ (self , other ):
116207 return (
117208 isinstance (other , CombinedTransportModel )
118209 and self .transport_models == other .transport_models
210+ and self .pedestal_transport_models == other .pedestal_transport_models
119211 )
0 commit comments