Skip to content

Commit dad4d8a

Browse files
sbodensteinTorax team
authored andcommitted
Stricter typing for geometry.
PiperOrigin-RevId: 800410735
1 parent f04884a commit dad4d8a

File tree

2 files changed

+73
-70
lines changed

2 files changed

+73
-70
lines changed

torax/_src/geometry/geometry.py

Lines changed: 68 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919
import chex
2020
import jax
2121
import jax.numpy as jnp
22+
import jaxtyping as jt
2223
import numpy as np
23-
from torax._src import array_typing
24+
from torax._src import array_typing as at
2425
from torax._src.torax_pydantic import torax_pydantic
2526

2627

27-
def face_to_cell(
28-
face: array_typing.FloatVectorFace,
29-
) -> array_typing.FloatVectorCell:
28+
def face_to_cell(face: at.FloatVectorFace) -> at.FloatVectorCell:
3029
"""Infers cell values corresponding to a vector of face values.
3130
3231
Simply a linear interpolation between face values.
@@ -59,6 +58,7 @@ class GeometryType(enum.IntEnum):
5958
# pylint: disable=invalid-name
6059

6160

61+
@at.jaxtyped
6262
@jax.tree_util.register_dataclass
6363
@dataclasses.dataclass(frozen=True)
6464
class Geometry:
@@ -170,47 +170,47 @@ class Geometry:
170170
[:math:`\mathrm{m}`].
171171
"""
172172

173-
geometry_type: GeometryType
173+
geometry_type: GeometryType = dataclasses.field(metadata=dict(static=True))
174174
torax_mesh: torax_pydantic.Grid1D
175-
Phi: array_typing.Array
176-
Phi_face: array_typing.Array
177-
R_major: array_typing.FloatScalar
178-
a_minor: array_typing.FloatScalar
179-
B_0: array_typing.FloatScalar
180-
volume: array_typing.Array
181-
volume_face: array_typing.Array
182-
area: array_typing.Array
183-
area_face: array_typing.Array
184-
vpr: array_typing.Array
185-
vpr_face: array_typing.Array
186-
spr: array_typing.Array
187-
spr_face: array_typing.Array
188-
delta_face: array_typing.Array
189-
elongation: array_typing.Array
190-
elongation_face: array_typing.Array
191-
g0: array_typing.Array
192-
g0_face: array_typing.Array
193-
g1: array_typing.Array
194-
g1_face: array_typing.Array
195-
g2: array_typing.Array
196-
g2_face: array_typing.Array
197-
g3: array_typing.Array
198-
g3_face: array_typing.Array
199-
g2g3_over_rhon: array_typing.Array
200-
g2g3_over_rhon_face: array_typing.Array
201-
g2g3_over_rhon_hires: array_typing.Array
202-
F: array_typing.Array
203-
F_face: array_typing.Array
204-
F_hires: array_typing.Array
205-
R_in: array_typing.Array
206-
R_in_face: array_typing.Array
207-
R_out: array_typing.Array
208-
R_out_face: array_typing.Array
209-
spr_hires: array_typing.Array
210-
rho_hires_norm: array_typing.Array
211-
rho_hires: array_typing.Array
212-
Phi_b_dot: array_typing.FloatScalar
213-
_z_magnetic_axis: array_typing.FloatScalar | None
175+
Phi: jt.Float[at.Array, '*stack rhon']
176+
Phi_face: jt.Float[at.Array, '*stack rhon+1']
177+
R_major: jt.Float[at.Array, '*stack']
178+
a_minor: jt.Float[at.Array, '*stack']
179+
B_0: jt.Float[at.Array, '*stack']
180+
volume: jt.Float[at.Array, '*stack rhon']
181+
volume_face: jt.Float[at.Array, '*stack rhon+1']
182+
area: jt.Float[at.Array, '*stack rhon']
183+
area_face: jt.Float[at.Array, '*stack rhon+1']
184+
vpr: jt.Float[at.Array, '*stack rhon']
185+
vpr_face: jt.Float[at.Array, '*stack rhon+1']
186+
spr: jt.Float[at.Array, '*stack rhon']
187+
spr_face: jt.Float[at.Array, '*stack rhon+1']
188+
delta_face: jt.Float[at.Array, '*stack rhon+1']
189+
elongation: jt.Float[at.Array, '*stack rhon']
190+
elongation_face: jt.Float[at.Array, '*stack rhon+1']
191+
g0: jt.Float[at.Array, '*stack rhon']
192+
g0_face: jt.Float[at.Array, '*stack rhon+1']
193+
g1: jt.Float[at.Array, '*stack rhon']
194+
g1_face: jt.Float[at.Array, '*stack rhon+1']
195+
g2: jt.Float[at.Array, '*stack rhon']
196+
g2_face: jt.Float[at.Array, '*stack rhon+1']
197+
g3: jt.Float[at.Array, '*stack rhon']
198+
g3_face: jt.Float[at.Array, '*stack rhon+1']
199+
g2g3_over_rhon: jt.Float[at.Array, '*stack rhon']
200+
g2g3_over_rhon_face: jt.Float[at.Array, '*stack rhon+1']
201+
g2g3_over_rhon_hires: jt.Float[at.Array, '*stack channel']
202+
F: jt.Float[at.Array, '*stack rhon']
203+
F_face: jt.Float[at.Array, '*stack rhon+1']
204+
F_hires: jt.Float[at.Array, '*stack channel']
205+
R_in: jt.Float[at.Array, '*stack rhon']
206+
R_in_face: jt.Float[at.Array, '*stack rhon+1']
207+
R_out: jt.Float[at.Array, '*stack rhon']
208+
R_out_face: jt.Float[at.Array, '*stack rhon+1']
209+
spr_hires: jt.Float[at.Array, '*stack channel']
210+
rho_hires_norm: jt.Float[at.Array, '*stack channel']
211+
rho_hires: jt.Float[at.Array, '*stack channel']
212+
Phi_b_dot: jt.Float[at.Array | float, '*stack']
213+
_z_magnetic_axis: jt.Float[at.Array, '*stack'] | None
214214

215215
def __eq__(self, other: 'Geometry') -> bool:
216216
try:
@@ -220,7 +220,7 @@ def __eq__(self, other: 'Geometry') -> bool:
220220
return True
221221

222222
@property
223-
def q_correction_factor(self) -> chex.Numeric:
223+
def q_correction_factor(self) -> jt.Float[jax.Array, '']:
224224
"""Ad-hoc fix for non-physical circular geometry model.
225225
226226
Set such that q(r=a) = 3 for standard ITER parameters.
@@ -232,27 +232,27 @@ def q_correction_factor(self) -> chex.Numeric:
232232
)
233233

234234
@property
235-
def rho_norm(self) -> array_typing.Array:
235+
def rho_norm(self) -> at.FloatVectorCell:
236236
r"""Normalized toroidal flux coordinate on cell grid [dimensionless]."""
237237
return self.torax_mesh.cell_centers
238238

239239
@property
240-
def rho_face_norm(self) -> array_typing.Array:
240+
def rho_face_norm(self) -> at.FloatVectorFace:
241241
r"""Normalized toroidal flux coordinate on face grid [dimensionless]."""
242242
return self.torax_mesh.face_centers
243243

244244
@property
245-
def drho_norm(self) -> array_typing.Array:
245+
def drho_norm(self) -> at.FloatVectorCell:
246246
r"""Grid size for rho_norm [dimensionless]."""
247247
return jnp.array(self.torax_mesh.dx)
248248

249249
@property
250-
def rho_face(self) -> array_typing.Array:
250+
def rho_face(self) -> at.FloatVectorFace:
251251
r"""Toroidal flux coordinate on face grid :math:`\mathrm{m}`."""
252252
return self.rho_face_norm * jnp.expand_dims(self.rho_b, axis=-1)
253253

254254
@property
255-
def rho(self) -> array_typing.Array:
255+
def rho(self) -> jt.Float[jax.Array, 'rhon']:
256256
r"""Toroidal flux coordinate on cell grid :math:`\mathrm{m}`.
257257
258258
The toroidal flux coordinate is defined as
@@ -263,54 +263,54 @@ def rho(self) -> array_typing.Array:
263263
return self.rho_norm * jnp.expand_dims(self.rho_b, axis=-1)
264264

265265
@property
266-
def r_mid(self) -> array_typing.Array:
266+
def r_mid(self) -> jt.Float[at.Array, '*stack rhon']:
267267
"""Midplane radius of the plasma [m], defined as (Rout-Rin)/2."""
268268
return (self.R_out - self.R_in) / 2
269269

270270
@property
271-
def r_mid_face(self) -> array_typing.Array:
271+
def r_mid_face(self) -> jt.Float[at.Array, '*stack rhon+1']:
272272
"""Midplane radius of the plasma on the face grid [m]."""
273273
return (self.R_out_face - self.R_in_face) / 2
274274

275275
@property
276-
def epsilon(self) -> array_typing.Array:
276+
def epsilon(self) -> jt.Float[at.Array, '*stack rhon']:
277277
"""Local midplane inverse aspect ratio [dimensionless]."""
278278
return (self.R_out - self.R_in) / (self.R_out + self.R_in)
279279

280280
@property
281-
def epsilon_face(self) -> array_typing.Array:
281+
def epsilon_face(self) -> jt.Float[at.Array, '*stack rhon+1']:
282282
"""Local midplane inverse aspect ratio on the face grid [dimensionless]."""
283283
return (self.R_out_face - self.R_in_face) / (
284284
self.R_out_face + self.R_in_face
285285
)
286286

287287
@property
288-
def drho(self) -> array_typing.Array:
288+
def drho(self) -> jt.Float[at.Array, 'rhon']:
289289
"""Grid size for rho [m]."""
290290
return self.drho_norm * self.rho_b
291291

292292
@property
293-
def rho_b(self) -> array_typing.FloatScalar:
293+
def rho_b(self) -> jt.Float[jax.Array, '*stack']:
294294
"""Toroidal flux coordinate [m] at boundary (LCFS)."""
295295
return jnp.sqrt(self.Phi_b / np.pi / self.B_0)
296296

297297
@property
298-
def Phi_b(self) -> array_typing.FloatScalar:
298+
def Phi_b(self) -> jt.Float[at.Array, '*stack']:
299299
r"""Toroidal flux at boundary (LCFS) :math:`\mathrm{Wb}`."""
300300
return self.Phi_face[..., -1]
301301

302302
@property
303-
def g1_over_vpr(self) -> array_typing.Array:
303+
def g1_over_vpr(self) -> jt.Float[at.Array, '*stack rhon']:
304304
r"""g1/vpr [:math:`\mathrm{m}`]."""
305305
return self.g1 / self.vpr
306306

307307
@property
308-
def g1_over_vpr2(self) -> array_typing.Array:
308+
def g1_over_vpr2(self) -> jt.Float[at.Array, '*stack rhon']:
309309
r"""g1/vpr**2 [:math:`\mathrm{m}^{-2}`]."""
310310
return self.g1 / self.vpr**2
311311

312312
@property
313-
def g0_over_vpr_face(self) -> jax.Array:
313+
def g0_over_vpr_face(self) -> jt.Float[at.Array, '*stack rhon_1']:
314314
"""g0_face/vpr_face [:math:`m^{-1}`], equal to 1/rho_b on-axis."""
315315
# Calculate the bulk of the array (excluding the first element)
316316
# to avoid division by zero.
@@ -322,7 +322,7 @@ def g0_over_vpr_face(self) -> jax.Array:
322322
)
323323

324324
@property
325-
def g1_over_vpr_face(self) -> jax.Array:
325+
def g1_over_vpr_face(self) -> jt.Float[at.Array, '*stack rhon_1']:
326326
r"""g1_face/vpr_face [:math:`\mathrm{m}`]. Zero on-axis."""
327327
bulk = self.g1_face[..., 1:] / self.vpr_face[..., 1:]
328328
first_element = jnp.zeros_like(self.rho_b)
@@ -331,7 +331,7 @@ def g1_over_vpr_face(self) -> jax.Array:
331331
)
332332

333333
@property
334-
def g1_over_vpr2_face(self) -> jax.Array:
334+
def g1_over_vpr2_face(self) -> jt.Float[at.Array, '*stack rhon_1']:
335335
"""g1_face/vpr_face**2 [:math:`m^{-2}`], equal to 1/rho_b**2 on-axis."""
336336
bulk = self.g1_face[..., 1:] / self.vpr_face[..., 1:] ** 2
337337
first_element = jnp.ones_like(self.rho_b) / self.rho_b**2
@@ -340,20 +340,20 @@ def g1_over_vpr2_face(self) -> jax.Array:
340340
)
341341

342342
@property
343-
def gm9(self) -> jax.Array:
343+
def gm9(self) -> jt.Float[at.Array, '*stack rhon']:
344344
r"""<1/R> on cell grid [:math:`\mathrm{m}^{-1}`]."""
345345
return 2 * jnp.pi * self.spr / self.vpr
346346

347347
@property
348-
def gm9_face(self) -> jax.Array:
348+
def gm9_face(self) -> jt.Float[at.Array, '*stack rhon_1']:
349349
r"""<1/R> on face grid [:math:`\mathrm{m}^{-1}`]."""
350350
bulk = 2 * jnp.pi * self.spr_face[..., 1:] / self.vpr_face[..., 1:]
351351
first_element = 1 / self.R_major
352352
return jnp.concatenate(
353353
[jnp.expand_dims(first_element, axis=-1), bulk], axis=-1
354354
)
355355

356-
def z_magnetic_axis(self) -> chex.Numeric:
356+
def z_magnetic_axis(self) -> jt.Float[jax.Array, '*stack']:
357357
"""z position of magnetic axis [m]."""
358358
z_magnetic_axis = self._z_magnetic_axis
359359
if z_magnetic_axis is not None:
@@ -391,7 +391,7 @@ def stack_geometries(geometries: Sequence[Geometry]) -> Geometry:
391391
field_name = field.name
392392
field_value = getattr(first_geo, field_name)
393393
# Stack stackable fields. Save first geo's value for non-stackable fields.
394-
if isinstance(field_value, (array_typing.Array, array_typing.FloatScalar)):
394+
if isinstance(field_value, (at.Array, at.FloatScalar)):
395395
field_values = [getattr(geo, field_name) for geo in geometries]
396396
stacked_data[field_name] = np.stack(field_values)
397397
else:
@@ -401,9 +401,10 @@ def stack_geometries(geometries: Sequence[Geometry]) -> Geometry:
401401
return first_geo.__class__(**stacked_data)
402402

403403

404+
@at.jaxtyped
404405
def update_geometries_with_Phibdot(
405406
*,
406-
dt: chex.Numeric,
407+
dt: at.FloatScalar,
407408
geo_t: Geometry,
408409
geo_t_plus_dt: Geometry,
409410
) -> tuple[Geometry, Geometry]:

torax/_src/geometry/tests/geometry_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def test_stack_geometries_error_handling_different_geometry_types(self):
151151
def test_update_phibdot(self):
152152
"""Test update_phibdot for circular geometries."""
153153
geo = geometry_pydantic_model.CircularConfig().build_geometry()
154-
geo0 = dataclasses.replace(geo, Phi_face=np.array([1.0]))
155-
geo1 = dataclasses.replace(geo, Phi_face=np.array([2.0]))
154+
geo0 = dataclasses.replace(geo, Phi_face=np.ones_like(geo.Phi_face) * 1.0)
155+
geo1 = dataclasses.replace(geo, Phi_face=np.ones_like(geo.Phi_face) * 2.0)
156156
geo0_updated, geo1_updated = geometry.update_geometries_with_Phibdot(
157157
dt=0.1, geo_t=geo0, geo_t_plus_dt=geo1
158158
)
@@ -166,7 +166,9 @@ def test_geometry_eq(self):
166166
self.assertEqual(geo1, geo2)
167167

168168
with self.subTest('different_geometries_are_not_equal'):
169-
geo3 = dataclasses.replace(geo1, Phi_face=np.array([2.0]))
169+
geo3 = dataclasses.replace(
170+
geo1, Phi_face=np.ones_like(geo1.Phi_face) * 2.0
171+
)
170172
self.assertNotEqual(geo1, geo3)
171173

172174

0 commit comments

Comments
 (0)