1- # Copyright 2024 DeepMind Technologies Limited
1+ 7 # Copyright 2024 DeepMind Technologies Limited
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1818"""
1919
2020from typing import Sequence
21+ import dataclasses
2122
2223import chex
2324import jax
25+ import jax .numpy as jnp
2426from torax ._src import state
2527from torax ._src .config import runtime_params_slice
2628from torax ._src .geometry import geometry
3335
3436@chex .dataclass (frozen = True )
3537class DynamicRuntimeParams (runtime_params_lib .DynamicRuntimeParams ):
36- transport_model_params : Sequence [runtime_params_lib .DynamicRuntimeParams ]
38+ core_transport_model_params : Sequence [runtime_params_lib .DynamicRuntimeParams ]
39+ pedestal_transport_model_params : Sequence [
40+ runtime_params_lib .DynamicRuntimeParams
41+ ]
3742
3843
3944class CombinedTransportModel (transport_model_lib .TransportModel ):
4045 """Combines coefficients from a list of transport models."""
4146
4247 def __init__ (
43- self , transport_models : Sequence [transport_model_lib .TransportModel ]
48+ self ,
49+ core_transport_models : Sequence [transport_model_lib .TransportModel ],
50+ pedestal_transport_model : transport_model_lib .TransportModel ,
4451 ):
4552 super ().__init__ ()
46- self .transport_models = transport_models
53+ self .core_transport_models = core_transport_models
54+ self .pedestal_transport_model = pedestal_transport_model
4755 self ._frozen = True
4856
57+ def __call__ (
58+ self ,
59+ dynamic_runtime_params_slice : runtime_params_slice .DynamicRuntimeParamsSlice ,
60+ geo : geometry .Geometry ,
61+ core_profiles : state .CoreProfiles ,
62+ pedestal_model_output : pedestal_model_lib .PedestalModelOutput ,
63+ ) -> transport_model_lib .TurbulentTransport :
64+ if not getattr (self , "_frozen" , False ):
65+ raise RuntimeError (
66+ f"Subclass implementation { type (self )} forgot to "
67+ "freeze at the end of __init__."
68+ )
69+
70+ transport_runtime_params = dynamic_runtime_params_slice .transport
71+
72+ # Calculate the transport coefficients - includes contribution from pedestal
73+ # and core transport models.
74+ transport_coeffs = self ._call_implementation (
75+ transport_runtime_params ,
76+ dynamic_runtime_params_slice ,
77+ geo ,
78+ core_profiles ,
79+ pedestal_model_output ,
80+ )
81+
82+ # In contrast to the base TransportModel, we do not apply domain restriction
83+ # as this is handled at the component model level
84+
85+ # Apply min/max clipping
86+ transport_coeffs = self ._apply_clipping (
87+ transport_runtime_params ,
88+ transport_coeffs ,
89+ )
90+
91+ # In contrast to the base TransportModel, we do not apply patches, as these
92+ # should be handled by instantiating constant component models instead.
93+ # However, the rho_inner and rho_outer arguments are currently required
94+ # in the case where the inner/outer region are to be excluded from smoothing.
95+ # Smoothing is applied to rho_inner < rho_norm < min(rho_ped_top, rho_outer)
96+ # unless smooth_everywhere is True.
97+ return self ._smooth_coeffs (
98+ transport_runtime_params ,
99+ dynamic_runtime_params_slice ,
100+ geo ,
101+ transport_coeffs ,
102+ pedestal_model_output ,
103+ )
104+
49105 def _call_implementation (
50106 self ,
51107 transport_dynamic_runtime_params : runtime_params_lib .DynamicRuntimeParams ,
@@ -72,16 +128,15 @@ def _call_implementation(
72128 # Required for pytype
73129 assert isinstance (transport_dynamic_runtime_params , DynamicRuntimeParams )
74130
75- component_transport_coeffs_list = []
76-
131+ # Core transport
132+ core_transport_coeffs_list = []
77133 for component_model , component_params in zip (
78- self .transport_models ,
79- transport_dynamic_runtime_params .transport_model_params ,
134+ self .core_transport_models ,
135+ transport_dynamic_runtime_params .core_transport_model_params ,
80136 ):
81137 # Use the component model's _call_implementation, rather than __call__
82138 # directly. This ensures postprocessing (clipping, smoothing, patches) are
83- # performed on the output of CombinedTransportModel rather than its
84- # component models.
139+ # performed on the combined output rather than the individual components.
85140 component_transport_coeffs = component_model ._call_implementation (
86141 component_params ,
87142 dynamic_runtime_params_slice ,
@@ -100,20 +155,63 @@ def _call_implementation(
100155 pedestal_model_output ,
101156 )
102157
103- component_transport_coeffs_list .append (component_transport_coeffs )
158+ core_transport_coeffs_list .append (component_transport_coeffs )
159+
160+ # Pedestal transport
161+ pedestal_transport_coeffs = (
162+ self .pedestal_transport_model ._call_implementation (
163+ transport_dynamic_runtime_params .pedestal_transport_model_params ,
164+ dynamic_runtime_params_slice ,
165+ geo ,
166+ core_profiles ,
167+ pedestal_model_output ,
168+ )
169+ )
170+ pedestal_transport_coeffs = self ._apply_pedestal_domain_restriction (
171+ transport_dynamic_runtime_params .pedestal_transport_model_params ,
172+ geo ,
173+ pedestal_transport_coeffs ,
174+ pedestal_model_output ,
175+ )
104176
177+ # Combine the transport coefficients from core and pedestal models.
105178 combined_transport_coeffs = jax .tree .map (
106179 lambda * leaves : sum (leaves ),
107- * component_transport_coeffs_list ,
180+ * ( core_transport_coeffs_list + [ pedestal_transport_coeffs ]) ,
108181 )
109182
110183 return combined_transport_coeffs
111184
112185 def __hash__ (self ):
113- return hash (tuple (self .transport_models ))
186+ return hash (
187+ tuple (self .core_transport_models + [self .pedestal_transport_model ])
188+ )
114189
115190 def __eq__ (self , other ):
116191 return (
117192 isinstance (other , CombinedTransportModel )
118- and self .transport_models == other .transport_models
193+ and self .core_transport_models == other .core_transport_models
194+ and self .pedestal_transport_model == other .pedestal_transport_model
195+ )
196+
197+ def _apply_pedestal_domain_restriction (
198+ self ,
199+ transport_runtime_params : runtime_params_lib .DynamicRuntimeParams ,
200+ geo : geometry .Geometry ,
201+ transport_coeffs : transport_model_lib .TurbulentTransport ,
202+ pedestal_model_output : pedestal_model_lib .PedestalModelOutput ,
203+ ) -> transport_model_lib .TurbulentTransport :
204+ active_mask = geo .rho_face_norm > pedestal_model_output .rho_norm_ped_top
205+
206+ chi_face_ion = jnp .where (active_mask , transport_coeffs .chi_face_ion , 0.0 )
207+ chi_face_el = jnp .where (active_mask , transport_coeffs .chi_face_el , 0.0 )
208+ d_face_el = jnp .where (active_mask , transport_coeffs .d_face_el , 0.0 )
209+ v_face_el = jnp .where (active_mask , transport_coeffs .v_face_el , 0.0 )
210+
211+ return dataclasses .replace (
212+ transport_coeffs ,
213+ chi_face_ion = chi_face_ion ,
214+ chi_face_el = chi_face_el ,
215+ d_face_el = d_face_el ,
216+ v_face_el = v_face_el ,
119217 )
0 commit comments