3131import dataclasses
3232
3333import jax
34+ import jax .numpy as jnp
3435from torax ._src import array_typing
36+ from torax ._src import jax_utils
3537from torax ._src import state
3638from torax ._src .config import runtime_params as runtime_params_lib
3739from torax ._src .core_profiles import convertors
3840from torax ._src .core_profiles import getters
3941from torax ._src .fvm import cell_variable
4042from torax ._src .geometry import geometry
4143from torax ._src .neoclassical import neoclassical_models as neoclassical_models_lib
44+ from torax ._src .physics import formulas
4245from torax ._src .physics import psi_calculations
4346from torax ._src .sources import source_models as source_models_lib
4447from torax ._src .sources import source_profile_builders
@@ -68,15 +71,12 @@ def update_core_profiles_during_step(
6871 x_new: The new values of the evolving variables.
6972 runtime_params: The runtime params slice.
7073 geo: Magnetic geometry.
71- core_profiles: The old set of core plasma profiles.
74+ core_profiles: The old set of core plasma profiles for this timestep .
7275 prev_core_profiles: Core plasma profiles from the previous timestep if
7376 available, used to update the energy state.
7477 dt: The size of the last timestep, used to update the energy state.
7578 evolving_names: The names of the evolving variables.
7679 """
77- # Currently unused but will be used to update the energy state soon
78- del prev_core_profiles , dt
79-
8080 updated_core_profiles = convertors .solver_x_tuple_to_core_profiles (
8181 x_new , evolving_names , core_profiles
8282 )
@@ -88,7 +88,7 @@ def update_core_profiles_during_step(
8888 updated_core_profiles .T_e ,
8989 )
9090
91- return dataclasses .replace (
91+ updated_core_profiles = dataclasses .replace (
9292 updated_core_profiles ,
9393 n_i = ions .n_i ,
9494 n_impurity = ions .n_impurity ,
@@ -109,6 +109,24 @@ def update_core_profiles_during_step(
109109 charge_state_info_face = ions .charge_state_info_face ,
110110 )
111111
112+ if prev_core_profiles is not None :
113+ if dt is None :
114+ raise ValueError ('dt must be provided when updating the energy state.' )
115+ energy_state = _update_energy_state (
116+ runtime_params ,
117+ geo ,
118+ updated_core_profiles ,
119+ prev_core_profiles .internal_plasma_energy ,
120+ dt ,
121+ )
122+ else :
123+ energy_state = core_profiles .internal_plasma_energy
124+
125+ return dataclasses .replace (
126+ updated_core_profiles ,
127+ internal_plasma_energy = energy_state ,
128+ )
129+
112130
113131def update_core_and_source_profiles_after_step (
114132 dt : array_typing .FloatScalar ,
@@ -208,6 +226,13 @@ def update_core_and_source_profiles_after_step(
208226 charge_state_info = ions .charge_state_info ,
209227 charge_state_info_face = ions .charge_state_info_face ,
210228 )
229+ energy_state = _update_energy_state (
230+ runtime_params_t_plus_dt ,
231+ geo ,
232+ intermediate_core_profiles ,
233+ core_profiles_t .internal_plasma_energy ,
234+ dt ,
235+ )
211236
212237 conductivity = neoclassical_models .conductivity .calculate_conductivity (
213238 geo , intermediate_core_profiles
@@ -217,6 +242,7 @@ def update_core_and_source_profiles_after_step(
217242 intermediate_core_profiles ,
218243 sigma = conductivity .sigma ,
219244 sigma_face = conductivity .sigma_face ,
245+ internal_plasma_energy = energy_state ,
220246 )
221247
222248 # build_source_profiles calculates the union with explicit + implicit
@@ -316,3 +342,55 @@ def provide_core_profiles_t_plus_dt(
316342 toroidal_angular_velocity = toroidal_angular_velocity ,
317343 )
318344 return core_profiles_t_plus_dt
345+
346+
347+ def _update_energy_state (
348+ runtime_params : runtime_params_lib .RuntimeParams ,
349+ geo : geometry .Geometry ,
350+ core_profiles : state .CoreProfiles ,
351+ prev_energy_state : state .PlasmaInternalEnergy ,
352+ dt : array_typing .FloatScalar ,
353+ ) -> state .PlasmaInternalEnergy :
354+ """Updates the energy state."""
355+ W_thermal_e , W_thermal_i , W_thermal_total = (
356+ formulas .calculate_stored_thermal_energy (
357+ core_profiles .pressure_thermal_e ,
358+ core_profiles .pressure_thermal_i ,
359+ core_profiles .pressure_thermal_total ,
360+ geo ,
361+ )
362+ )
363+ dW_i_dt_raw = (W_thermal_i - prev_energy_state .W_thermal_i ) / dt
364+ dW_e_dt_raw = (W_thermal_e - prev_energy_state .W_thermal_e ) / dt
365+
366+ exponential_smoothing_alpha = jax .lax .cond (
367+ runtime_params .numerics .dW_dt_smoothing_time_scale > 0.0 ,
368+ lambda : jnp .array (1.0 , dtype = jax_utils .get_dtype ())
369+ - jnp .exp (- dt / runtime_params .numerics .dW_dt_smoothing_time_scale ),
370+ lambda : jnp .array (1.0 , dtype = jax_utils .get_dtype ()),
371+ )
372+ dW_i_dt_smoothed = _exponential_smoothing (
373+ dW_i_dt_raw ,
374+ prev_energy_state .dW_thermal_i_dt_smoothed ,
375+ exponential_smoothing_alpha ,
376+ )
377+ dW_e_dt_smoothed = _exponential_smoothing (
378+ dW_e_dt_raw ,
379+ prev_energy_state .dW_thermal_e_dt_smoothed ,
380+ exponential_smoothing_alpha ,
381+ )
382+
383+ return state .PlasmaInternalEnergy (
384+ W_thermal_i = W_thermal_i ,
385+ W_thermal_e = W_thermal_e ,
386+ W_thermal_total = W_thermal_total ,
387+ dW_thermal_i_dt = dW_i_dt_raw ,
388+ dW_thermal_e_dt = dW_e_dt_raw ,
389+ dW_thermal_i_dt_smoothed = dW_i_dt_smoothed ,
390+ dW_thermal_e_dt_smoothed = dW_e_dt_smoothed ,
391+ )
392+
393+
394+ def _exponential_smoothing (new_raw , old_smoothed , alpha ):
395+ """Exponential moving average (EMA)."""
396+ return (1.0 - alpha ) * old_smoothed + alpha * new_raw
0 commit comments