Skip to content

Commit 8454b96

Browse files
jcitrinTorax team
authored andcommitted
Add dW/dt terms to P_SOL and P_loss calculations.
Introduces a dW_dt_smoothing_time_scale parameter in the numerics config to control an exponential moving average for dW/dt. The smoothed dW/dt is now used in the calculation of P_SOL terms and the energy confinement time (tauE). The raw dW/dt is still available as a separate output, for backwards compatibility and also helping with judging the impact of the smoothing. Sim tests regenerated due to modified post processed outputs. P_SOL and P_loss terms with dW/dt verified against RAPTOR. PiperOrigin-RevId: 852692319
1 parent f3a791e commit 8454b96

File tree

51 files changed

+137
-44
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+137
-44
lines changed

torax/_src/config/numerics.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class RuntimeParams:
4545
resistivity_multiplier: array_typing.FloatScalar
4646
adaptive_T_source_prefactor: float
4747
adaptive_n_source_prefactor: float
48+
dW_dt_smoothing_time_scale: float
4849
evolve_ion_heat: bool = dataclasses.field(metadata={'static': True})
4950
evolve_electron_heat: bool = dataclasses.field(metadata={'static': True})
5051
evolve_current: bool = dataclasses.field(metadata={'static': True})
@@ -102,6 +103,9 @@ class Numerics(torax_pydantic.BaseModelFrozen):
102103
temperature internal boundary conditions.
103104
adaptive_n_source_prefactor: Prefactor for adaptive source term for setting
104105
density internal boundary conditions.
106+
dW_dt_smoothing_time_scale: Time scale [s] for the exponential moving
107+
average smoothing of dW/dt terms used in P_SOL and confinement time
108+
calculations. If 0.0, no smoothing is applied and raw dW/dt is used.
105109
"""
106110

107111
t_initial: torax_pydantic.Second = 0.0
@@ -125,6 +129,7 @@ class Numerics(torax_pydantic.BaseModelFrozen):
125129
)
126130
adaptive_T_source_prefactor: pydantic.PositiveFloat = 2.0e10
127131
adaptive_n_source_prefactor: pydantic.PositiveFloat = 2.0e8
132+
dW_dt_smoothing_time_scale: pydantic.NonNegativeFloat = 0.3
128133

129134
T_minimum_eV: pydantic.PositiveFloat = 5.0
130135

@@ -170,6 +175,7 @@ def build_runtime_params(self, t: chex.Numeric) -> RuntimeParams:
170175
resistivity_multiplier=self.resistivity_multiplier.get_value(t),
171176
adaptive_T_source_prefactor=self.adaptive_T_source_prefactor,
172177
adaptive_n_source_prefactor=self.adaptive_n_source_prefactor,
178+
dW_dt_smoothing_time_scale=self.dW_dt_smoothing_time_scale,
173179
evolve_ion_heat=self.evolve_ion_heat,
174180
evolve_electron_heat=self.evolve_electron_heat,
175181
evolve_current=self.evolve_current,

torax/_src/output_tools/post_processing.py

Lines changed: 122 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,13 @@ class PostProcessedOutputs:
6666
FFprime: FF' on the face grid, where F is the toroidal flux function
6767
psi_norm: Normalized poloidal flux on the face grid [Wb]
6868
P_SOL_i: Total ion heating power exiting the plasma with all sources:
69-
auxiliary heating + ion-electron exchange + fusion [W]
69+
auxiliary heating + ion-electron exchange + fusion [W]. Includes smoothed
70+
dW/dt correction.
7071
P_SOL_e: Total electron heating power exiting the plasma with all sources
7172
and sinks: auxiliary heating + ion-electron exchange + Ohmic + fusion +
72-
radiation sinks [W]
73+
radiation sinks [W]. Includes smoothed dW/dt correction.
7374
P_SOL_total: Total heating power exiting the plasma with all sources and
74-
sinks
75+
sinks. Includes smoothed dW/dt correction.
7576
P_aux_i: Total auxiliary ion heating power [W]
7677
P_aux_e: Total auxiliary electron heating power [W]
7778
P_aux_total: Total auxiliary heating power [W]
@@ -115,7 +116,6 @@ class PostProcessedOutputs:
115116
T_e_volume_avg: Volume average electron temperature [keV]
116117
T_i_volume_avg: Volume average ion temperature [keV]
117118
n_e_volume_avg: Volume average electron density [m^-3]
118-
n_e_volume_avg: Volume average electron density [m^-3]
119119
n_i_volume_avg: Volume average main ion density [m^-3]
120120
n_e_line_avg: Line averaged electron density [m^-3]
121121
n_i_line_avg: Line averaged main ion density [m^-3]
@@ -126,7 +126,14 @@ class PostProcessedOutputs:
126126
q95: q at 95% of the normalized poloidal flux
127127
W_pol: Total magnetic energy [J]
128128
li3: Normalized plasma internal inductance, ITER convention [dimensionless]
129-
dW_thermal_dt: Time derivative of the total stored thermal energy [W]
129+
dW_thermal_dt: Time derivative of the total stored thermal energy [W], raw
130+
unsmoothed value.
131+
dW_thermal_dt_smoothed: Smoothed time derivative of total stored thermal
132+
energy [W].
133+
dW_thermal_i_dt_smoothed: Smoothed time derivative of ion stored thermal
134+
energy [W].
135+
dW_thermal_e_dt_smoothed: Smoothed time derivative of electron stored
136+
thermal energy [W].
130137
q_min: Minimum q value
131138
rho_q_min: rho_norm at the minimum q
132139
rho_q_3_2_first: First outermost rho_norm value that intercepts the q=3/2
@@ -235,6 +242,9 @@ class PostProcessedOutputs:
235242
W_pol: array_typing.FloatScalar
236243
li3: array_typing.FloatScalar
237244
dW_thermal_dt: array_typing.FloatScalar
245+
dW_thermal_dt_smoothed: array_typing.FloatScalar
246+
dW_thermal_i_dt_smoothed: array_typing.FloatScalar
247+
dW_thermal_e_dt_smoothed: array_typing.FloatScalar
238248
rho_q_min: array_typing.FloatScalar
239249
q_min: array_typing.FloatScalar
240250
rho_q_3_2_first: array_typing.FloatScalar
@@ -332,6 +342,9 @@ def zeros(cls, geo: geometry.Geometry) -> typing_extensions.Self:
332342
W_pol=jnp.array(0.0, dtype=jax_utils.get_dtype()),
333343
li3=jnp.array(0.0, dtype=jax_utils.get_dtype()),
334344
dW_thermal_dt=jnp.array(0.0, dtype=jax_utils.get_dtype()),
345+
dW_thermal_dt_smoothed=jnp.array(0.0, dtype=jax_utils.get_dtype()),
346+
dW_thermal_i_dt_smoothed=jnp.array(0.0, dtype=jax_utils.get_dtype()),
347+
dW_thermal_e_dt_smoothed=jnp.array(0.0, dtype=jax_utils.get_dtype()),
335348
rho_q_min=jnp.array(0.0, dtype=jax_utils.get_dtype()),
336349
q_min=jnp.array(0.0, dtype=jax_utils.get_dtype()),
337350
rho_q_3_2_first=jnp.array(0.0, dtype=jax_utils.get_dtype()),
@@ -460,12 +473,7 @@ def _calculate_integrated_sources(
460473
integrated['P_ei_exchange_i'] = math_utils.volume_integration(qei, geo)
461474
integrated['P_ei_exchange_e'] = -integrated['P_ei_exchange_i']
462475

463-
# Initialize total electron and ion powers
464-
# TODO(b/380848256): P_sol is now correct for stationary state. However,
465-
# for generality need to add dWth/dt to the equation (time dependence of
466-
# stored energy).
467-
integrated['P_SOL_i'] = integrated['P_ei_exchange_i']
468-
integrated['P_SOL_e'] = integrated['P_ei_exchange_e']
476+
# Initialize total electron and ion auxiliary powers.
469477
integrated['P_aux_i'] = jnp.array(0.0, dtype=jax_utils.get_dtype())
470478
integrated['P_aux_e'] = jnp.array(0.0, dtype=jax_utils.get_dtype())
471479
integrated['P_external_injected'] = jnp.array(
@@ -492,8 +500,7 @@ def _calculate_integrated_sources(
492500
integrated[f'{value}_total'] = (
493501
integrated[f'{value}_i'] + integrated[f'{value}_e']
494502
)
495-
integrated['P_SOL_i'] += integrated[f'{value}_i']
496-
integrated['P_SOL_e'] += integrated[f'{value}_e']
503+
497504
if key in EXTERNAL_HEATING_SOURCES:
498505
integrated['P_aux_i'] += integrated[f'{value}_i']
499506
integrated['P_aux_e'] += integrated[f'{value}_e']
@@ -521,7 +528,6 @@ def _calculate_integrated_sources(
521528
integrated[f'{value}'] = _get_integrated_source_value(
522529
core_sources.T_e, key, geo, math_utils.volume_integration
523530
)
524-
integrated['P_SOL_e'] += integrated[f'{value}']
525531
if key in EXTERNAL_HEATING_SOURCES:
526532
integrated['P_aux_e'] += integrated[f'{value}']
527533
integrated['P_external_injected'] += integrated[f'{value}']
@@ -543,7 +549,6 @@ def _calculate_integrated_sources(
543549
)
544550
integrated['S_total'] += integrated[f'{value}']
545551

546-
integrated['P_SOL_total'] = integrated['P_SOL_i'] + integrated['P_SOL_e']
547552
integrated['P_aux_total'] = integrated['P_aux_i'] + integrated['P_aux_e']
548553
integrated['P_fusion'] = 5 * integrated['P_alpha_total']
549554
integrated['P_external_total'] = (
@@ -614,24 +619,100 @@ def make_post_processed_outputs(
614619
)
615620
)
616621

617-
# Thermal energy confinement time is the stored energy divided by the total
618-
# input power into the plasma.
622+
# Calculate dW/dt.
623+
# We perform raw calculation and smoothing inside a conditional block to
624+
# prevent division by zero on the first step (when dt=0) and to avoid
625+
# large transients (since previous W is initialized to 0).
626+
def _calculate_dW_dt_terms():
627+
# Raw values
628+
dW_i_dt_raw = (
629+
W_thermal_ion - previous_post_processed_outputs.W_thermal_i
630+
) / sim_state.dt
631+
dW_e_dt_raw = (
632+
W_thermal_el - previous_post_processed_outputs.W_thermal_e
633+
) / sim_state.dt
634+
dW_total_dt_raw = dW_i_dt_raw + dW_e_dt_raw
635+
636+
# Smoothing
637+
alpha = jax.lax.cond(
638+
runtime_params.numerics.dW_dt_smoothing_time_scale > 0.0,
639+
lambda: jnp.array(1.0, dtype=jax_utils.get_dtype())
640+
- jnp.exp(
641+
-sim_state.dt / runtime_params.numerics.dW_dt_smoothing_time_scale
642+
),
643+
lambda: jnp.array(1.0, dtype=jax_utils.get_dtype()),
644+
)
645+
646+
def _smooth(new_raw, old_smoothed):
647+
return (1.0 - alpha) * old_smoothed + alpha * new_raw
648+
649+
dW_i_dt_smoothed = _smooth(
650+
dW_i_dt_raw,
651+
previous_post_processed_outputs.dW_thermal_i_dt_smoothed,
652+
)
653+
dW_e_dt_smoothed = _smooth(
654+
dW_e_dt_raw,
655+
previous_post_processed_outputs.dW_thermal_e_dt_smoothed,
656+
)
657+
dW_total_dt_smoothed = dW_i_dt_smoothed + dW_e_dt_smoothed
658+
659+
return (
660+
dW_total_dt_raw,
661+
dW_total_dt_smoothed,
662+
dW_i_dt_smoothed,
663+
dW_e_dt_smoothed,
664+
)
665+
666+
(
667+
dW_thermal_total_dt_raw,
668+
dW_thermal_total_dt_smoothed,
669+
dW_thermal_i_dt_smoothed,
670+
dW_thermal_e_dt_smoothed,
671+
) = jax.lax.cond(
672+
previous_post_processed_outputs.first_step,
673+
lambda: (0.0, 0.0, 0.0, 0.0),
674+
_calculate_dW_dt_terms,
675+
)
676+
677+
# Calculate P_SOL (Power crossing separatrix) = P_sources - P_sinks - dW/dt
678+
integrated_sources['P_SOL_i'] = (
679+
integrated_sources['P_aux_i']
680+
+ integrated_sources['P_alpha_i']
681+
+ integrated_sources['P_ei_exchange_i']
682+
- dW_thermal_i_dt_smoothed
683+
)
684+
685+
# Note: P_bremsstrahlung_e, P_cyclotron_e, P_radiation_e are sink terms
686+
# (negative) in integrated_sources, so we add them.
687+
integrated_sources['P_SOL_e'] = (
688+
integrated_sources['P_aux_e']
689+
+ integrated_sources['P_alpha_e']
690+
+ integrated_sources['P_ei_exchange_e']
691+
+ integrated_sources['P_ohmic_e']
692+
+ integrated_sources['P_bremsstrahlung_e']
693+
+ integrated_sources['P_cyclotron_e']
694+
+ integrated_sources['P_radiation_e']
695+
- dW_thermal_e_dt_smoothed
696+
)
697+
698+
integrated_sources['P_SOL_total'] = (
699+
integrated_sources['P_SOL_i'] + integrated_sources['P_SOL_e']
700+
)
619701

620-
# Ploss term here does not include the reduction of radiated power. Most
621-
# analysis of confinement times from databases have not included this term.
702+
# Calculate P_loss term used for confinement time calculations.
703+
# As per standard definitions, P_loss does not include radiation terms.
622704
# Therefore highly radiative scenarios can lead to skewed results.
623705

624-
Ploss = (
706+
P_loss = (
625707
integrated_sources['P_alpha_total']
626708
+ integrated_sources['P_aux_total']
627709
+ integrated_sources['P_ohmic_e']
710+
- dW_thermal_total_dt_smoothed
628711
+ constants.CONSTANTS.eps # Division guard.
629712
)
630713

631714
def cumulative_values():
632-
dW_th_dt = (
633-
W_thermal_tot - previous_post_processed_outputs.W_thermal_total
634-
) / sim_state.dt
715+
635716
E_fusion = (
636717
previous_post_processed_outputs.E_fusion
637718
+ sim_state.dt
@@ -678,7 +759,6 @@ def cumulative_values():
678759
/ 2.0
679760
)
680761
return (
681-
dW_th_dt,
682762
E_fusion,
683763
E_aux_total,
684764
E_ohmic_e,
@@ -687,37 +767,36 @@ def cumulative_values():
687767
)
688768

689769
(
690-
dW_th_dt,
691770
E_fusion,
692771
E_aux_total,
693772
E_ohmic_e,
694773
E_external_injected,
695774
E_external_total,
696775
) = jax.lax.cond(
697776
previous_post_processed_outputs.first_step,
698-
lambda: (0.0,) * 6,
777+
lambda: (0.0,) * 5,
699778
cumulative_values,
700779
)
701780

702-
tauE = W_thermal_tot / Ploss
781+
tau_E = W_thermal_tot / P_loss
703782

704783
tauH89P = scaling_laws.calculate_scaling_law_confinement_time(
705-
sim_state.geometry, sim_state.core_profiles, Ploss, 'H89P'
784+
sim_state.geometry, sim_state.core_profiles, P_loss, 'H89P'
706785
)
707786
tauH98 = scaling_laws.calculate_scaling_law_confinement_time(
708-
sim_state.geometry, sim_state.core_profiles, Ploss, 'H98'
787+
sim_state.geometry, sim_state.core_profiles, P_loss, 'H98'
709788
)
710789
tauH97L = scaling_laws.calculate_scaling_law_confinement_time(
711-
sim_state.geometry, sim_state.core_profiles, Ploss, 'H97L'
790+
sim_state.geometry, sim_state.core_profiles, P_loss, 'H97L'
712791
)
713792
tauH20 = scaling_laws.calculate_scaling_law_confinement_time(
714-
sim_state.geometry, sim_state.core_profiles, Ploss, 'H20'
793+
sim_state.geometry, sim_state.core_profiles, P_loss, 'H20'
715794
)
716795

717-
H89P = tauE / tauH89P
718-
H98 = tauE / tauH98
719-
H97L = tauE / tauH97L
720-
H20 = tauE / tauH20
796+
H89P = tau_E / tauH89P
797+
H98 = tau_E / tauH98
798+
H97L = tau_E / tauH97L
799+
H20 = tau_E / tauH20
721800

722801
# Calculate q at 95% of the normalized poloidal flux
723802
q95 = psi_calculations.calc_q95(psi_norm_face, sim_state.core_profiles.q_face)
@@ -749,12 +828,12 @@ def cumulative_values():
749828
fgw_n_e_line_avg = formulas.calculate_greenwald_fraction(
750829
n_e_line_avg, sim_state.core_profiles, sim_state.geometry
751830
)
752-
Wpol = psi_calculations.calc_Wpol(
831+
W_pol = psi_calculations.calc_Wpol(
753832
sim_state.geometry, sim_state.core_profiles.psi
754833
)
755834
li3 = psi_calculations.calc_li3(
756835
sim_state.geometry.R_major,
757-
Wpol,
836+
W_pol,
758837
sim_state.core_profiles.Ip_profile_face[-1],
759838
)
760839

@@ -841,7 +920,7 @@ def cumulative_values():
841920
W_thermal_i=W_thermal_ion,
842921
W_thermal_e=W_thermal_el,
843922
W_thermal_total=W_thermal_tot,
844-
tau_E=tauE,
923+
tau_E=tau_E,
845924
H89P=H89P,
846925
H98=H98,
847926
H97L=H97L,
@@ -868,9 +947,12 @@ def cumulative_values():
868947
fgw_n_e_volume_avg=fgw_n_e_volume_avg,
869948
fgw_n_e_line_avg=fgw_n_e_line_avg,
870949
q95=q95,
871-
W_pol=Wpol,
950+
W_pol=W_pol,
872951
li3=li3,
873-
dW_thermal_dt=dW_th_dt,
952+
dW_thermal_dt=dW_thermal_total_dt_raw,
953+
dW_thermal_dt_smoothed=dW_thermal_total_dt_smoothed,
954+
dW_thermal_i_dt_smoothed=dW_thermal_i_dt_smoothed,
955+
dW_thermal_e_dt_smoothed=dW_thermal_e_dt_smoothed,
874956
rho_q_min=safety_factor_fit_outputs.rho_q_min,
875957
q_min=safety_factor_fit_outputs.q_min,
876958
rho_q_3_2_first=safety_factor_fit_outputs.rho_q_3_2_first,
@@ -924,4 +1006,3 @@ def _convert_j_parallel_face_to_j_toroidal_face(
9241006
j_parallel_to_j_toroidal_factor_cell,
9251007
)
9261008
return j_parallel_to_j_toroidal_factor_face * j_parallel_face
927-

torax/_src/physics/scaling_laws.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,15 @@ def calculate_plh_scaling_factor(
105105
def calculate_scaling_law_confinement_time(
106106
geo: geometry.Geometry,
107107
core_profiles: state.CoreProfiles,
108-
Ploss: jax.Array,
108+
P_loss: jax.Array,
109109
scaling_law: str,
110110
) -> jax.Array:
111111
"""Calculates the thermal energy confinement time for a given empirical scaling law.
112112
113113
Args:
114114
geo: Torus geometry.
115115
core_profiles: Core plasma profiles.
116-
Ploss: Plasma power loss in W.
116+
P_loss: Plasma power loss in W.
117117
scaling_law: Scaling law to use.
118118
119119
Returns:
@@ -184,8 +184,14 @@ def calculate_scaling_law_confinement_time(
184184

185185
params = scaling_params[scaling_law]
186186

187+
# Ensure P_loss is positive to avoid NaNs in power laws (e.g. x^-0.69).
188+
# Physically, scaling laws are derived for steady state (P_loss ~ P_heat).
189+
# During transients where dW/dt > P_heat, P_loss can be negative.
190+
# We clamp it to a small positive value.
191+
P_loss = jnp.maximum(P_loss, 1.0)
192+
187193
scaled_Ip = core_profiles.Ip_profile_face[-1] / 1e6 # convert to MA
188-
scaled_Ploss = Ploss / 1e6 # convert to MW
194+
scaled_Ploss = P_loss / 1e6 # convert to MW
189195
B = geo.B_0
190196
line_avg_n_e = ( # convert to 10^19 m^-3
191197
math_utils.line_average(core_profiles.n_e.value, geo) / 1e19
-2.7 KB
Binary file not shown.
-1.5 KB
Binary file not shown.
201 Bytes
Binary file not shown.
169 Bytes
Binary file not shown.
63 Bytes
Binary file not shown.
-4.29 KB
Binary file not shown.
1.67 KB
Binary file not shown.

0 commit comments

Comments
 (0)