Skip to content

Commit 3f83f89

Browse files
tamaranormanTorax team
authored andcommitted
Add a data structure for the energy state info to core profiles
And update/initialise this as expected PiperOrigin-RevId: 876290487
1 parent 40ec3c8 commit 3f83f89

File tree

5 files changed

+169
-5
lines changed

5 files changed

+169
-5
lines changed

torax/_src/core_profiles/initialization.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torax._src.geometry import standard_geometry
3333
from torax._src.neoclassical import neoclassical_models as neoclassical_models_lib
3434
from torax._src.neoclassical.bootstrap_current import base as bootstrap_current_base
35+
from torax._src.physics import formulas
3536
from torax._src.physics import psi_calculations
3637
from torax._src.sources import source_models as source_models_lib
3738
from torax._src.sources import source_profile_builders
@@ -129,6 +130,13 @@ def initial_core_profiles(
129130
charge_state_info_face=ions.charge_state_info_face,
130131
)
131132

133+
# TODO(b/398816463): Clean up energy state initialization as part of V2
134+
# core profiles refactor.
135+
core_profiles = dataclasses.replace(
136+
core_profiles,
137+
internal_plasma_energy=_initialise_internal_energy(core_profiles, geo),
138+
)
139+
132140
return _init_psi_and_psi_derived(
133141
runtime_params,
134142
geo,
@@ -138,6 +146,30 @@ def initial_core_profiles(
138146
)
139147

140148

149+
def _initialise_internal_energy(
150+
core_profiles: state.CoreProfiles,
151+
geo: geometry.Geometry,
152+
) -> state.PlasmaInternalEnergy:
153+
"""Initializes the energy stored in the plasma."""
154+
W_thermal_e, W_thermal_i, W_thermal_total = (
155+
formulas.calculate_stored_thermal_energy(
156+
core_profiles.pressure_thermal_e,
157+
core_profiles.pressure_thermal_i,
158+
core_profiles.pressure_thermal_total,
159+
geo,
160+
)
161+
)
162+
return state.PlasmaInternalEnergy(
163+
W_thermal_i=W_thermal_i,
164+
W_thermal_e=W_thermal_e,
165+
W_thermal_total=W_thermal_total,
166+
dW_thermal_i_dt=jnp.array(0.0, dtype=jax_utils.get_dtype()),
167+
dW_thermal_e_dt=jnp.array(0.0, dtype=jax_utils.get_dtype()),
168+
dW_thermal_i_dt_smoothed=jnp.array(0.0, dtype=jax_utils.get_dtype()),
169+
dW_thermal_e_dt_smoothed=jnp.array(0.0, dtype=jax_utils.get_dtype()),
170+
)
171+
172+
141173
def update_psi_from_j(
142174
Ip: array_typing.FloatScalar,
143175
geo: geometry.Geometry,

torax/_src/core_profiles/updaters.py

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,17 @@
3131
import dataclasses
3232

3333
import jax
34+
import jax.numpy as jnp
3435
from torax._src import array_typing
36+
from torax._src import jax_utils
3537
from torax._src import state
3638
from torax._src.config import runtime_params as runtime_params_lib
3739
from torax._src.core_profiles import convertors
3840
from torax._src.core_profiles import getters
3941
from torax._src.fvm import cell_variable
4042
from torax._src.geometry import geometry
4143
from torax._src.neoclassical import neoclassical_models as neoclassical_models_lib
44+
from torax._src.physics import formulas
4245
from torax._src.physics import psi_calculations
4346
from torax._src.sources import source_models as source_models_lib
4447
from torax._src.sources import source_profile_builders
@@ -68,15 +71,12 @@ def update_core_profiles_during_step(
6871
x_new: The new values of the evolving variables.
6972
runtime_params: The runtime params slice.
7073
geo: Magnetic geometry.
71-
core_profiles: The old set of core plasma profiles.
74+
core_profiles: The old set of core plasma profiles for this timestep.
7275
prev_core_profiles: Core plasma profiles from the previous timestep if
7376
available, used to update the energy state.
7477
dt: The size of the last timestep, used to update the energy state.
7578
evolving_names: The names of the evolving variables.
7679
"""
77-
# Currently unused but will be used to update the energy state soon
78-
del prev_core_profiles, dt
79-
8080
updated_core_profiles = convertors.solver_x_tuple_to_core_profiles(
8181
x_new, evolving_names, core_profiles
8282
)
@@ -88,7 +88,7 @@ def update_core_profiles_during_step(
8888
updated_core_profiles.T_e,
8989
)
9090

91-
return dataclasses.replace(
91+
updated_core_profiles = dataclasses.replace(
9292
updated_core_profiles,
9393
n_i=ions.n_i,
9494
n_impurity=ions.n_impurity,
@@ -109,6 +109,24 @@ def update_core_profiles_during_step(
109109
charge_state_info_face=ions.charge_state_info_face,
110110
)
111111

112+
if prev_core_profiles is not None:
113+
if dt is None:
114+
raise ValueError('dt must be provided when updating the energy state.')
115+
energy_state = _update_energy_state(
116+
runtime_params,
117+
geo,
118+
updated_core_profiles,
119+
prev_core_profiles.internal_plasma_energy,
120+
dt,
121+
)
122+
else:
123+
energy_state = core_profiles.internal_plasma_energy
124+
125+
return dataclasses.replace(
126+
updated_core_profiles,
127+
internal_plasma_energy=energy_state,
128+
)
129+
112130

113131
def update_core_and_source_profiles_after_step(
114132
dt: array_typing.FloatScalar,
@@ -208,6 +226,13 @@ def update_core_and_source_profiles_after_step(
208226
charge_state_info=ions.charge_state_info,
209227
charge_state_info_face=ions.charge_state_info_face,
210228
)
229+
energy_state = _update_energy_state(
230+
runtime_params_t_plus_dt,
231+
geo,
232+
intermediate_core_profiles,
233+
core_profiles_t.internal_plasma_energy,
234+
dt,
235+
)
211236

212237
conductivity = neoclassical_models.conductivity.calculate_conductivity(
213238
geo, intermediate_core_profiles
@@ -217,6 +242,7 @@ def update_core_and_source_profiles_after_step(
217242
intermediate_core_profiles,
218243
sigma=conductivity.sigma,
219244
sigma_face=conductivity.sigma_face,
245+
internal_plasma_energy=energy_state,
220246
)
221247

222248
# build_source_profiles calculates the union with explicit + implicit
@@ -316,3 +342,55 @@ def provide_core_profiles_t_plus_dt(
316342
toroidal_angular_velocity=toroidal_angular_velocity,
317343
)
318344
return core_profiles_t_plus_dt
345+
346+
347+
def _update_energy_state(
348+
runtime_params: runtime_params_lib.RuntimeParams,
349+
geo: geometry.Geometry,
350+
core_profiles: state.CoreProfiles,
351+
prev_energy_state: state.PlasmaInternalEnergy,
352+
dt: array_typing.FloatScalar,
353+
) -> state.PlasmaInternalEnergy:
354+
"""Updates the energy state."""
355+
W_thermal_e, W_thermal_i, W_thermal_total = (
356+
formulas.calculate_stored_thermal_energy(
357+
core_profiles.pressure_thermal_e,
358+
core_profiles.pressure_thermal_i,
359+
core_profiles.pressure_thermal_total,
360+
geo,
361+
)
362+
)
363+
dW_i_dt_raw = (W_thermal_i - prev_energy_state.W_thermal_i) / dt
364+
dW_e_dt_raw = (W_thermal_e - prev_energy_state.W_thermal_e) / dt
365+
366+
exponential_smoothing_alpha = jax.lax.cond(
367+
runtime_params.numerics.dW_dt_smoothing_time_scale > 0.0,
368+
lambda: jnp.array(1.0, dtype=jax_utils.get_dtype())
369+
- jnp.exp(-dt / runtime_params.numerics.dW_dt_smoothing_time_scale),
370+
lambda: jnp.array(1.0, dtype=jax_utils.get_dtype()),
371+
)
372+
dW_i_dt_smoothed = _exponential_smoothing(
373+
dW_i_dt_raw,
374+
prev_energy_state.dW_thermal_i_dt_smoothed,
375+
exponential_smoothing_alpha,
376+
)
377+
dW_e_dt_smoothed = _exponential_smoothing(
378+
dW_e_dt_raw,
379+
prev_energy_state.dW_thermal_e_dt_smoothed,
380+
exponential_smoothing_alpha,
381+
)
382+
383+
return state.PlasmaInternalEnergy(
384+
W_thermal_i=W_thermal_i,
385+
W_thermal_e=W_thermal_e,
386+
W_thermal_total=W_thermal_total,
387+
dW_thermal_i_dt=dW_i_dt_raw,
388+
dW_thermal_e_dt=dW_e_dt_raw,
389+
dW_thermal_i_dt_smoothed=dW_i_dt_smoothed,
390+
dW_thermal_e_dt_smoothed=dW_e_dt_smoothed,
391+
)
392+
393+
394+
def _exponential_smoothing(new_raw, old_smoothed, alpha):
395+
"""Exponential moving average (EMA)."""
396+
return (1.0 - alpha) * old_smoothed + alpha * new_raw

torax/_src/orchestration/initial_state.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,20 @@ def get_initial_state_and_post_processed_outputs_from_file(
259259
'dW_thermal_e_dt_smoothed'
260260
].to_numpy(),
261261
)
262+
energy_state = initial_state.core_profiles.internal_plasma_energy
263+
energy_state = dataclasses.replace(
264+
energy_state,
265+
dW_thermal_i_dt_smoothed=scalars_dataset.data_vars[
266+
'dW_thermal_i_dt_smoothed'
267+
].to_numpy(),
268+
dW_thermal_e_dt_smoothed=scalars_dataset.data_vars[
269+
'dW_thermal_e_dt_smoothed'
270+
].to_numpy(),
271+
)
262272
core_profiles = dataclasses.replace(
263273
initial_state.core_profiles,
264274
v_loop_lcfs=scalars_dataset.v_loop_lcfs.values,
275+
internal_plasma_energy=energy_state,
265276
)
266277
numerics_dataset = data_tree.children[output.NUMERICS].dataset
267278
numerics_dataset = numerics_dataset.squeeze()

torax/_src/output_tools/output.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,11 @@ def _save_core_profiles(
581581
if attr_name == "main_ion_fractions":
582582
continue
583583

584+
# Skip internal_plasma_energy as it is in post_processed_outputs.
585+
# TODO(b/434175938): Remove once we move to V2.
586+
if attr_name == "internal_plasma_energy":
587+
continue
588+
584589
attr_value = getattr(stacked_core_profiles, attr_name)
585590

586591
output_key = output_name_map.get(attr_name, attr_name)

torax/_src/state.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,41 @@
3232

3333

3434
# pylint: disable=invalid-name
35+
@jax.tree_util.register_dataclass
36+
@dataclasses.dataclass(frozen=True)
37+
class PlasmaInternalEnergy:
38+
"""Internal energy stored in the plasma and its time derivatives.
39+
40+
Attributes:
41+
W_thermal_i: Ion thermal stored energy [J]
42+
W_thermal_e: Electron thermal stored energy [J]
43+
W_thermal_total: Total thermal stored energy [J]
44+
dW_thermal_i_dt: Time derivative of ion thermal stored energy [W]
45+
dW_thermal_e_dt: Time derivative of electron thermal stored energy [W]
46+
dW_thermal_i_dt_smoothed: Smoothed time derivative of ion thermal stored
47+
energy [W]
48+
dW_thermal_e_dt_smoothed: Smoothed time derivative of electron thermal
49+
stored energy [W]
50+
"""
51+
W_thermal_i: array_typing.FloatScalar
52+
W_thermal_e: array_typing.FloatScalar
53+
W_thermal_total: array_typing.FloatScalar
54+
dW_thermal_i_dt: array_typing.FloatScalar
55+
dW_thermal_e_dt: array_typing.FloatScalar
56+
dW_thermal_i_dt_smoothed: array_typing.FloatScalar
57+
dW_thermal_e_dt_smoothed: array_typing.FloatScalar
58+
59+
@property
60+
def dW_thermal_dt(self) -> array_typing.FloatScalar:
61+
"""Total thermal stored energy time derivative [W], raw unsmoothed."""
62+
return self.dW_thermal_i_dt + self.dW_thermal_e_dt
63+
64+
@property
65+
def dW_thermal_dt_smoothed(self) -> array_typing.FloatScalar:
66+
"""Smoothed total thermal stored energy time derivative [W]."""
67+
return self.dW_thermal_i_dt_smoothed + self.dW_thermal_e_dt_smoothed
68+
69+
3570
@jax.tree_util.register_dataclass
3671
@dataclasses.dataclass(frozen=True, eq=False)
3772
class CoreProfiles:
@@ -78,6 +113,8 @@ class CoreProfiles:
78113
state information. See `charge_states.ChargeStateInfo`. Cell grid.
79114
charge_state_info_face: Container with averaged and per-species ion charge
80115
state information. See `charge_states.ChargeStateInfo`. Face grid.
116+
internal_plasma_energy: Container with energy variables. See
117+
`PlasmaInternalEnergy`.
81118
"""
82119

83120
T_i: cell_variable.CellVariable
@@ -109,6 +146,7 @@ class CoreProfiles:
109146
toroidal_angular_velocity: cell_variable.CellVariable
110147
charge_state_info: charge_states.ChargeStateInfo
111148
charge_state_info_face: charge_states.ChargeStateInfo
149+
internal_plasma_energy: PlasmaInternalEnergy | None = None
112150

113151
@functools.cached_property
114152
def impurity_density_scaling(self) -> jax.Array:

0 commit comments

Comments
 (0)