Skip to content

Commit 8c4ae9e

Browse files
jcitrinTorax team
authored andcommitted
Extend return value of charge_states.get_average_charge_state.
Now outputs a dataclass including the previous average charge state (now called "Z_mixture", as well as various intermediate calculated quantities. This is useful for the upcoming extended impurity API, as well as reducing some code duplication elsewhere, like in the Mavrin radiation calculation. PiperOrigin-RevId: 794141009
1 parent 68d7481 commit 8c4ae9e

File tree

6 files changed

+79
-31
lines changed

6 files changed

+79
-31
lines changed

torax/_src/config/tests/plasma_composition_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def test_ion_mixture_averaging(self, species, time, expected_Z, expected_A):
223223
ion_symbols=tuple(species.keys()), # pytype: disable=attribute-error
224224
ion_mixture=dynamic_mixture,
225225
T_e=np.array(10.0), # Ensure that all ions in test are fully ionized
226-
)
226+
).Z_mixture
227227
np.testing.assert_allclose(calculated_Z, expected_Z)
228228
np.testing.assert_allclose(dynamic_mixture.avg_A, expected_A)
229229

@@ -248,7 +248,7 @@ def test_ion_mixture_override(self, Z_override, A_override, Z, A):
248248
ion_symbols=tuple(mixture.species.keys()),
249249
ion_mixture=dynamic_mixture,
250250
T_e=np.array(1.0), # arbitrary temperature, won't be used for D
251-
)
251+
).Z_mixture
252252
Z_expected = Z if Z_override is None else Z_override
253253
A_expected = A if A_override is None else A_override
254254
np.testing.assert_allclose(calculated_Z, Z_expected)

torax/_src/core_profiles/getters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,23 +293,23 @@ def _get_charge_states(
293293
ion_symbols=static_runtime_params_slice.main_ion_names,
294294
ion_mixture=dynamic_runtime_params_slice.plasma_composition.main_ion,
295295
T_e=T_e.value,
296-
)
296+
).Z_mixture
297297
Z_i_face = charge_states.get_average_charge_state(
298298
ion_symbols=static_runtime_params_slice.main_ion_names,
299299
ion_mixture=dynamic_runtime_params_slice.plasma_composition.main_ion,
300300
T_e=T_e.face_value(),
301-
)
301+
).Z_mixture
302302

303303
Z_impurity = charge_states.get_average_charge_state(
304304
ion_symbols=static_runtime_params_slice.impurity_names,
305305
ion_mixture=dynamic_runtime_params_slice.plasma_composition.impurity,
306306
T_e=T_e.value,
307-
)
307+
).Z_mixture
308308
Z_impurity_face = charge_states.get_average_charge_state(
309309
ion_symbols=static_runtime_params_slice.impurity_names,
310310
ion_mixture=dynamic_runtime_params_slice.plasma_composition.impurity,
311311
T_e=T_e.face_value(),
312-
)
312+
).Z_mixture
313313

314314
return Z_i, Z_i_face, Z_impurity, Z_impurity_face
315315

torax/_src/core_profiles/updaters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,12 +383,12 @@ def compute_boundary_conditions_for_t_plus_dt(
383383
static_runtime_params_slice.main_ion_names,
384384
ion_mixture=dynamic_runtime_params_slice_t_plus_dt.plasma_composition.main_ion,
385385
T_e=profile_conditions_t_plus_dt.T_e_right_bc,
386-
)
386+
).Z_mixture
387387
Z_impurity_edge = charge_states.get_average_charge_state(
388388
static_runtime_params_slice.impurity_names,
389389
ion_mixture=dynamic_runtime_params_slice_t_plus_dt.plasma_composition.impurity,
390390
T_e=profile_conditions_t_plus_dt.T_e_right_bc,
391-
)
391+
).Z_mixture
392392

393393
dilution_factor_edge = formulas.calculate_main_ion_dilution_factor(
394394
Z_i_edge,

torax/_src/physics/charge_states.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,19 @@
1414

1515
"""Routines for calculating impurity charge states."""
1616

17+
import dataclasses
1718
from typing import Final, Mapping, Sequence
1819

1920
import immutabledict
21+
import jax
2022
from jax import numpy as jnp
2123
import numpy as np
2224
from torax._src import array_typing
2325
from torax._src import constants
2426
from torax._src.config import plasma_composition
2527

28+
# pylint: disable=invalid-name
29+
2630
# Polynomial fit coefficients from A. A. Mavrin (2018):
2731
# Improved fits of coronal radiative cooling rates for high-temperature plasmas,
2832
# Radiation Effects and Defects in Solids, 173:5-6, 388-398,
@@ -86,6 +90,33 @@
8690
)
8791

8892

93+
@jax.tree_util.register_dataclass
94+
@dataclasses.dataclass(frozen=True)
95+
class ChargeStateInfo:
96+
"""Container for average charge state calculations.
97+
98+
Attributes:
99+
Z_avg: Average charge of the mixture, weighted by ion fractions. <Z> =
100+
sum(fraction_i * Z_i).
101+
Z2_avg: Average squared charge of the mixture, weighted by ion fractions.
102+
<Z^2> = sum(fraction_i * Z_i^2).
103+
Z_per_species: Charge state for each individual ion species in the mixture.
104+
For impurities, this is the outcome of a temperature dependent charge
105+
state calculation.
106+
Z_mixture: Effective charge of the mixture, defined as <Z^2> / <Z>. This is
107+
the charge used in quasineutrality calculations when treating the mixture
108+
as a single effective species.
109+
"""
110+
111+
Z_avg: array_typing.ArrayFloat
112+
Z2_avg: array_typing.ArrayFloat
113+
Z_per_species: array_typing.ArrayFloat
114+
115+
@property
116+
def Z_mixture(self) -> array_typing.ArrayFloat:
117+
return self.Z2_avg / self.Z_avg
118+
119+
89120
# pylint: disable=invalid-name
90121
def calculate_average_charge_state_single_species(
91122
T_e: array_typing.ArrayFloat,
@@ -135,7 +166,7 @@ def get_average_charge_state(
135166
ion_symbols: Sequence[str],
136167
ion_mixture: plasma_composition.DynamicIonMixture,
137168
T_e: array_typing.ArrayFloat,
138-
) -> array_typing.ArrayFloat:
169+
) -> ChargeStateInfo:
139170
"""Calculates or prescribes average impurity charge state profile (JAX-compatible).
140171
141172
Equations for quasineutrality and Zeff are the following:
@@ -171,18 +202,37 @@ def get_average_charge_state(
171202
face grid, or a single scalar.
172203
173204
Returns:
174-
avg_Z: Average charge state profile [amu].
175-
The shape of avg_Z is the same as T_e.
205+
AverageChargeState: dataclass with average charge state info.
176206
"""
177207

178208
if ion_mixture.Z_override is not None:
179-
return jnp.ones_like(T_e) * ion_mixture.Z_override
180-
181-
avg_Z = jnp.zeros_like(T_e)
182-
avg_Z2 = jnp.zeros_like(T_e)
183-
for ion_symbol, fraction in zip(ion_symbols, ion_mixture.fractions):
184-
Z_species = calculate_average_charge_state_single_species(T_e, ion_symbol)
185-
avg_Z += fraction * Z_species
186-
avg_Z2 += fraction * Z_species**2
209+
override_val = jnp.ones_like(T_e) * ion_mixture.Z_override
210+
return ChargeStateInfo(
211+
Z_avg=override_val,
212+
Z2_avg=override_val**2,
213+
Z_per_species=jnp.stack([override_val for _ in ion_symbols]),
214+
)
187215

188-
return avg_Z2 / avg_Z
216+
Z_per_species = jnp.stack([
217+
calculate_average_charge_state_single_species(T_e, ion_symbol)
218+
for ion_symbol in ion_symbols
219+
])
220+
221+
# ion_mixture.fractions has shape (n_species,).
222+
# Z_per_species has shape (n_species,) if T_e is a scalar, or
223+
# (n_species, n_grid) if T_e is an array.
224+
# We need to broadcast fractions for element-wise multiplication.
225+
# Reshape fractions to be broadcastable with Z_per_species.
226+
fractions_reshaped = jnp.reshape(
227+
ion_mixture.fractions,
228+
ion_mixture.fractions.shape + (1,) * (Z_per_species.ndim - 1),
229+
)
230+
231+
Z_avg = jnp.sum(fractions_reshaped * Z_per_species, axis=0)
232+
Z2_avg = jnp.sum(fractions_reshaped * Z_per_species**2, axis=0)
233+
234+
return ChargeStateInfo(
235+
Z_avg=Z_avg,
236+
Z2_avg=Z2_avg,
237+
Z_per_species=Z_per_species,
238+
)

torax/_src/physics/tests/charge_states_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def test_get_average_charge_state(
179179
)
180180
Z_calculated = charge_states.get_average_charge_state(
181181
ion_symbols, ion_mixture, T_e
182-
)
182+
).Z_mixture
183183

184184
np.testing.assert_allclose(Z_calculated, expected_Z, rtol=1e-5)
185185

@@ -196,7 +196,7 @@ def test_Z_override_in_get_average_charge_state(self):
196196
)
197197
Z_calculated = charge_states.get_average_charge_state(
198198
ion_symbols, ion_mixture, T_e
199-
)
199+
).Z_mixture
200200
np.testing.assert_allclose(Z_calculated, Z_override)
201201

202202

torax/_src/sources/impurity_radiation_heat_sink/impurity_radiation_mavrin_fit.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,13 @@ def impurity_radiation_mavrin_fit(
225225
# impurity density, not the effective one.
226226

227227
# ion_symbols is a static argument so can use the for loop under jit
228-
Z_per_species = jnp.stack([
229-
charge_states.calculate_average_charge_state_single_species(
230-
core_profiles.T_e.value, ion_symbol
231-
)
232-
for ion_symbol in ion_symbols
233-
])
234-
235-
avg_Z = jnp.sum(ion_mixture.fractions[:, jnp.newaxis] * Z_per_species, axis=0)
236-
impurity_density_scaling = core_profiles.Z_impurity / avg_Z
228+
charge_state_info = charge_states.get_average_charge_state(
229+
ion_symbols=ion_symbols,
230+
ion_mixture=ion_mixture,
231+
T_e=core_profiles.T_e.value,
232+
)
233+
Z_avg = charge_state_info.Z_avg
234+
impurity_density_scaling = core_profiles.Z_impurity / Z_avg
237235

238236
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
239237
source_name

0 commit comments

Comments
 (0)