Skip to content

Commit ff60422

Browse files
jcitrinTorax team
authored andcommitted
Centralize sign convention enforcement in StandardGeometryIntermediates.
This change removes np.abs() calls from geometry loading functions (fbt.py, equilibrium.py) and instead enforces sign conventions for psi and Ip_profile within the StandardGeometryIntermediates.__post_init__ method. This ensures that psi increases with radius and Ip_profile is positive, regardless of the input data's sign conventions. Tests are added to verify the sign flipping behavior. This also fixes a bug where any initial psi from geo that was not strictly positive or strictly negative would get a "kink" due to the abs. We now ensure correct trend of psi, and do not enforce an abs PiperOrigin-RevId: 883157580
1 parent 6620f90 commit ff60422

File tree

5 files changed

+131
-14
lines changed

5 files changed

+131
-14
lines changed

torax/_src/geometry/fbt.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -505,22 +505,22 @@ def _get_val(key):
505505
B_0=B_0,
506506
psi=psi,
507507
Phi=Phi,
508-
Ip_profile=np.abs(LY['ItQ']),
508+
Ip_profile=LY['ItQ'],
509509
R_in=LY['rgeom'] - LY['aminor'],
510510
R_out=LY['rgeom'] + LY['aminor'],
511-
F=np.abs(LY['TQ']),
511+
F=LY['TQ'],
512512
int_dl_over_Bp=1 / LY_Q1Q,
513513
flux_surf_avg_1_over_R=LY['Q0Q'],
514514
flux_surf_avg_1_over_R2=LY['Q2Q'],
515-
flux_surf_avg_grad_psi2_over_R2=np.abs(LY['Q3Q']),
516-
flux_surf_avg_grad_psi=np.abs(LY['Q5Q']),
517-
flux_surf_avg_grad_psi2=np.abs(LY['Q4Q']),
515+
flux_surf_avg_grad_psi2_over_R2=LY['Q3Q'],
516+
flux_surf_avg_grad_psi=LY['Q5Q'],
517+
flux_surf_avg_grad_psi2=LY['Q4Q'],
518518
flux_surf_avg_B2=flux_surf_avg_B2,
519519
flux_surf_avg_1_over_B2=flux_surf_avg_1_over_B2,
520520
delta_upper_face=LY['deltau'],
521521
delta_lower_face=LY['deltal'],
522522
elongation=LY['kappa'],
523-
vpr=4 * np.pi * Phi[-1] * rhon / (np.abs(LY['TQ']) * LY['Q2Q']),
523+
vpr=4 * np.pi * Phi[-1] * rhon / (LY['TQ'] * LY['Q2Q']),
524524
face_centers=face_centers,
525525
hires_factor=hires_factor,
526526
diverted=diverted,

torax/_src/geometry/standard_geometry.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,44 @@ class StandardGeometryIntermediates:
251251
B_pol_OMP: array_typing.FloatScalar | None
252252

253253
def __post_init__(self):
254-
"""Extrapolates edge values and smooths near-axis values.
255-
254+
"""Enforces sign conventions, extrapolates edge, and smooths near-axis.
255+
256+
Sign conventions - TODO(b/335204606): Replace with proper COCOS handling:
257+
- psi must grow with radius, i.e. psi(axis) < psi(separatrix).
258+
- Total Ip must be positive.
259+
- Positive-definite quantities are enforced via abs(): Phi, F,
260+
int_dl_over_Bp, vpr, flux_surf_avg_grad_psi, flux_surf_avg_grad_psi2,
261+
flux_surf_avg_grad_psi2_over_R2.
262+
Then:
256263
- Edge extrapolation for a subset of attributes based on a Cubic spline fit.
257264
- Near-axis smoothing for a subset of attributes based on a Savitzky-Golay
258265
filter with an appropriate polynominal order based on the attribute.
259266
"""
260267

268+
if self.psi[-1] < self.psi[0]:
269+
object.__setattr__(self, 'psi', -self.psi)
270+
271+
if self.Ip_profile[-1] < 0:
272+
object.__setattr__(self, 'Ip_profile', -self.Ip_profile)
273+
274+
object.__setattr__(self, 'Phi', np.abs(self.Phi))
275+
object.__setattr__(self, 'F', np.abs(self.F))
276+
object.__setattr__(
277+
self, 'int_dl_over_Bp', np.abs(self.int_dl_over_Bp)
278+
)
279+
object.__setattr__(self, 'vpr', np.abs(self.vpr))
280+
object.__setattr__(
281+
self, 'flux_surf_avg_grad_psi', np.abs(self.flux_surf_avg_grad_psi)
282+
)
283+
object.__setattr__(
284+
self, 'flux_surf_avg_grad_psi2', np.abs(self.flux_surf_avg_grad_psi2)
285+
)
286+
object.__setattr__(
287+
self,
288+
'flux_surf_avg_grad_psi2_over_R2',
289+
np.abs(self.flux_surf_avg_grad_psi2_over_R2),
290+
)
291+
261292
# Check if last flux surface is diverted and correct via spline fit if so
262293
if self.diverted or self.flux_surf_avg_grad_psi2_over_R2[-1] < 1e-10:
263294
# Calculate rhon

torax/_src/geometry/tests/standard_geometry_test.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,85 @@ def test_stack_geometries_standard_geometries(self):
146146
stacked_geo.g1_over_vpr2_face[:, 0], 1 / stacked_geo.rho_b**2
147147
)
148148

149+
def _make_intermediates(self, **overrides):
150+
defaults = dict(
151+
geometry_type=geometry.GeometryType.FBT,
152+
Ip_from_parameters=True,
153+
R_major=6.2,
154+
a_minor=2.0,
155+
B_0=5.3,
156+
psi=np.linspace(0, 1.0, 100),
157+
Ip_profile=np.linspace(0, 1e6, 100),
158+
Phi=np.linspace(0, 1.0, 100),
159+
R_in=np.linspace(4.0, 4.2, 100),
160+
R_out=np.linspace(8.0, 8.4, 100),
161+
F=np.linspace(30.0, 33.0, 100),
162+
int_dl_over_Bp=np.linspace(0.01, 1.0, 100),
163+
flux_surf_avg_1_over_R=np.linspace(0.1, 0.2, 100),
164+
flux_surf_avg_1_over_R2=np.linspace(0.01, 0.04, 100),
165+
flux_surf_avg_grad_psi2=np.linspace(0.01, 1.0, 100),
166+
flux_surf_avg_grad_psi=np.linspace(0.01, 1.0, 100),
167+
flux_surf_avg_grad_psi2_over_R2=np.linspace(0.01, 1.0, 100),
168+
flux_surf_avg_B2=np.linspace(25.0, 30.0, 100),
169+
flux_surf_avg_1_over_B2=np.linspace(0.03, 0.04, 100),
170+
delta_upper_face=np.linspace(0.0, 0.3, 100),
171+
delta_lower_face=np.linspace(0.0, 0.3, 100),
172+
elongation=np.linspace(1.0, 1.7, 100),
173+
vpr=np.linspace(0.01, 1.0, 100),
174+
face_centers=interpolated_param_2d.get_face_centers(25),
175+
hires_factor=4,
176+
z_magnetic_axis=np.array(0.0),
177+
diverted=None,
178+
connection_length_target=None,
179+
connection_length_divertor=None,
180+
angle_of_incidence_target=None,
181+
R_OMP=None,
182+
R_target=None,
183+
B_pol_OMP=None,
184+
)
185+
defaults.update(overrides)
186+
return standard_geometry.StandardGeometryIntermediates(**defaults)
187+
188+
def test_post_init_flips_psi_when_decreasing(self):
189+
psi_decreasing = np.linspace(1.0, 0.0, 100)
190+
intermediates = self._make_intermediates(psi=psi_decreasing)
191+
self.assertGreater(intermediates.psi[-1], intermediates.psi[0])
192+
193+
def test_post_init_preserves_psi_when_increasing(self):
194+
psi_increasing = np.linspace(0.0, 1.0, 100)
195+
intermediates = self._make_intermediates(psi=psi_increasing)
196+
np.testing.assert_array_equal(intermediates.psi, psi_increasing)
197+
198+
def test_post_init_flips_negative_Ip(self):
199+
Ip_negative = np.linspace(0, -1e6, 100)
200+
intermediates = self._make_intermediates(Ip_profile=Ip_negative)
201+
self.assertGreater(intermediates.Ip_profile[-1], 0)
202+
203+
def test_post_init_preserves_positive_Ip(self):
204+
Ip_positive = np.linspace(0, 1e6, 100)
205+
intermediates = self._make_intermediates(Ip_profile=Ip_positive)
206+
np.testing.assert_array_equal(intermediates.Ip_profile, Ip_positive)
207+
208+
def test_post_init_enforces_positive_definite_quantities(self):
209+
intermediates = self._make_intermediates(
210+
Phi=-np.linspace(0, 1.0, 100),
211+
F=-np.linspace(30.0, 33.0, 100),
212+
int_dl_over_Bp=-np.linspace(0.01, 1.0, 100),
213+
vpr=-np.linspace(0.01, 1.0, 100),
214+
flux_surf_avg_grad_psi=-np.linspace(0.01, 1.0, 100),
215+
flux_surf_avg_grad_psi2=-np.linspace(0.01, 1.0, 100),
216+
flux_surf_avg_grad_psi2_over_R2=-np.linspace(0.01, 1.0, 100),
217+
)
218+
np.testing.assert_array_less(-1e-15, intermediates.Phi)
219+
np.testing.assert_array_less(0, intermediates.F)
220+
np.testing.assert_array_less(0, intermediates.int_dl_over_Bp)
221+
np.testing.assert_array_less(0, intermediates.vpr)
222+
np.testing.assert_array_less(0, intermediates.flux_surf_avg_grad_psi)
223+
np.testing.assert_array_less(0, intermediates.flux_surf_avg_grad_psi2)
224+
np.testing.assert_array_less(
225+
0, intermediates.flux_surf_avg_grad_psi2_over_R2
226+
)
227+
149228

150229
if __name__ == '__main__':
151230
absltest.main()

torax/_src/imas_tools/input/equilibrium.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,33 +118,38 @@ def geometry_from_IMAS(
118118
f"{len(equilibrium.time_slice)} time slices"
119119
)
120120
IMAS_data = equilibrium.time_slice[slice_index]
121+
# IMAS python API returns custom primitive types (e.g. IDSFloat0D,
122+
# IDSNumericArray) instead of standard python floats or numpy arrays. We must
123+
# cast these to standard numpy arrays using np.asarray() (or via implicit
124+
# numpy operations like np.abs), otherwise JAX JIT compilation will fail with
125+
# TypeErrors during PyTree tracing.
121126
R_major = np.asarray(equilibrium.vacuum_toroidal_field.r0)
122-
B_0 = np.asarray(np.abs(equilibrium.vacuum_toroidal_field.b0[0]))
127+
B_0 = np.abs(equilibrium.vacuum_toroidal_field.b0[0])
123128

124129
# Poloidal flux.
125-
psi = np.abs(IMAS_data.profiles_1d.psi)
130+
psi = np.asarray(IMAS_data.profiles_1d.psi)
126131

127132
# Toroidal flux.
128-
phi = np.abs(IMAS_data.profiles_1d.phi)
133+
phi = np.asarray(IMAS_data.profiles_1d.phi)
129134

130135
# Midplane radii.
131136
R_in = IMAS_data.profiles_1d.r_inboard
132137
R_out = IMAS_data.profiles_1d.r_outboard
133138
R_major_profile = (R_in + R_out) / 2.0
134139
# toroidal field flux function
135-
F = np.abs(IMAS_data.profiles_1d.f)
140+
F = np.asarray(IMAS_data.profiles_1d.f)
136141

137142
# Flux surface integrals of various geometry quantities.
138143
# IDS Contour integrals.
139144
if IMAS_data.profiles_1d.dvolume_dpsi:
140-
dvoldpsi = np.abs(IMAS_data.profiles_1d.dvolume_dpsi)
145+
dvoldpsi = np.asarray(IMAS_data.profiles_1d.dvolume_dpsi)
141146
else:
142147
dvoldpsi = np.gradient(
143148
IMAS_data.profiles_1d.volume, IMAS_data.profiles_1d.psi
144149
)
145150
# dpsi_drho_tor
146151
if IMAS_data.profiles_1d.dpsi_drho_tor:
147-
dpsidrhotor = np.abs(IMAS_data.profiles_1d.dpsi_drho_tor)
152+
dpsidrhotor = np.asarray(IMAS_data.profiles_1d.dpsi_drho_tor)
148153
else:
149154
rho_tor = IMAS_data.profiles_1d.rho_tor
150155
if not rho_tor:

torax/tests/sim_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ class SimTest(sim_test_case.SimTestCase):
296296
(
297297
'test_iterhybrid_lh_transition',
298298
'test_iterhybrid_lh_transition.py',
299+
_ALL_PROFILES,
300+
1e-7,
299301
),
300302
# Tests used for testing changing configs without recompiling.
301303
# Based on test_iterhybrid_predictor_corrector

0 commit comments

Comments
 (0)