Skip to content

Commit f04884a

Browse files
sbodensteinTorax team
authored andcommitted
Replace chex.Array with array_typing.Array.
chex.Array includes things like np.bool_ and np.number that are not what functions accepting either NumPy or JAX arrays intend. PiperOrigin-RevId: 802070254
1 parent 5852672 commit f04884a

37 files changed

+298
-268
lines changed

torax/_src/array_typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
FloatVector: TypeAlias = jt.Float[Array, "_"]
3333
BoolVector: TypeAlias = jt.Bool[Array, "_"]
3434
FloatVectorCell: TypeAlias = jt.Float[Array, "rhon"]
35+
FloatVectorCellPlusBoundaries: TypeAlias = jt.Float[Array, "rhon+2"]
36+
FloatMatrixCell: TypeAlias = jt.Float[Array, "rhon rhon"]
3537
FloatVectorFace: TypeAlias = jt.Float[Array, "rhon+1"]
3638

3739

torax/_src/fvm/diffusion_terms.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@
1919

2020
import chex
2121
from jax import numpy as jnp
22+
from torax._src import array_typing
2223
from torax._src import math_utils
2324
from torax._src.fvm import cell_variable
2425

2526

2627
def make_diffusion_terms(
27-
d_face: chex.Array, var: cell_variable.CellVariable
28-
) -> tuple[chex.Array, chex.Array]:
28+
d_face: array_typing.FloatVectorFace, var: cell_variable.CellVariable
29+
) -> tuple[array_typing.FloatMatrixCell, array_typing.FloatVectorCell]:
2930
"""Makes the terms of the matrix equation derived from the diffusion term.
3031
3132
The diffusion term is of the form

torax/_src/fvm/newton_raphson_solve_block.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import functools
2121
from typing import Final
22-
import chex
22+
from torax._src import array_typing
2323
from torax._src import jax_utils
2424
from torax._src import physics_models as physics_models_lib
2525
from torax._src import state as state_module
@@ -52,7 +52,7 @@
5252
],
5353
)
5454
def newton_raphson_solve_block(
55-
dt: chex.Array,
55+
dt: array_typing.FloatScalar,
5656
dynamic_runtime_params_slice_t: runtime_params_slice.RuntimeParams,
5757
dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.RuntimeParams,
5858
geo_t: geometry.Geometry,

torax/_src/geometry/geometry.py

Lines changed: 58 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,17 @@
1616
from collections.abc import Sequence
1717
import dataclasses
1818
import enum
19-
2019
import chex
2120
import jax
2221
import jax.numpy as jnp
2322
import numpy as np
23+
from torax._src import array_typing
2424
from torax._src.torax_pydantic import torax_pydantic
2525

2626

27-
def face_to_cell(face: chex.Array) -> chex.Array:
27+
def face_to_cell(
28+
face: array_typing.FloatVectorFace,
29+
) -> array_typing.FloatVectorCell:
2830
"""Infers cell values corresponding to a vector of face values.
2931
3032
Simply a linear interpolation between face values.
@@ -170,45 +172,45 @@ class Geometry:
170172

171173
geometry_type: GeometryType
172174
torax_mesh: torax_pydantic.Grid1D
173-
Phi: chex.Array
174-
Phi_face: chex.Array
175-
R_major: chex.Array
176-
a_minor: chex.Array
177-
B_0: chex.Array
178-
volume: chex.Array
179-
volume_face: chex.Array
180-
area: chex.Array
181-
area_face: chex.Array
182-
vpr: chex.Array
183-
vpr_face: chex.Array
184-
spr: chex.Array
185-
spr_face: chex.Array
186-
delta_face: chex.Array
187-
elongation: chex.Array
188-
elongation_face: chex.Array
189-
g0: chex.Array
190-
g0_face: chex.Array
191-
g1: chex.Array
192-
g1_face: chex.Array
193-
g2: chex.Array
194-
g2_face: chex.Array
195-
g3: chex.Array
196-
g3_face: chex.Array
197-
g2g3_over_rhon: chex.Array
198-
g2g3_over_rhon_face: chex.Array
199-
g2g3_over_rhon_hires: chex.Array
200-
F: chex.Array
201-
F_face: chex.Array
202-
F_hires: chex.Array
203-
R_in: chex.Array
204-
R_in_face: chex.Array
205-
R_out: chex.Array
206-
R_out_face: chex.Array
207-
spr_hires: chex.Array
208-
rho_hires_norm: chex.Array
209-
rho_hires: chex.Array
210-
Phi_b_dot: chex.Array
211-
_z_magnetic_axis: chex.Array | None
175+
Phi: array_typing.Array
176+
Phi_face: array_typing.Array
177+
R_major: array_typing.FloatScalar
178+
a_minor: array_typing.FloatScalar
179+
B_0: array_typing.FloatScalar
180+
volume: array_typing.Array
181+
volume_face: array_typing.Array
182+
area: array_typing.Array
183+
area_face: array_typing.Array
184+
vpr: array_typing.Array
185+
vpr_face: array_typing.Array
186+
spr: array_typing.Array
187+
spr_face: array_typing.Array
188+
delta_face: array_typing.Array
189+
elongation: array_typing.Array
190+
elongation_face: array_typing.Array
191+
g0: array_typing.Array
192+
g0_face: array_typing.Array
193+
g1: array_typing.Array
194+
g1_face: array_typing.Array
195+
g2: array_typing.Array
196+
g2_face: array_typing.Array
197+
g3: array_typing.Array
198+
g3_face: array_typing.Array
199+
g2g3_over_rhon: array_typing.Array
200+
g2g3_over_rhon_face: array_typing.Array
201+
g2g3_over_rhon_hires: array_typing.Array
202+
F: array_typing.Array
203+
F_face: array_typing.Array
204+
F_hires: array_typing.Array
205+
R_in: array_typing.Array
206+
R_in_face: array_typing.Array
207+
R_out: array_typing.Array
208+
R_out_face: array_typing.Array
209+
spr_hires: array_typing.Array
210+
rho_hires_norm: array_typing.Array
211+
rho_hires: array_typing.Array
212+
Phi_b_dot: array_typing.FloatScalar
213+
_z_magnetic_axis: array_typing.FloatScalar | None
212214

213215
def __eq__(self, other: 'Geometry') -> bool:
214216
try:
@@ -230,27 +232,27 @@ def q_correction_factor(self) -> chex.Numeric:
230232
)
231233

232234
@property
233-
def rho_norm(self) -> chex.Array:
235+
def rho_norm(self) -> array_typing.Array:
234236
r"""Normalized toroidal flux coordinate on cell grid [dimensionless]."""
235237
return self.torax_mesh.cell_centers
236238

237239
@property
238-
def rho_face_norm(self) -> chex.Array:
240+
def rho_face_norm(self) -> array_typing.Array:
239241
r"""Normalized toroidal flux coordinate on face grid [dimensionless]."""
240242
return self.torax_mesh.face_centers
241243

242244
@property
243-
def drho_norm(self) -> chex.Array:
245+
def drho_norm(self) -> array_typing.Array:
244246
r"""Grid size for rho_norm [dimensionless]."""
245247
return jnp.array(self.torax_mesh.dx)
246248

247249
@property
248-
def rho_face(self) -> chex.Array:
250+
def rho_face(self) -> array_typing.Array:
249251
r"""Toroidal flux coordinate on face grid :math:`\mathrm{m}`."""
250252
return self.rho_face_norm * jnp.expand_dims(self.rho_b, axis=-1)
251253

252254
@property
253-
def rho(self) -> chex.Array:
255+
def rho(self) -> array_typing.Array:
254256
r"""Toroidal flux coordinate on cell grid :math:`\mathrm{m}`.
255257
256258
The toroidal flux coordinate is defined as
@@ -261,49 +263,49 @@ def rho(self) -> chex.Array:
261263
return self.rho_norm * jnp.expand_dims(self.rho_b, axis=-1)
262264

263265
@property
264-
def r_mid(self) -> chex.Array:
266+
def r_mid(self) -> array_typing.Array:
265267
"""Midplane radius of the plasma [m], defined as (Rout-Rin)/2."""
266268
return (self.R_out - self.R_in) / 2
267269

268270
@property
269-
def r_mid_face(self) -> chex.Array:
271+
def r_mid_face(self) -> array_typing.Array:
270272
"""Midplane radius of the plasma on the face grid [m]."""
271273
return (self.R_out_face - self.R_in_face) / 2
272274

273275
@property
274-
def epsilon(self) -> chex.Array:
276+
def epsilon(self) -> array_typing.Array:
275277
"""Local midplane inverse aspect ratio [dimensionless]."""
276278
return (self.R_out - self.R_in) / (self.R_out + self.R_in)
277279

278280
@property
279-
def epsilon_face(self) -> chex.Array:
281+
def epsilon_face(self) -> array_typing.Array:
280282
"""Local midplane inverse aspect ratio on the face grid [dimensionless]."""
281283
return (self.R_out_face - self.R_in_face) / (
282284
self.R_out_face + self.R_in_face
283285
)
284286

285287
@property
286-
def drho(self) -> chex.Array:
288+
def drho(self) -> array_typing.Array:
287289
"""Grid size for rho [m]."""
288290
return self.drho_norm * self.rho_b
289291

290292
@property
291-
def rho_b(self) -> chex.Array:
293+
def rho_b(self) -> array_typing.FloatScalar:
292294
"""Toroidal flux coordinate [m] at boundary (LCFS)."""
293295
return jnp.sqrt(self.Phi_b / np.pi / self.B_0)
294296

295297
@property
296-
def Phi_b(self) -> chex.Array:
298+
def Phi_b(self) -> array_typing.FloatScalar:
297299
r"""Toroidal flux at boundary (LCFS) :math:`\mathrm{Wb}`."""
298300
return self.Phi_face[..., -1]
299301

300302
@property
301-
def g1_over_vpr(self) -> chex.Array:
303+
def g1_over_vpr(self) -> array_typing.Array:
302304
r"""g1/vpr [:math:`\mathrm{m}`]."""
303305
return self.g1 / self.vpr
304306

305307
@property
306-
def g1_over_vpr2(self) -> chex.Array:
308+
def g1_over_vpr2(self) -> array_typing.Array:
307309
r"""g1/vpr**2 [:math:`\mathrm{m}^{-2}`]."""
308310
return self.g1 / self.vpr**2
309311

@@ -389,7 +391,7 @@ def stack_geometries(geometries: Sequence[Geometry]) -> Geometry:
389391
field_name = field.name
390392
field_value = getattr(first_geo, field_name)
391393
# Stack stackable fields. Save first geo's value for non-stackable fields.
392-
if isinstance(field_value, chex.Array) or isinstance(field_value, float):
394+
if isinstance(field_value, (array_typing.Array, array_typing.FloatScalar)):
393395
field_values = [getattr(geo, field_name) for geo in geometries]
394396
stacked_data[field_name] = np.stack(field_values)
395397
else:

torax/_src/geometry/standard_geometry.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121
from collections.abc import Mapping
2222
import dataclasses
2323
import logging
24-
2524
import chex
2625
import contourpy
2726
from imas import ids_toplevel
2827
import jax
2928
import numpy as np
3029
import scipy
30+
from torax._src import array_typing
3131
from torax._src import constants
3232
from torax._src import interpolated_param
3333
from torax._src.geometry import geometry
@@ -74,14 +74,14 @@ class StandardGeometry(geometry.Geometry):
7474
"""
7575

7676
Ip_from_parameters: bool = dataclasses.field(metadata=dict(static=True))
77-
Ip_profile_face: chex.Array
78-
psi: chex.Array
79-
psi_from_Ip: chex.Array
80-
psi_from_Ip_face: chex.Array
81-
j_total: chex.Array
82-
j_total_face: chex.Array
83-
delta_upper_face: chex.Array
84-
delta_lower_face: chex.Array
77+
Ip_profile_face: array_typing.FloatVectorFace
78+
psi: array_typing.FloatVectorCell
79+
psi_from_Ip: array_typing.FloatVectorCell
80+
psi_from_Ip_face: array_typing.FloatVectorFace
81+
j_total: array_typing.Array
82+
j_total_face: array_typing.FloatVectorFace
83+
delta_upper_face: array_typing.FloatVectorFace
84+
delta_lower_face: array_typing.FloatVectorFace
8585

8686

8787
@jax.tree_util.register_dataclass
@@ -169,28 +169,28 @@ class StandardGeometryIntermediates:
169169

170170
geometry_type: geometry.GeometryType
171171
Ip_from_parameters: bool
172-
R_major: chex.Numeric
173-
a_minor: chex.Numeric
174-
B_0: chex.Numeric
175-
psi: chex.Array
176-
Ip_profile: chex.Array
177-
Phi: chex.Array
178-
R_in: chex.Array
179-
R_out: chex.Array
180-
F: chex.Array
181-
int_dl_over_Bp: chex.Array
182-
flux_surf_avg_1_over_R: chex.Array
183-
flux_surf_avg_1_over_R2: chex.Array
184-
flux_surf_avg_Bp2: chex.Array
185-
flux_surf_avg_RBp: chex.Array
186-
flux_surf_avg_R2Bp2: chex.Array
187-
delta_upper_face: chex.Array
188-
delta_lower_face: chex.Array
189-
elongation: chex.Array
190-
vpr: chex.Array
172+
R_major: array_typing.FloatScalar
173+
a_minor: array_typing.FloatScalar
174+
B_0: array_typing.FloatScalar
175+
psi: array_typing.Array
176+
Ip_profile: array_typing.Array
177+
Phi: array_typing.Array
178+
R_in: array_typing.Array
179+
R_out: array_typing.Array
180+
F: array_typing.Array
181+
int_dl_over_Bp: array_typing.Array
182+
flux_surf_avg_1_over_R: array_typing.Array
183+
flux_surf_avg_1_over_R2: array_typing.Array
184+
flux_surf_avg_Bp2: array_typing.Array
185+
flux_surf_avg_RBp: array_typing.Array
186+
flux_surf_avg_R2Bp2: array_typing.Array
187+
delta_upper_face: array_typing.Array
188+
delta_lower_face: array_typing.Array
189+
elongation: array_typing.Array
190+
vpr: array_typing.Array
191191
n_rho: int
192192
hires_factor: int
193-
z_magnetic_axis: chex.Numeric | None
193+
z_magnetic_axis: array_typing.FloatScalar | None
194194

195195
def __post_init__(self):
196196
"""Extrapolates edge values and smooths near-axis values.
@@ -1079,7 +1079,7 @@ def build_standard_geometry(
10791079

10801080
# fill geometry structure
10811081
# normalized grid
1082-
mesh = torax_pydantic.Grid1D(nx=intermediate.n_rho,)
1082+
mesh = torax_pydantic.Grid1D(nx=intermediate.n_rho)
10831083
rho_b = rho_intermediate[-1] # radius denormalization constant
10841084
# helper variables for mesh cells and faces
10851085
rho_face_norm = mesh.face_centers

0 commit comments

Comments
 (0)