Skip to content

Commit 2923d37

Browse files
jcitrinTorax team
authored andcommitted
Improve j_total calculation to avoid numerical artifacts.
Calculate j_total on the cell grid from finite differences of Ip_profile face. This significantly improves smoothness particularly towards the boundaries. Minor differences in some integration test cases on order of O(1e-3) due to different ohmic heating. PiperOrigin-RevId: 852211436
1 parent 8513344 commit 2923d37

File tree

51 files changed

+387
-391
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+387
-391
lines changed

torax/_src/mhd/sawtooth/tests/sawtooth_model_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -267,15 +267,15 @@ def test_no_subsequent_sawtooth_crashes(self):
267267
])
268268

269269
_POST_CRASH_PSI = np.array([
270-
8.245389,
271-
9.864265,
272-
12.989683,
273-
17.52163,
274-
23.362746,
275-
30.278108,
276-
37.587465,
277-
44.522205,
278-
50.597804,
270+
8.882378,
271+
10.483339,
272+
13.574243,
273+
18.056257,
274+
23.83316,
275+
30.671823,
276+
37.89592,
277+
44.739235,
278+
50.713549,
279279
55.729866,
280280
])
281281

torax/_src/physics/psi_calculations.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
- j_toroidal_to_j_parallel: Calculates <j.B>/B0 from j_toroidal = dI/dS.
3030
- j_parallel_to_j_toroidal: Calculates j_toroidal = dI/dS from <j.B>/B0.
3131
"""
32+
3233
from typing import Final
3334

3435
import jax
@@ -168,30 +169,25 @@ def calc_j_total(
168169
/ (16 * jnp.pi**3 * constants.CONSTANTS.mu_0)
169170
)
170171

171-
Ip_profile = (
172-
psi.grad()
173-
* geo.g2g3_over_rhon
174-
* geo.F
175-
/ geo.Phi_b
176-
/ (16 * jnp.pi**3 * constants.CONSTANTS.mu_0)
177-
)
178-
179-
dI_drhon_face = jnp.gradient(Ip_profile_face, geo.rho_face_norm)
180-
dI_drhon = jnp.gradient(Ip_profile, geo.rho_norm)
172+
# Calculate dI/drhon on the cell grid using finite difference of the face
173+
# values. This ensures that the current density is consistent with the
174+
# enclosed current at the boundaries of the cells.
175+
dI_drhon = jnp.diff(Ip_profile_face) / geo.drho_norm
181176

182177
j_total = dI_drhon / geo.spr
183-
# Note: On-axis face values will be overwritten by extrapolation below, but we
184-
# need to avoid division by zero
185-
j_total_face_bulk = dI_drhon_face[1:] / geo.spr_face[1:]
186-
j_total_face = jnp.concatenate([jnp.array([j_total[0]]), j_total_face_bulk])
187178

188-
# Extrapolate the axis values to avoid numerical artifacts
179+
# Extrapolate the axis values to avoid numerical artifacts.
189180
j_total = _extrapolate_cell_profile_to_axis(j_total, geo)
190-
j_total_face = _extrapolate_face_profile_to_axis(
191-
j_total_face,
181+
182+
# Convert to face grid using linear interpolation in the bulk, and set right
183+
# edge while preserving total current. This provides a smoother profile than
184+
# calculating gradients directly on faces.
185+
j_total_face = math_utils.cell_to_face(
192186
j_total,
193187
geo,
188+
preserved_quantity=math_utils.IntegralPreservationQuantity.SURFACE,
194189
)
190+
195191
return j_total, j_total_face, Ip_profile_face
196192

197193

0 commit comments

Comments
 (0)