@@ -438,82 +438,83 @@ function thermo_vars(::NonEquilMoistModel, ::Any, ᶜY, K, Φ)
438438 return (; e_int, q_pt = TD. PhasePartition(q_pt_args... ))
439439end
440440
441- """
442- thermo_state_gs(thermo_params, moisture_model, microphysics_model, ᶜY, K, Φ, ρ)
441+ # Individual getter functions for thermodynamic state variables.
442+ # These avoid struct construction which can cause allocations on GPU.
443443
444- Compute grid-scale thermodynamic state variables in one call.
444+ # Temperature getters
445+ function T_gs(thermo_params, :: DryModel , :: Any , ᶜY, K, Φ, ρ)
446+ e_int = specific(ᶜY. ρe_tot, ᶜY. ρ) - K - Φ
447+ return TD. air_temperature(thermo_params, e_int)
448+ end
445449
446- Returns a NamedTuple with:
447- - `T`: Temperature
448- - `q_tot_safe`: Total specific humidity (clipped to non-negative)
449- - `q_liq_rai`: Liquid + rain specific humidity
450- - `q_ice_sno`: Ice + snow specific humidity
450+ function T_gs(thermo_params, :: EquilMoistModel , :: Any , ᶜY, K, Φ, ρ)
451+ e_int = specific(ᶜY. ρe_tot, ᶜY. ρ) - K - Φ
452+ q_tot = specific(ᶜY. ρq_tot, ᶜY. ρ)
453+ sa_result = TD. saturation_adjustment(thermo_params, TD. ρe(), ρ, e_int, q_tot)
454+ return sa_result. T
455+ end
451456
452- For DryModel: T from e_int, all q's are zero
453- For EquilMoistModel: Uses `saturation_adjustment` once for all quantities
454- For NonEquilMoistModel: T from e_int and prognostic q's, q's from prognostic values
455- """
456- function thermo_state_gs(
457- thermo_params,
458- moisture_model:: DryModel ,
459- microphysics_model,
460- ᶜY,
461- K,
462- Φ,
463- ρ,
464- )
457+ function T_gs(thermo_params, :: NonEquilMoistModel , :: Any , ᶜY, K, Φ, ρ)
458+ e_int = specific(ᶜY. ρe_tot, ᶜY. ρ) - K - Φ
459+ q_tot = specific(ᶜY. ρq_tot, ᶜY. ρ)
460+ q_liq_rai = specific(ᶜY. ρq_liq, ᶜY. ρ) + specific(ᶜY. ρq_rai, ᶜY. ρ)
461+ q_ice_sno = specific(ᶜY. ρq_ice, ᶜY. ρ) + specific(ᶜY. ρq_sno, ᶜY. ρ)
462+ return TD. air_temperature(thermo_params, e_int, q_tot, q_liq_rai, q_ice_sno)
463+ end
464+
465+ # q_tot_safe getters
466+ function q_tot_safe_gs(thermo_params, :: DryModel , :: Any , ᶜY, K, Φ, ρ)
465467 e_int = specific(ᶜY. ρe_tot, ᶜY. ρ) - K - Φ
466468 T = TD. air_temperature(thermo_params, e_int)
467- FT = eltype(thermo_params)
468- return (;
469- T = T,
470- q_tot_safe = zero(FT),
471- q_liq_rai = zero(FT),
472- q_ice_sno = zero(FT),
473- )
469+ return zero(T) # Use zero(T) for autodiff compatibility
474470end
475471
476- function thermo_state_gs(
477- thermo_params,
478- moisture_model:: EquilMoistModel ,
479- microphysics_model,
480- ᶜY,
481- K,
482- Φ,
483- ρ,
484- )
472+ function q_tot_safe_gs(thermo_params, :: EquilMoistModel , :: Any , ᶜY, K, Φ, ρ)
473+ q_tot = specific(ᶜY. ρq_tot, ᶜY. ρ)
474+ return max(0 , q_tot)
475+ end
476+
477+ function q_tot_safe_gs(thermo_params, :: NonEquilMoistModel , :: Any , ᶜY, K, Φ, ρ)
478+ q_tot = specific(ᶜY. ρq_tot, ᶜY. ρ)
479+ return max(0 , q_tot)
480+ end
481+
482+ # q_liq_rai getters
483+ function q_liq_rai_gs(thermo_params, :: DryModel , :: Any , ᶜY, K, Φ, ρ)
484+ e_int = specific(ᶜY. ρe_tot, ᶜY. ρ) - K - Φ
485+ T = TD. air_temperature(thermo_params, e_int)
486+ return zero(T) # Use zero(T) for autodiff compatibility
487+ end
488+
489+ function q_liq_rai_gs(thermo_params, :: EquilMoistModel , :: Any , ᶜY, K, Φ, ρ)
485490 e_int = specific(ᶜY. ρe_tot, ᶜY. ρ) - K - Φ
486491 q_tot = specific(ᶜY. ρq_tot, ᶜY. ρ)
487- # Use saturation_adjustment once for all quantities
488492 sa_result = TD. saturation_adjustment(thermo_params, TD. ρe(), ρ, e_int, q_tot)
489- return (;
490- T = sa_result. T,
491- q_tot_safe = max(0 , q_tot),
492- q_liq_rai = sa_result. q_liq,
493- q_ice_sno = sa_result. q_ice,
494- )
493+ return sa_result. q_liq
495494end
496495
497- function thermo_state_gs(
498- thermo_params,
499- moisture_model:: NonEquilMoistModel ,
500- microphysics_model,
501- ᶜY,
502- K,
503- Φ,
504- ρ,
505- )
496+ function q_liq_rai_gs(thermo_params, :: NonEquilMoistModel , :: Any , ᶜY, K, Φ, ρ)
497+ q_liq_rai = specific(ᶜY. ρq_liq, ᶜY. ρ) + specific(ᶜY. ρq_rai, ᶜY. ρ)
498+ return max(0 , q_liq_rai)
499+ end
500+
501+ # q_ice_sno getters
502+ function q_ice_sno_gs(thermo_params, :: DryModel , :: Any , ᶜY, K, Φ, ρ)
503+ e_int = specific(ᶜY. ρe_tot, ᶜY. ρ) - K - Φ
504+ T = TD. air_temperature(thermo_params, e_int)
505+ return zero(T) # Use zero(T) for autodiff compatibility
506+ end
507+
508+ function q_ice_sno_gs(thermo_params, :: EquilMoistModel , :: Any , ᶜY, K, Φ, ρ)
506509 e_int = specific(ᶜY. ρe_tot, ᶜY. ρ) - K - Φ
507510 q_tot = specific(ᶜY. ρq_tot, ᶜY. ρ)
508- q_liq_rai = specific(ᶜY. ρq_liq, ᶜY. ρ) + specific(ᶜY. ρq_rai, ᶜY. ρ)
511+ sa_result = TD. saturation_adjustment(thermo_params, TD. ρe(), ρ, e_int, q_tot)
512+ return sa_result. q_ice
513+ end
514+
515+ function q_ice_sno_gs(thermo_params, :: NonEquilMoistModel , :: Any , ᶜY, K, Φ, ρ)
509516 q_ice_sno = specific(ᶜY. ρq_ice, ᶜY. ρ) + specific(ᶜY. ρq_sno, ᶜY. ρ)
510- T = TD. air_temperature(thermo_params, e_int, q_tot, q_liq_rai, q_ice_sno)
511- return (;
512- T = T,
513- q_tot_safe = max(0 , q_tot),
514- q_liq_rai = max(0 , q_liq_rai),
515- q_ice_sno = max(0 , q_ice_sno),
516- )
517+ return max(0 , q_ice_sno)
517518end
518519
519520function eddy_diffusivity_coefficient_H(D₀, H, z_sfc, z)
@@ -576,13 +577,13 @@ NVTX.@annotate function set_implicit_precomputed_quantities!(Y, p, t)
576577 # @. ᶜK += Y.c.ρtke / Y.c.ρ
577578 # TODO : We should think more about these increments before we use them.
578579 end
579- # Compute all thermodynamic state variables in one call (avoids multiple saturation_adjustment calls)
580- ᶜthermo_state = p . scratch . ᶜtemp_thermo_state
581- @. ᶜthermo_state = thermo_state_gs(thermo_args ... , Y . c, ᶜK, ᶜΦ, Y . c . ρ)
582- @. ᶜT = ᶜthermo_state . T
583- @. ᶜq_tot_safe = ᶜthermo_state . q_tot_safe
584- @. ᶜq_liq_rai = ᶜthermo_state . q_liq_rai
585- @. ᶜq_ice_sno = ᶜthermo_state . q_ice_sno
580+ # Compute thermodynamic state variables using individual getter functions.
581+ # Note: For EquilMoistModel, this calls saturation_adjustment 3 times per grid point
582+ # (T, q_liq_rai, q_ice_sno each call it; q_tot_safe doesn't need it).
583+ @. ᶜT = T_gs(thermo_args ... , Y . c, ᶜK, ᶜΦ, Y . c . ρ)
584+ @. ᶜq_tot_safe = q_tot_safe_gs(thermo_args ... , Y . c, ᶜK, ᶜΦ, Y . c . ρ)
585+ @. ᶜq_liq_rai = q_liq_rai_gs(thermo_args ... , Y . c, ᶜK, ᶜΦ, Y . c . ρ)
586+ @. ᶜq_ice_sno = q_ice_sno_gs(thermo_args ... , Y . c, ᶜK, ᶜΦ, Y . c . ρ)
586587 ᶜe_tot = @. lazy(specific(Y. c. ρe_tot, Y. c. ρ))
587588 @. ᶜh_tot = TD. total_enthalpy(thermo_params, ᶜe_tot, ᶜT, ᶜq_tot_safe, ᶜq_liq_rai, ᶜq_ice_sno)
588589 @. ᶜp = TD. air_pressure(thermo_params, ᶜT, Y. c. ρ, ᶜq_tot_safe, ᶜq_liq_rai, ᶜq_ice_sno)
0 commit comments