1- # Copyright 2024 DeepMind Technologies Limited
1+ 7 # Copyright 2024 DeepMind Technologies Limited
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1818"""
1919
2020from typing import Sequence
21+ import dataclasses
2122
2223import chex
2324import jax
25+ import jax .numpy as jnp
2426from torax ._src import state
2527from torax ._src .config import runtime_params_slice
2628from torax ._src .geometry import geometry
3335
3436@chex .dataclass (frozen = True )
3537class DynamicRuntimeParams (runtime_params_lib .DynamicRuntimeParams ):
36- transport_model_params : Sequence [runtime_params_lib .DynamicRuntimeParams ]
38+ core_transport_model_params : Sequence [runtime_params_lib .DynamicRuntimeParams ]
39+ pedestal_transport_model_params : Sequence [
40+ runtime_params_lib .DynamicRuntimeParams
41+ ]
3742
3843
3944class CombinedTransportModel (transport_model_lib .TransportModel ):
4045 """Combines coefficients from a list of transport models."""
4146
4247 def __init__ (
43- self , transport_models : Sequence [transport_model_lib .TransportModel ]
48+ self ,
49+ core_transport_models : Sequence [transport_model_lib .TransportModel ],
50+ pedestal_transport_model : transport_model_lib .TransportModel ,
4451 ):
4552 super ().__init__ ()
46- self .transport_models = transport_models
53+ self .core_transport_models = core_transport_models
54+ self .pedestal_transport_model = pedestal_transport_model
4755 self ._frozen = True
4856
57+ def __call__ (
58+ self ,
59+ dynamic_runtime_params_slice : runtime_params_slice .DynamicRuntimeParamsSlice ,
60+ geo : geometry .Geometry ,
61+ core_profiles : state .CoreProfiles ,
62+ pedestal_model_output : pedestal_model_lib .PedestalModelOutput ,
63+ ) -> transport_model_lib .TurbulentTransport :
64+ if not getattr (self , "_frozen" , False ):
65+ raise RuntimeError (
66+ f"Subclass implementation { type (self )} forgot to "
67+ "freeze at the end of __init__."
68+ )
69+
70+ transport_runtime_params = dynamic_runtime_params_slice .transport
71+
72+ # Calculate the transport coefficients - includes contribution from pedestal
73+ # and core transport models.
74+ transport_coeffs = self ._call_implementation (
75+ transport_runtime_params ,
76+ dynamic_runtime_params_slice ,
77+ geo ,
78+ core_profiles ,
79+ pedestal_model_output ,
80+ )
81+
82+ # In contrast to the base TransportModel, we do not apply domain restriction
83+
84+ # Apply min/max clipping
85+ transport_coeffs = self ._apply_clipping (
86+ transport_runtime_params ,
87+ transport_coeffs ,
88+ )
89+
90+ # In contrast to the base TransportModel, we do not apply patches
91+
92+ # Return smoothed coefficients if smoothing is enabled
93+ # TODO: what should be done about masking? In base TransportModel, the
94+ # pedestal and any patches are masked out before smoothing. May require a
95+ # custom implementation of _smooth_coeffs for CombinedTransportModel.
96+ return self ._smooth_coeffs (
97+ transport_runtime_params ,
98+ dynamic_runtime_params_slice ,
99+ geo ,
100+ transport_coeffs ,
101+ pedestal_model_output ,
102+ )
103+
49104 def _call_implementation (
50105 self ,
51106 transport_dynamic_runtime_params : runtime_params_lib .DynamicRuntimeParams ,
@@ -72,16 +127,15 @@ def _call_implementation(
72127 # Required for pytype
73128 assert isinstance (transport_dynamic_runtime_params , DynamicRuntimeParams )
74129
75- component_transport_coeffs_list = []
76-
130+ # Core transport
131+ core_transport_coeffs_list = []
77132 for component_model , component_params in zip (
78- self .transport_models ,
79- transport_dynamic_runtime_params .transport_model_params ,
133+ self .core_transport_models ,
134+ transport_dynamic_runtime_params .core_transport_model_params ,
80135 ):
81136 # Use the component model's _call_implementation, rather than __call__
82137 # directly. This ensures postprocessing (clipping, smoothing, patches) are
83- # performed on the output of CombinedTransportModel rather than its
84- # component models.
138+ # performed on the combined output rather than the individual components.
85139 component_transport_coeffs = component_model ._call_implementation (
86140 component_params ,
87141 dynamic_runtime_params_slice ,
@@ -100,20 +154,63 @@ def _call_implementation(
100154 pedestal_model_output ,
101155 )
102156
103- component_transport_coeffs_list .append (component_transport_coeffs )
157+ core_transport_coeffs_list .append (component_transport_coeffs )
158+
159+ # Pedestal transport
160+ pedestal_transport_coeffs = (
161+ self .pedestal_transport_model ._call_implementation (
162+ transport_dynamic_runtime_params .pedestal_transport_model_params ,
163+ dynamic_runtime_params_slice ,
164+ geo ,
165+ core_profiles ,
166+ pedestal_model_output ,
167+ )
168+ )
169+ pedestal_transport_coeffs = self ._apply_pedestal_domain_restriction (
170+ transport_dynamic_runtime_params .pedestal_transport_model_params ,
171+ geo ,
172+ pedestal_transport_coeffs ,
173+ pedestal_model_output ,
174+ )
104175
176+ # Combine the transport coefficients from core and pedestal models.
105177 combined_transport_coeffs = jax .tree .map (
106178 lambda * leaves : sum (leaves ),
107- * component_transport_coeffs_list ,
179+ * ( core_transport_coeffs_list + [ pedestal_transport_coeffs ]) ,
108180 )
109181
110182 return combined_transport_coeffs
111183
112184 def __hash__ (self ):
113- return hash (tuple (self .transport_models ))
185+ return hash (
186+ tuple (self .core_transport_models + [self .pedestal_transport_model ])
187+ )
114188
115189 def __eq__ (self , other ):
116190 return (
117191 isinstance (other , CombinedTransportModel )
118- and self .transport_models == other .transport_models
192+ and self .core_transport_models == other .core_transport_models
193+ and self .pedestal_transport_model == other .pedestal_transport_model
194+ )
195+
196+ def _apply_pedestal_domain_restriction (
197+ self ,
198+ transport_runtime_params : runtime_params_lib .DynamicRuntimeParams ,
199+ geo : geometry .Geometry ,
200+ transport_coeffs : transport_model_lib .TurbulentTransport ,
201+ pedestal_model_output : pedestal_model_lib .PedestalModelOutput ,
202+ ) -> transport_model_lib .TurbulentTransport :
203+ active_mask = geo .rho_face_norm > pedestal_model_output .rho_norm_ped_top
204+
205+ chi_face_ion = jnp .where (active_mask , transport_coeffs .chi_face_ion , 0.0 )
206+ chi_face_el = jnp .where (active_mask , transport_coeffs .chi_face_el , 0.0 )
207+ d_face_el = jnp .where (active_mask , transport_coeffs .d_face_el , 0.0 )
208+ v_face_el = jnp .where (active_mask , transport_coeffs .v_face_el , 0.0 )
209+
210+ return dataclasses .replace (
211+ transport_coeffs ,
212+ chi_face_ion = chi_face_ion ,
213+ chi_face_el = chi_face_el ,
214+ d_face_el = d_face_el ,
215+ v_face_el = v_face_el ,
119216 )
0 commit comments