@@ -195,6 +195,7 @@ class PostProcessedOutputs:
195195
196196 pprime : array_typing .FloatVector
197197 # pylint: disable=invalid-name
198+ # TODO(b/434175938) Remove W_thermal terms in favor of terms from energy_state
198199 W_thermal_i : array_typing .FloatScalar
199200 W_thermal_e : array_typing .FloatScalar
200201 W_thermal_total : array_typing .FloatScalar
@@ -257,6 +258,7 @@ class PostProcessedOutputs:
257258 q95 : array_typing .FloatScalar
258259 W_pol : array_typing .FloatScalar
259260 li3 : array_typing .FloatScalar
261+ # TODO(b/434175938) Remove dW_dt terms in favor of terms from energy_state.
260262 dW_thermal_dt : array_typing .FloatScalar
261263 dW_thermal_dt_smoothed : array_typing .FloatScalar
262264 dW_thermal_i_dt_smoothed : array_typing .FloatScalar
@@ -632,6 +634,11 @@ def make_post_processed_outputs(
632634 Returns:
633635 post_processed_outputs: The post_processed_outputs for the given state.
634636 """
637+ internal_plasma_energy = sim_state .core_profiles .internal_plasma_energy
638+ if internal_plasma_energy is None :
639+ raise ValueError (
640+ 'internal_plasma_energy is None, but is required for post-processing.'
641+ )
635642 # TODO(b/444380540): Remove once aux outputs from sources are exposed.
636643 impurity_radiation_outputs = (
637644 impurity_radiation .calculate_impurity_species_output (
@@ -640,15 +647,6 @@ def make_post_processed_outputs(
640647 )
641648
642649 pprime_face = formulas .calc_pprime (sim_state .core_profiles )
643- # pylint: disable=invalid-name
644- W_thermal_el , W_thermal_ion , W_thermal_tot = (
645- formulas .calculate_stored_thermal_energy (
646- sim_state .core_profiles .pressure_thermal_e ,
647- sim_state .core_profiles .pressure_thermal_i ,
648- sim_state .core_profiles .pressure_thermal_total ,
649- sim_state .geometry ,
650- )
651- )
652650 FFprime_face = formulas .calc_FFprime (
653651 sim_state .core_profiles , sim_state .geometry
654652 )
@@ -672,67 +670,15 @@ def make_post_processed_outputs(
672670 )
673671 )
674672
675- # Calculate dW/dt.
676- # We perform raw calculation and smoothing inside a conditional block to
677- # prevent division by zero on the first step (when dt=0) and to avoid
678- # large transients (since previous W is initialized to 0).
679- def _calculate_dW_dt_terms ():
680- # Raw values
681- dW_i_dt_raw = (
682- W_thermal_ion - previous_post_processed_outputs .W_thermal_i
683- ) / sim_state .dt
684- dW_e_dt_raw = (
685- W_thermal_el - previous_post_processed_outputs .W_thermal_e
686- ) / sim_state .dt
687- dW_total_dt_raw = dW_i_dt_raw + dW_e_dt_raw
688-
689- # Calculate smoothing parameter
690- alpha = jax .lax .cond (
691- runtime_params .numerics .dW_dt_smoothing_time_scale > 0.0 ,
692- lambda : jnp .array (1.0 , dtype = jax_utils .get_dtype ())
693- - jnp .exp (
694- - sim_state .dt / runtime_params .numerics .dW_dt_smoothing_time_scale
695- ),
696- lambda : jnp .array (1.0 , dtype = jax_utils .get_dtype ()),
697- )
698-
699- dW_i_dt_smoothed = _exponential_smoothing (
700- dW_i_dt_raw ,
701- previous_post_processed_outputs .dW_thermal_i_dt_smoothed ,
702- alpha ,
703- )
704- dW_e_dt_smoothed = _exponential_smoothing (
705- dW_e_dt_raw ,
706- previous_post_processed_outputs .dW_thermal_e_dt_smoothed ,
707- alpha ,
708- )
709- dW_total_dt_smoothed = dW_i_dt_smoothed + dW_e_dt_smoothed
710-
711- return (
712- dW_total_dt_raw ,
713- dW_total_dt_smoothed ,
714- dW_i_dt_smoothed ,
715- dW_e_dt_smoothed ,
716- )
717-
718- (
719- dW_thermal_total_dt_raw ,
720- dW_thermal_total_dt_smoothed ,
721- dW_thermal_i_dt_smoothed ,
722- dW_thermal_e_dt_smoothed ,
723- ) = jax .lax .cond (
724- previous_post_processed_outputs .first_step ,
725- lambda : (0.0 , 0.0 , 0.0 , 0.0 ),
726- _calculate_dW_dt_terms ,
727- )
728-
729673 # Calculate P_SOL (Power crossing separatrix) = P_sources - P_sinks - dW/dt
730674 integrated_sources ['P_SOL_i' ] = (
731- integrated_sources ['P_heat_i' ] - dW_thermal_i_dt_smoothed
675+ integrated_sources ['P_heat_i' ]
676+ - internal_plasma_energy .dW_thermal_i_dt_smoothed
732677 )
733678
734679 integrated_sources ['P_SOL_e' ] = (
735- integrated_sources ['P_heat_e' ] - dW_thermal_e_dt_smoothed
680+ integrated_sources ['P_heat_e' ]
681+ - internal_plasma_energy .dW_thermal_e_dt_smoothed
736682 )
737683
738684 integrated_sources ['P_SOL_total' ] = (
@@ -747,7 +693,7 @@ def _calculate_dW_dt_terms():
747693 integrated_sources ['P_alpha_total' ]
748694 + integrated_sources ['P_aux_total' ]
749695 + integrated_sources ['P_ohmic_e' ]
750- - dW_thermal_total_dt_smoothed
696+ - internal_plasma_energy . dW_thermal_dt_smoothed
751697 + constants .CONSTANTS .eps # Division guard.
752698 )
753699
@@ -818,7 +764,7 @@ def cumulative_values():
818764 cumulative_values ,
819765 )
820766
821- tau_E = W_thermal_tot / P_loss
767+ tau_E = internal_plasma_energy . W_thermal_total / P_loss
822768
823769 tauH89P = scaling_laws .calculate_scaling_law_confinement_time (
824770 sim_state .geometry , sim_state .core_profiles , P_loss , 'H89P'
@@ -971,9 +917,9 @@ def cumulative_values():
971917
972918 return PostProcessedOutputs (
973919 pprime = pprime_face ,
974- W_thermal_i = W_thermal_ion ,
975- W_thermal_e = W_thermal_el ,
976- W_thermal_total = W_thermal_tot ,
920+ W_thermal_i = internal_plasma_energy . W_thermal_i ,
921+ W_thermal_e = internal_plasma_energy . W_thermal_e ,
922+ W_thermal_total = internal_plasma_energy . W_thermal_total ,
977923 tau_E = tau_E ,
978924 H89P = H89P ,
979925 H98 = H98 ,
@@ -1003,10 +949,10 @@ def cumulative_values():
1003949 q95 = q95 ,
1004950 W_pol = W_pol ,
1005951 li3 = li3 ,
1006- dW_thermal_dt = dW_thermal_total_dt_raw ,
1007- dW_thermal_dt_smoothed = dW_thermal_total_dt_smoothed ,
1008- dW_thermal_i_dt_smoothed = dW_thermal_i_dt_smoothed ,
1009- dW_thermal_e_dt_smoothed = dW_thermal_e_dt_smoothed ,
952+ dW_thermal_dt = internal_plasma_energy . dW_thermal_dt ,
953+ dW_thermal_dt_smoothed = internal_plasma_energy . dW_thermal_dt_smoothed ,
954+ dW_thermal_i_dt_smoothed = internal_plasma_energy . dW_thermal_i_dt_smoothed ,
955+ dW_thermal_e_dt_smoothed = internal_plasma_energy . dW_thermal_e_dt_smoothed ,
1010956 rho_q_min = safety_factor_fit_outputs .rho_q_min ,
1011957 q_min = safety_factor_fit_outputs .q_min ,
1012958 rho_q_3_2_first = safety_factor_fit_outputs .rho_q_3_2_first ,
0 commit comments