1616from collections .abc import Sequence
1717import dataclasses
1818import enum
19-
2019import chex
2120import jax
2221import jax .numpy as jnp
2322import numpy as np
23+ from torax ._src import array_typing
2424from torax ._src .torax_pydantic import torax_pydantic
2525
2626
27- def face_to_cell (face : chex .Array ) -> chex .Array :
27+ def face_to_cell (
28+ face : array_typing .FloatVectorFace ,
29+ ) -> array_typing .FloatVectorCell :
2830 """Infers cell values corresponding to a vector of face values.
2931
3032 Simply a linear interpolation between face values.
@@ -170,45 +172,45 @@ class Geometry:
170172
171173 geometry_type : GeometryType
172174 torax_mesh : torax_pydantic .Grid1D
173- Phi : chex .Array
174- Phi_face : chex .Array
175- R_major : chex . Array
176- a_minor : chex . Array
177- B_0 : chex . Array
178- volume : chex .Array
179- volume_face : chex .Array
180- area : chex .Array
181- area_face : chex .Array
182- vpr : chex .Array
183- vpr_face : chex .Array
184- spr : chex .Array
185- spr_face : chex .Array
186- delta_face : chex .Array
187- elongation : chex .Array
188- elongation_face : chex .Array
189- g0 : chex .Array
190- g0_face : chex .Array
191- g1 : chex .Array
192- g1_face : chex .Array
193- g2 : chex .Array
194- g2_face : chex .Array
195- g3 : chex .Array
196- g3_face : chex .Array
197- g2g3_over_rhon : chex .Array
198- g2g3_over_rhon_face : chex .Array
199- g2g3_over_rhon_hires : chex .Array
200- F : chex .Array
201- F_face : chex .Array
202- F_hires : chex .Array
203- R_in : chex .Array
204- R_in_face : chex .Array
205- R_out : chex .Array
206- R_out_face : chex .Array
207- spr_hires : chex .Array
208- rho_hires_norm : chex .Array
209- rho_hires : chex .Array
210- Phi_b_dot : chex . Array
211- _z_magnetic_axis : chex . Array | None
175+ Phi : array_typing .Array
176+ Phi_face : array_typing .Array
177+ R_major : array_typing . FloatScalar
178+ a_minor : array_typing . FloatScalar
179+ B_0 : array_typing . FloatScalar
180+ volume : array_typing .Array
181+ volume_face : array_typing .Array
182+ area : array_typing .Array
183+ area_face : array_typing .Array
184+ vpr : array_typing .Array
185+ vpr_face : array_typing .Array
186+ spr : array_typing .Array
187+ spr_face : array_typing .Array
188+ delta_face : array_typing .Array
189+ elongation : array_typing .Array
190+ elongation_face : array_typing .Array
191+ g0 : array_typing .Array
192+ g0_face : array_typing .Array
193+ g1 : array_typing .Array
194+ g1_face : array_typing .Array
195+ g2 : array_typing .Array
196+ g2_face : array_typing .Array
197+ g3 : array_typing .Array
198+ g3_face : array_typing .Array
199+ g2g3_over_rhon : array_typing .Array
200+ g2g3_over_rhon_face : array_typing .Array
201+ g2g3_over_rhon_hires : array_typing .Array
202+ F : array_typing .Array
203+ F_face : array_typing .Array
204+ F_hires : array_typing .Array
205+ R_in : array_typing .Array
206+ R_in_face : array_typing .Array
207+ R_out : array_typing .Array
208+ R_out_face : array_typing .Array
209+ spr_hires : array_typing .Array
210+ rho_hires_norm : array_typing .Array
211+ rho_hires : array_typing .Array
212+ Phi_b_dot : array_typing . FloatScalar
213+ _z_magnetic_axis : array_typing . FloatScalar | None
212214
213215 def __eq__ (self , other : 'Geometry' ) -> bool :
214216 try :
@@ -230,27 +232,27 @@ def q_correction_factor(self) -> chex.Numeric:
230232 )
231233
232234 @property
233- def rho_norm (self ) -> chex .Array :
235+ def rho_norm (self ) -> array_typing .Array :
234236 r"""Normalized toroidal flux coordinate on cell grid [dimensionless]."""
235237 return self .torax_mesh .cell_centers
236238
237239 @property
238- def rho_face_norm (self ) -> chex .Array :
240+ def rho_face_norm (self ) -> array_typing .Array :
239241 r"""Normalized toroidal flux coordinate on face grid [dimensionless]."""
240242 return self .torax_mesh .face_centers
241243
242244 @property
243- def drho_norm (self ) -> chex .Array :
245+ def drho_norm (self ) -> array_typing .Array :
244246 r"""Grid size for rho_norm [dimensionless]."""
245247 return jnp .array (self .torax_mesh .dx )
246248
247249 @property
248- def rho_face (self ) -> chex .Array :
250+ def rho_face (self ) -> array_typing .Array :
249251 r"""Toroidal flux coordinate on face grid :math:`\mathrm{m}`."""
250252 return self .rho_face_norm * jnp .expand_dims (self .rho_b , axis = - 1 )
251253
252254 @property
253- def rho (self ) -> chex .Array :
255+ def rho (self ) -> array_typing .Array :
254256 r"""Toroidal flux coordinate on cell grid :math:`\mathrm{m}`.
255257
256258 The toroidal flux coordinate is defined as
@@ -261,49 +263,49 @@ def rho(self) -> chex.Array:
261263 return self .rho_norm * jnp .expand_dims (self .rho_b , axis = - 1 )
262264
263265 @property
264- def r_mid (self ) -> chex .Array :
266+ def r_mid (self ) -> array_typing .Array :
265267 """Midplane radius of the plasma [m], defined as (Rout-Rin)/2."""
266268 return (self .R_out - self .R_in ) / 2
267269
268270 @property
269- def r_mid_face (self ) -> chex .Array :
271+ def r_mid_face (self ) -> array_typing .Array :
270272 """Midplane radius of the plasma on the face grid [m]."""
271273 return (self .R_out_face - self .R_in_face ) / 2
272274
273275 @property
274- def epsilon (self ) -> chex .Array :
276+ def epsilon (self ) -> array_typing .Array :
275277 """Local midplane inverse aspect ratio [dimensionless]."""
276278 return (self .R_out - self .R_in ) / (self .R_out + self .R_in )
277279
278280 @property
279- def epsilon_face (self ) -> chex .Array :
281+ def epsilon_face (self ) -> array_typing .Array :
280282 """Local midplane inverse aspect ratio on the face grid [dimensionless]."""
281283 return (self .R_out_face - self .R_in_face ) / (
282284 self .R_out_face + self .R_in_face
283285 )
284286
285287 @property
286- def drho (self ) -> chex .Array :
288+ def drho (self ) -> array_typing .Array :
287289 """Grid size for rho [m]."""
288290 return self .drho_norm * self .rho_b
289291
290292 @property
291- def rho_b (self ) -> chex . Array :
293+ def rho_b (self ) -> array_typing . FloatScalar :
292294 """Toroidal flux coordinate [m] at boundary (LCFS)."""
293295 return jnp .sqrt (self .Phi_b / np .pi / self .B_0 )
294296
295297 @property
296- def Phi_b (self ) -> chex . Array :
298+ def Phi_b (self ) -> array_typing . FloatScalar :
297299 r"""Toroidal flux at boundary (LCFS) :math:`\mathrm{Wb}`."""
298300 return self .Phi_face [..., - 1 ]
299301
300302 @property
301- def g1_over_vpr (self ) -> chex .Array :
303+ def g1_over_vpr (self ) -> array_typing .Array :
302304 r"""g1/vpr [:math:`\mathrm{m}`]."""
303305 return self .g1 / self .vpr
304306
305307 @property
306- def g1_over_vpr2 (self ) -> chex .Array :
308+ def g1_over_vpr2 (self ) -> array_typing .Array :
307309 r"""g1/vpr**2 [:math:`\mathrm{m}^{-2}`]."""
308310 return self .g1 / self .vpr ** 2
309311
@@ -389,7 +391,7 @@ def stack_geometries(geometries: Sequence[Geometry]) -> Geometry:
389391 field_name = field .name
390392 field_value = getattr (first_geo , field_name )
391393 # Stack stackable fields. Save first geo's value for non-stackable fields.
392- if isinstance (field_value , chex .Array ) or isinstance ( field_value , float ):
394+ if isinstance (field_value , ( array_typing .Array , array_typing . FloatScalar ) ):
393395 field_values = [getattr (geo , field_name ) for geo in geometries ]
394396 stacked_data [field_name ] = np .stack (field_values )
395397 else :
0 commit comments