Skip to content

Commit 953e904

Browse files
committed
autodiff and allocation ahhhhh
1 parent 934de93 commit 953e904

File tree

3 files changed

+69
-75
lines changed

3 files changed

+69
-75
lines changed

src/cache/precomputed_quantities.jl

Lines changed: 69 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -438,82 +438,83 @@ function thermo_vars(::NonEquilMoistModel, ::Any, ᶜY, K, Φ)
438438
return (; e_int, q_pt = TD.PhasePartition(q_pt_args...))
439439
end
440440

441-
"""
442-
thermo_state_gs(thermo_params, moisture_model, microphysics_model, ᶜY, K, Φ, ρ)
441+
# Individual getter functions for thermodynamic state variables.
442+
# These avoid struct construction which can cause allocations on GPU.
443443

444-
Compute grid-scale thermodynamic state variables in one call.
444+
# Temperature getters
445+
function T_gs(thermo_params, ::DryModel, ::Any, ᶜY, K, Φ, ρ)
446+
e_int = specific(ᶜY.ρe_tot, ᶜY.ρ) - K - Φ
447+
return TD.air_temperature(thermo_params, e_int)
448+
end
445449

446-
Returns a NamedTuple with:
447-
- `T`: Temperature
448-
- `q_tot_safe`: Total specific humidity (clipped to non-negative)
449-
- `q_liq_rai`: Liquid + rain specific humidity
450-
- `q_ice_sno`: Ice + snow specific humidity
450+
function T_gs(thermo_params, ::EquilMoistModel, ::Any, ᶜY, K, Φ, ρ)
451+
e_int = specific(ᶜY.ρe_tot, ᶜY.ρ) - K - Φ
452+
q_tot = specific(ᶜY.ρq_tot, ᶜY.ρ)
453+
sa_result = TD.saturation_adjustment(thermo_params, TD.ρe(), ρ, e_int, q_tot)
454+
return sa_result.T
455+
end
451456

452-
For DryModel: T from e_int, all q's are zero
453-
For EquilMoistModel: Uses `saturation_adjustment` once for all quantities
454-
For NonEquilMoistModel: T from e_int and prognostic q's, q's from prognostic values
455-
"""
456-
function thermo_state_gs(
457-
thermo_params,
458-
moisture_model::DryModel,
459-
microphysics_model,
460-
ᶜY,
461-
K,
462-
Φ,
463-
ρ,
464-
)
457+
function T_gs(thermo_params, ::NonEquilMoistModel, ::Any, ᶜY, K, Φ, ρ)
458+
e_int = specific(ᶜY.ρe_tot, ᶜY.ρ) - K - Φ
459+
q_tot = specific(ᶜY.ρq_tot, ᶜY.ρ)
460+
q_liq_rai = specific(ᶜY.ρq_liq, ᶜY.ρ) + specific(ᶜY.ρq_rai, ᶜY.ρ)
461+
q_ice_sno = specific(ᶜY.ρq_ice, ᶜY.ρ) + specific(ᶜY.ρq_sno, ᶜY.ρ)
462+
return TD.air_temperature(thermo_params, e_int, q_tot, q_liq_rai, q_ice_sno)
463+
end
464+
465+
# q_tot_safe getters
466+
function q_tot_safe_gs(thermo_params, ::DryModel, ::Any, ᶜY, K, Φ, ρ)
465467
e_int = specific(ᶜY.ρe_tot, ᶜY.ρ) - K - Φ
466468
T = TD.air_temperature(thermo_params, e_int)
467-
FT = eltype(thermo_params)
468-
return (;
469-
T = T,
470-
q_tot_safe = zero(FT),
471-
q_liq_rai = zero(FT),
472-
q_ice_sno = zero(FT),
473-
)
469+
return zero(T) # Use zero(T) for autodiff compatibility
474470
end
475471

476-
function thermo_state_gs(
477-
thermo_params,
478-
moisture_model::EquilMoistModel,
479-
microphysics_model,
480-
ᶜY,
481-
K,
482-
Φ,
483-
ρ,
484-
)
472+
function q_tot_safe_gs(thermo_params, ::EquilMoistModel, ::Any, ᶜY, K, Φ, ρ)
473+
q_tot = specific(ᶜY.ρq_tot, ᶜY.ρ)
474+
return max(0, q_tot)
475+
end
476+
477+
function q_tot_safe_gs(thermo_params, ::NonEquilMoistModel, ::Any, ᶜY, K, Φ, ρ)
478+
q_tot = specific(ᶜY.ρq_tot, ᶜY.ρ)
479+
return max(0, q_tot)
480+
end
481+
482+
# q_liq_rai getters
483+
function q_liq_rai_gs(thermo_params, ::DryModel, ::Any, ᶜY, K, Φ, ρ)
484+
e_int = specific(ᶜY.ρe_tot, ᶜY.ρ) - K - Φ
485+
T = TD.air_temperature(thermo_params, e_int)
486+
return zero(T) # Use zero(T) for autodiff compatibility
487+
end
488+
489+
function q_liq_rai_gs(thermo_params, ::EquilMoistModel, ::Any, ᶜY, K, Φ, ρ)
485490
e_int = specific(ᶜY.ρe_tot, ᶜY.ρ) - K - Φ
486491
q_tot = specific(ᶜY.ρq_tot, ᶜY.ρ)
487-
# Use saturation_adjustment once for all quantities
488492
sa_result = TD.saturation_adjustment(thermo_params, TD.ρe(), ρ, e_int, q_tot)
489-
return (;
490-
T = sa_result.T,
491-
q_tot_safe = max(0, q_tot),
492-
q_liq_rai = sa_result.q_liq,
493-
q_ice_sno = sa_result.q_ice,
494-
)
493+
return sa_result.q_liq
495494
end
496495

497-
function thermo_state_gs(
498-
thermo_params,
499-
moisture_model::NonEquilMoistModel,
500-
microphysics_model,
501-
ᶜY,
502-
K,
503-
Φ,
504-
ρ,
505-
)
496+
function q_liq_rai_gs(thermo_params, ::NonEquilMoistModel, ::Any, ᶜY, K, Φ, ρ)
497+
q_liq_rai = specific(ᶜY.ρq_liq, ᶜY.ρ) + specific(ᶜY.ρq_rai, ᶜY.ρ)
498+
return max(0, q_liq_rai)
499+
end
500+
501+
# q_ice_sno getters
502+
function q_ice_sno_gs(thermo_params, ::DryModel, ::Any, ᶜY, K, Φ, ρ)
503+
e_int = specific(ᶜY.ρe_tot, ᶜY.ρ) - K - Φ
504+
T = TD.air_temperature(thermo_params, e_int)
505+
return zero(T) # Use zero(T) for autodiff compatibility
506+
end
507+
508+
function q_ice_sno_gs(thermo_params, ::EquilMoistModel, ::Any, ᶜY, K, Φ, ρ)
506509
e_int = specific(ᶜY.ρe_tot, ᶜY.ρ) - K - Φ
507510
q_tot = specific(ᶜY.ρq_tot, ᶜY.ρ)
508-
q_liq_rai = specific(ᶜY.ρq_liq, ᶜY.ρ) + specific(ᶜY.ρq_rai, ᶜY.ρ)
511+
sa_result = TD.saturation_adjustment(thermo_params, TD.ρe(), ρ, e_int, q_tot)
512+
return sa_result.q_ice
513+
end
514+
515+
function q_ice_sno_gs(thermo_params, ::NonEquilMoistModel, ::Any, ᶜY, K, Φ, ρ)
509516
q_ice_sno = specific(ᶜY.ρq_ice, ᶜY.ρ) + specific(ᶜY.ρq_sno, ᶜY.ρ)
510-
T = TD.air_temperature(thermo_params, e_int, q_tot, q_liq_rai, q_ice_sno)
511-
return (;
512-
T = T,
513-
q_tot_safe = max(0, q_tot),
514-
q_liq_rai = max(0, q_liq_rai),
515-
q_ice_sno = max(0, q_ice_sno),
516-
)
517+
return max(0, q_ice_sno)
517518
end
518519

519520
function eddy_diffusivity_coefficient_H(D₀, H, z_sfc, z)
@@ -576,13 +577,13 @@ NVTX.@annotate function set_implicit_precomputed_quantities!(Y, p, t)
576577
# @. ᶜK += Y.c.ρtke / Y.c.ρ
577578
# TODO: We should think more about these increments before we use them.
578579
end
579-
# Compute all thermodynamic state variables in one call (avoids multiple saturation_adjustment calls)
580-
ᶜthermo_state = p.scratch.ᶜtemp_thermo_state
581-
@. ᶜthermo_state = thermo_state_gs(thermo_args..., Y.c, ᶜK, ᶜΦ, Y.c.ρ)
582-
@. ᶜT = ᶜthermo_state.T
583-
@. ᶜq_tot_safe = ᶜthermo_state.q_tot_safe
584-
@. ᶜq_liq_rai = ᶜthermo_state.q_liq_rai
585-
@. ᶜq_ice_sno = ᶜthermo_state.q_ice_sno
580+
# Compute thermodynamic state variables using individual getter functions.
581+
# Note: For EquilMoistModel, this calls saturation_adjustment 3 times per grid point
582+
# (T, q_liq_rai, q_ice_sno each call it; q_tot_safe doesn't need it).
583+
@. ᶜT = T_gs(thermo_args..., Y.c, ᶜK, ᶜΦ, Y.c.ρ)
584+
@. ᶜq_tot_safe = q_tot_safe_gs(thermo_args..., Y.c, ᶜK, ᶜΦ, Y.c.ρ)
585+
@. ᶜq_liq_rai = q_liq_rai_gs(thermo_args..., Y.c, ᶜK, ᶜΦ, Y.c.ρ)
586+
@. ᶜq_ice_sno = q_ice_sno_gs(thermo_args..., Y.c, ᶜK, ᶜΦ, Y.c.ρ)
586587
ᶜe_tot = @. lazy(specific(Y.c.ρe_tot, Y.c.ρ))
587588
@. ᶜh_tot = TD.total_enthalpy(thermo_params, ᶜe_tot, ᶜT, ᶜq_tot_safe, ᶜq_liq_rai, ᶜq_ice_sno)
588589
@. ᶜp = TD.air_pressure(thermo_params, ᶜT, Y.c.ρ, ᶜq_tot_safe, ᶜq_liq_rai, ᶜq_ice_sno)

src/cache/temporary_quantities.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,6 @@ function temporary_quantities(Y, atmos)
2929
3030
FT = Spaces.undertype(center_space)
3131
uvw_vec = UVW(FT(0), FT(0), FT(0))
32-
# Type for thermo state: (T, q_tot_safe, q_liq_rai, q_ice_sno)
33-
ThermoStateType = NamedTuple{
34-
(:T, :q_tot_safe, :q_liq_rai, :q_ice_sno),
35-
NTuple{4, FT},
36-
}
3732
return (;
3833
ᶠtemp_scalar = Fields.Field(FT, face_space), # ᶠp, ᶠρK_h
3934
ᶠtemp_scalar_2 = Fields.Field(FT, face_space), # ᶠρK_u
@@ -44,7 +39,6 @@ function temporary_quantities(Y, atmos)
4439
ᶜtemp_scalar_5 = Fields.Field(FT, center_space),
4540
ᶜtemp_scalar_6 = Fields.Field(FT, center_space),
4641
ᶜtemp_scalar_7 = Fields.Field(FT, center_space),
47-
ᶜtemp_thermo_state = Fields.Field(ThermoStateType, center_space), # thermo_state_gs result
4842
ᶠtemp_field_level = Fields.level(Fields.Field(FT, face_space), half),
4943
temp_field_level = Fields.level(Fields.Field(FT, center_space), 1),
5044
temp_field_level_2 = Fields.level(Fields.Field(FT, center_space), 1),

src/solver/types.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,6 @@ function MixingLength(master, wall, tke, buoy, l_grid)
379379
return MixingLength(promote(master, wall, tke, buoy, l_grid)...)
380380
end
381381

382-
383382
abstract type AbstractEDMF end
384383

385384
"""

0 commit comments

Comments
 (0)