@@ -464,13 +464,8 @@ function thermo_state_gs(
464464)
465465 e_int = specific(ᶜY. ρe_tot, ᶜY. ρ) - K - Φ
466466 T = TD. air_temperature(thermo_params, e_int)
467- FT = eltype(thermo_params)
468- return (;
469- T = T,
470- q_tot_safe = zero(FT),
471- q_liq_rai = zero(FT),
472- q_ice_sno = zero(FT),
473- )
467+ # Use zero(T) to match type of T (supports autodiff with dual numbers)
468+ return ThermoState(T, zero(T), zero(T), zero(T))
474469end
475470
476471function thermo_state_gs(
@@ -486,12 +481,7 @@ function thermo_state_gs(
486481 q_tot = specific(ᶜY. ρq_tot, ᶜY. ρ)
487482 # Use saturation_adjustment once for all quantities
488483 sa_result = TD. saturation_adjustment(thermo_params, TD. ρe(), ρ, e_int, q_tot)
489- return (;
490- T = sa_result. T,
491- q_tot_safe = max(0 , q_tot),
492- q_liq_rai = sa_result. q_liq,
493- q_ice_sno = sa_result. q_ice,
494- )
484+ return ThermoState(sa_result. T, max(0 , q_tot), sa_result. q_liq, sa_result. q_ice)
495485end
496486
497487function thermo_state_gs(
@@ -508,12 +498,7 @@ function thermo_state_gs(
508498 q_liq_rai = specific(ᶜY. ρq_liq, ᶜY. ρ) + specific(ᶜY. ρq_rai, ᶜY. ρ)
509499 q_ice_sno = specific(ᶜY. ρq_ice, ᶜY. ρ) + specific(ᶜY. ρq_sno, ᶜY. ρ)
510500 T = TD. air_temperature(thermo_params, e_int, q_tot, q_liq_rai, q_ice_sno)
511- return (;
512- T = T,
513- q_tot_safe = max(0 , q_tot),
514- q_liq_rai = max(0 , q_liq_rai),
515- q_ice_sno = max(0 , q_ice_sno),
516- )
501+ return ThermoState(T, max(0 , q_tot), max(0 , q_liq_rai), max(0 , q_ice_sno))
517502end
518503
519504function eddy_diffusivity_coefficient_H(D₀, H, z_sfc, z)
@@ -526,6 +511,11 @@ function eddy_diffusivity_coefficient(C_E, norm_v_a, z_a, p)
526511 return p > p_pbl ? K_E : K_E * exp(- ((p_pbl - p) / p_strato)^ 2 )
527512end
528513
514+ @inline get_thermo_state_value(thermo_state:: ThermoState , :: Val{:T} ) = thermo_state. T
515+ @inline get_thermo_state_value(thermo_state:: ThermoState , :: Val{:q_tot_safe} ) = thermo_state. q_tot_safe
516+ @inline get_thermo_state_value(thermo_state:: ThermoState , :: Val{:q_liq_rai} ) = thermo_state. q_liq_rai
517+ @inline get_thermo_state_value(thermo_state:: ThermoState , :: Val{:q_ice_sno} ) = thermo_state. q_ice_sno
518+
529519"""
530520 set_implicit_precomputed_quantities!(Y, p, t)
531521
@@ -576,13 +566,13 @@ NVTX.@annotate function set_implicit_precomputed_quantities!(Y, p, t)
576566 # @. ᶜK += Y.c.ρtke / Y.c.ρ
577567 # TODO : We should think more about these increments before we use them.
578568 end
579- # Compute all thermodynamic state variables in one call (avoids multiple saturation_adjustment calls)
580- ᶜthermo_state = p . scratch . ᶜtemp_thermo_state
581- @. ᶜthermo_state = thermo_state_gs(thermo_args... , Y. c, ᶜK, ᶜΦ, Y. c. ρ)
582- @. ᶜT = ᶜthermo_state. T
583- @. ᶜq_tot_safe = ᶜthermo_state. q_tot_safe
584- @. ᶜq_liq_rai = ᶜthermo_state. q_liq_rai
585- @. ᶜq_ice_sno = ᶜthermo_state. q_ice_sno
569+ # Compute thermodynamic state variables
570+ # Note: For EquilMoistModel, this calls saturation_adjustment 3 times per grid point.
571+ ᶜthermo_state = @. lazy( thermo_state_gs(thermo_args... , Y. c, ᶜK, ᶜΦ, Y. c. ρ) )
572+ @. ᶜT = get_thermo_state_value( ᶜthermo_state, Val(:T))
573+ @. ᶜq_tot_safe = get_thermo_state_value( ᶜthermo_state, Val(: q_tot_safe))
574+ @. ᶜq_liq_rai = get_thermo_state_value( ᶜthermo_state, Val(: q_liq_rai))
575+ @. ᶜq_ice_sno = get_thermo_state_value( ᶜthermo_state, Val(: q_ice_sno))
586576 ᶜe_tot = @. lazy(specific(Y. c. ρe_tot, Y. c. ρ))
587577 @. ᶜh_tot = TD. total_enthalpy(thermo_params, ᶜe_tot, ᶜT, ᶜq_tot_safe, ᶜq_liq_rai, ᶜq_ice_sno)
588578 @. ᶜp = TD. air_pressure(thermo_params, ᶜT, Y. c. ρ, ᶜq_tot_safe, ᶜq_liq_rai, ᶜq_ice_sno)
0 commit comments