1919import chex
2020import jax
2121import jax .numpy as jnp
22+ import jaxtyping as jt
2223import numpy as np
23- from torax ._src import array_typing
24+ from torax ._src import array_typing as at
2425from torax ._src .torax_pydantic import torax_pydantic
2526
2627
27- def face_to_cell (
28- face : array_typing .FloatVectorFace ,
29- ) -> array_typing .FloatVectorCell :
28+ def face_to_cell (face : at .FloatVectorFace ) -> at .FloatVectorCell :
3029 """Infers cell values corresponding to a vector of face values.
3130
3231 Simply a linear interpolation between face values.
@@ -59,6 +58,7 @@ class GeometryType(enum.IntEnum):
5958# pylint: disable=invalid-name
6059
6160
61+ @at .jaxtyped
6262@jax .tree_util .register_dataclass
6363@dataclasses .dataclass (frozen = True )
6464class Geometry :
@@ -170,47 +170,47 @@ class Geometry:
170170 [:math:`\mathrm{m}`].
171171 """
172172
173- geometry_type : GeometryType
173+ geometry_type : GeometryType = dataclasses . field ( metadata = dict ( static = True ))
174174 torax_mesh : torax_pydantic .Grid1D
175- Phi : array_typing . Array
176- Phi_face : array_typing . Array
177- R_major : array_typing . FloatScalar
178- a_minor : array_typing . FloatScalar
179- B_0 : array_typing . FloatScalar
180- volume : array_typing . Array
181- volume_face : array_typing . Array
182- area : array_typing . Array
183- area_face : array_typing . Array
184- vpr : array_typing . Array
185- vpr_face : array_typing . Array
186- spr : array_typing . Array
187- spr_face : array_typing . Array
188- delta_face : array_typing . Array
189- elongation : array_typing . Array
190- elongation_face : array_typing . Array
191- g0 : array_typing . Array
192- g0_face : array_typing . Array
193- g1 : array_typing . Array
194- g1_face : array_typing . Array
195- g2 : array_typing . Array
196- g2_face : array_typing . Array
197- g3 : array_typing . Array
198- g3_face : array_typing . Array
199- g2g3_over_rhon : array_typing . Array
200- g2g3_over_rhon_face : array_typing . Array
201- g2g3_over_rhon_hires : array_typing . Array
202- F : array_typing . Array
203- F_face : array_typing . Array
204- F_hires : array_typing . Array
205- R_in : array_typing . Array
206- R_in_face : array_typing . Array
207- R_out : array_typing . Array
208- R_out_face : array_typing . Array
209- spr_hires : array_typing . Array
210- rho_hires_norm : array_typing . Array
211- rho_hires : array_typing . Array
212- Phi_b_dot : array_typing . FloatScalar
213- _z_magnetic_axis : array_typing . FloatScalar | None
175+ Phi : jt . Float [ at . Array , '*stack rhon' ]
176+ Phi_face : jt . Float [ at . Array , '*stack rhon+1' ]
177+ R_major : jt . Float [ at . Array , '*stack' ]
178+ a_minor : jt . Float [ at . Array , '*stack' ]
179+ B_0 : jt . Float [ at . Array , '*stack' ]
180+ volume : jt . Float [ at . Array , '*stack rhon' ]
181+ volume_face : jt . Float [ at . Array , '*stack rhon+1' ]
182+ area : jt . Float [ at . Array , '*stack rhon' ]
183+ area_face : jt . Float [ at . Array , '*stack rhon+1' ]
184+ vpr : jt . Float [ at . Array , '*stack rhon' ]
185+ vpr_face : jt . Float [ at . Array , '*stack rhon+1' ]
186+ spr : jt . Float [ at . Array , '*stack rhon' ]
187+ spr_face : jt . Float [ at . Array , '*stack rhon+1' ]
188+ delta_face : jt . Float [ at . Array , '*stack rhon+1' ]
189+ elongation : jt . Float [ at . Array , '*stack rhon' ]
190+ elongation_face : jt . Float [ at . Array , '*stack rhon+1' ]
191+ g0 : jt . Float [ at . Array , '*stack rhon' ]
192+ g0_face : jt . Float [ at . Array , '*stack rhon+1' ]
193+ g1 : jt . Float [ at . Array , '*stack rhon' ]
194+ g1_face : jt . Float [ at . Array , '*stack rhon+1' ]
195+ g2 : jt . Float [ at . Array , '*stack rhon' ]
196+ g2_face : jt . Float [ at . Array , '*stack rhon+1' ]
197+ g3 : jt . Float [ at . Array , '*stack rhon' ]
198+ g3_face : jt . Float [ at . Array , '*stack rhon+1' ]
199+ g2g3_over_rhon : jt . Float [ at . Array , '*stack rhon' ]
200+ g2g3_over_rhon_face : jt . Float [ at . Array , '*stack rhon+1' ]
201+ g2g3_over_rhon_hires : jt . Float [ at . Array , '*stack channel' ]
202+ F : jt . Float [ at . Array , '*stack rhon' ]
203+ F_face : jt . Float [ at . Array , '*stack rhon+1' ]
204+ F_hires : jt . Float [ at . Array , '*stack channel' ]
205+ R_in : jt . Float [ at . Array , '*stack rhon' ]
206+ R_in_face : jt . Float [ at . Array , '*stack rhon+1' ]
207+ R_out : jt . Float [ at . Array , '*stack rhon' ]
208+ R_out_face : jt . Float [ at . Array , '*stack rhon+1' ]
209+ spr_hires : jt . Float [ at . Array , '*stack channel' ]
210+ rho_hires_norm : jt . Float [ at . Array , '*stack channel' ]
211+ rho_hires : jt . Float [ at . Array , '*stack channel' ]
212+ Phi_b_dot : jt . Float [ at . Array | float , '*stack' ]
213+ _z_magnetic_axis : jt . Float [ at . Array , '*stack' ] | None
214214
215215 def __eq__ (self , other : 'Geometry' ) -> bool :
216216 try :
@@ -220,7 +220,7 @@ def __eq__(self, other: 'Geometry') -> bool:
220220 return True
221221
222222 @property
223- def q_correction_factor (self ) -> chex . Numeric :
223+ def q_correction_factor (self ) -> jt . Float [ jax . Array , '' ] :
224224 """Ad-hoc fix for non-physical circular geometry model.
225225
226226 Set such that q(r=a) = 3 for standard ITER parameters.
@@ -232,27 +232,27 @@ def q_correction_factor(self) -> chex.Numeric:
232232 )
233233
234234 @property
235- def rho_norm (self ) -> array_typing . Array :
235+ def rho_norm (self ) -> at . FloatVectorCell :
236236 r"""Normalized toroidal flux coordinate on cell grid [dimensionless]."""
237237 return self .torax_mesh .cell_centers
238238
239239 @property
240- def rho_face_norm (self ) -> array_typing . Array :
240+ def rho_face_norm (self ) -> at . FloatVectorFace :
241241 r"""Normalized toroidal flux coordinate on face grid [dimensionless]."""
242242 return self .torax_mesh .face_centers
243243
244244 @property
245- def drho_norm (self ) -> array_typing . Array :
245+ def drho_norm (self ) -> at . FloatVectorCell :
246246 r"""Grid size for rho_norm [dimensionless]."""
247247 return jnp .array (self .torax_mesh .dx )
248248
249249 @property
250- def rho_face (self ) -> array_typing . Array :
250+ def rho_face (self ) -> at . FloatVectorFace :
251251 r"""Toroidal flux coordinate on face grid :math:`\mathrm{m}`."""
252252 return self .rho_face_norm * jnp .expand_dims (self .rho_b , axis = - 1 )
253253
254254 @property
255- def rho (self ) -> array_typing . Array :
255+ def rho (self ) -> jt . Float [ jax . Array , 'rhon' ] :
256256 r"""Toroidal flux coordinate on cell grid :math:`\mathrm{m}`.
257257
258258 The toroidal flux coordinate is defined as
@@ -263,54 +263,54 @@ def rho(self) -> array_typing.Array:
263263 return self .rho_norm * jnp .expand_dims (self .rho_b , axis = - 1 )
264264
265265 @property
266- def r_mid (self ) -> array_typing . Array :
266+ def r_mid (self ) -> jt . Float [ at . Array , '*stack rhon' ] :
267267 """Midplane radius of the plasma [m], defined as (Rout-Rin)/2."""
268268 return (self .R_out - self .R_in ) / 2
269269
270270 @property
271- def r_mid_face (self ) -> array_typing . Array :
271+ def r_mid_face (self ) -> jt . Float [ at . Array , '*stack rhon+1' ] :
272272 """Midplane radius of the plasma on the face grid [m]."""
273273 return (self .R_out_face - self .R_in_face ) / 2
274274
275275 @property
276- def epsilon (self ) -> array_typing . Array :
276+ def epsilon (self ) -> jt . Float [ at . Array , '*stack rhon' ] :
277277 """Local midplane inverse aspect ratio [dimensionless]."""
278278 return (self .R_out - self .R_in ) / (self .R_out + self .R_in )
279279
280280 @property
281- def epsilon_face (self ) -> array_typing . Array :
281+ def epsilon_face (self ) -> jt . Float [ at . Array , '*stack rhon+1' ] :
282282 """Local midplane inverse aspect ratio on the face grid [dimensionless]."""
283283 return (self .R_out_face - self .R_in_face ) / (
284284 self .R_out_face + self .R_in_face
285285 )
286286
287287 @property
288- def drho (self ) -> array_typing . Array :
288+ def drho (self ) -> jt . Float [ at . Array , 'rhon' ] :
289289 """Grid size for rho [m]."""
290290 return self .drho_norm * self .rho_b
291291
292292 @property
293- def rho_b (self ) -> array_typing . FloatScalar :
293+ def rho_b (self ) -> jt . Float [ jax . Array , '*stack' ] :
294294 """Toroidal flux coordinate [m] at boundary (LCFS)."""
295295 return jnp .sqrt (self .Phi_b / np .pi / self .B_0 )
296296
297297 @property
298- def Phi_b (self ) -> array_typing . FloatScalar :
298+ def Phi_b (self ) -> jt . Float [ at . Array , '*stack' ] :
299299 r"""Toroidal flux at boundary (LCFS) :math:`\mathrm{Wb}`."""
300300 return self .Phi_face [..., - 1 ]
301301
302302 @property
303- def g1_over_vpr (self ) -> array_typing . Array :
303+ def g1_over_vpr (self ) -> jt . Float [ at . Array , '*stack rhon' ] :
304304 r"""g1/vpr [:math:`\mathrm{m}`]."""
305305 return self .g1 / self .vpr
306306
307307 @property
308- def g1_over_vpr2 (self ) -> array_typing . Array :
308+ def g1_over_vpr2 (self ) -> jt . Float [ at . Array , '*stack rhon' ] :
309309 r"""g1/vpr**2 [:math:`\mathrm{m}^{-2}`]."""
310310 return self .g1 / self .vpr ** 2
311311
312312 @property
313- def g0_over_vpr_face (self ) -> jax . Array :
313+ def g0_over_vpr_face (self ) -> jt . Float [ at . Array , '*stack rhon_1' ] :
314314 """g0_face/vpr_face [:math:`m^{-1}`], equal to 1/rho_b on-axis."""
315315 # Calculate the bulk of the array (excluding the first element)
316316 # to avoid division by zero.
@@ -322,7 +322,7 @@ def g0_over_vpr_face(self) -> jax.Array:
322322 )
323323
324324 @property
325- def g1_over_vpr_face (self ) -> jax . Array :
325+ def g1_over_vpr_face (self ) -> jt . Float [ at . Array , '*stack rhon_1' ] :
326326 r"""g1_face/vpr_face [:math:`\mathrm{m}`]. Zero on-axis."""
327327 bulk = self .g1_face [..., 1 :] / self .vpr_face [..., 1 :]
328328 first_element = jnp .zeros_like (self .rho_b )
@@ -331,7 +331,7 @@ def g1_over_vpr_face(self) -> jax.Array:
331331 )
332332
333333 @property
334- def g1_over_vpr2_face (self ) -> jax . Array :
334+ def g1_over_vpr2_face (self ) -> jt . Float [ at . Array , '*stack rhon_1' ] :
335335 """g1_face/vpr_face**2 [:math:`m^{-2}`], equal to 1/rho_b**2 on-axis."""
336336 bulk = self .g1_face [..., 1 :] / self .vpr_face [..., 1 :] ** 2
337337 first_element = jnp .ones_like (self .rho_b ) / self .rho_b ** 2
@@ -340,20 +340,20 @@ def g1_over_vpr2_face(self) -> jax.Array:
340340 )
341341
342342 @property
343- def gm9 (self ) -> jax . Array :
343+ def gm9 (self ) -> jt . Float [ at . Array , '*stack rhon' ] :
344344 r"""<1/R> on cell grid [:math:`\mathrm{m}^{-1}`]."""
345345 return 2 * jnp .pi * self .spr / self .vpr
346346
347347 @property
348- def gm9_face (self ) -> jax . Array :
348+ def gm9_face (self ) -> jt . Float [ at . Array , '*stack rhon_1' ] :
349349 r"""<1/R> on face grid [:math:`\mathrm{m}^{-1}`]."""
350350 bulk = 2 * jnp .pi * self .spr_face [..., 1 :] / self .vpr_face [..., 1 :]
351351 first_element = 1 / self .R_major
352352 return jnp .concatenate (
353353 [jnp .expand_dims (first_element , axis = - 1 ), bulk ], axis = - 1
354354 )
355355
356- def z_magnetic_axis (self ) -> chex . Numeric :
356+ def z_magnetic_axis (self ) -> jt . Float [ jax . Array , '*stack' ] :
357357 """z position of magnetic axis [m]."""
358358 z_magnetic_axis = self ._z_magnetic_axis
359359 if z_magnetic_axis is not None :
@@ -391,7 +391,7 @@ def stack_geometries(geometries: Sequence[Geometry]) -> Geometry:
391391 field_name = field .name
392392 field_value = getattr (first_geo , field_name )
393393 # Stack stackable fields. Save first geo's value for non-stackable fields.
394- if isinstance (field_value , (array_typing .Array , array_typing .FloatScalar )):
394+ if isinstance (field_value , (at .Array , at .FloatScalar )):
395395 field_values = [getattr (geo , field_name ) for geo in geometries ]
396396 stacked_data [field_name ] = np .stack (field_values )
397397 else :
@@ -401,9 +401,10 @@ def stack_geometries(geometries: Sequence[Geometry]) -> Geometry:
401401 return first_geo .__class__ (** stacked_data )
402402
403403
404+ @at .jaxtyped
404405def update_geometries_with_Phibdot (
405406 * ,
406- dt : chex . Numeric ,
407+ dt : at . FloatScalar ,
407408 geo_t : Geometry ,
408409 geo_t_plus_dt : Geometry ,
409410) -> tuple [Geometry , Geometry ]:
0 commit comments