Skip to content

Commit 4fd4cd6

Browse files
committed
autodiff ahhhhh
1 parent 934de93 commit 4fd4cd6

File tree

3 files changed

+26
-32
lines changed

3 files changed

+26
-32
lines changed

src/cache/precomputed_quantities.jl

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -464,13 +464,8 @@ function thermo_state_gs(
464464
)
465465
e_int = specific(ᶜY.ρe_tot, ᶜY.ρ) - K - Φ
466466
T = TD.air_temperature(thermo_params, e_int)
467-
FT = eltype(thermo_params)
468-
return (;
469-
T = T,
470-
q_tot_safe = zero(FT),
471-
q_liq_rai = zero(FT),
472-
q_ice_sno = zero(FT),
473-
)
467+
# Use zero(T) to match type of T (supports autodiff with dual numbers)
468+
return ThermoState(T, zero(T), zero(T), zero(T))
474469
end
475470

476471
function thermo_state_gs(
@@ -486,12 +481,7 @@ function thermo_state_gs(
486481
q_tot = specific(ᶜY.ρq_tot, ᶜY.ρ)
487482
# Use saturation_adjustment once for all quantities
488483
sa_result = TD.saturation_adjustment(thermo_params, TD.ρe(), ρ, e_int, q_tot)
489-
return (;
490-
T = sa_result.T,
491-
q_tot_safe = max(0, q_tot),
492-
q_liq_rai = sa_result.q_liq,
493-
q_ice_sno = sa_result.q_ice,
494-
)
484+
return ThermoState(sa_result.T, max(0, q_tot), sa_result.q_liq, sa_result.q_ice)
495485
end
496486

497487
function thermo_state_gs(
@@ -508,12 +498,7 @@ function thermo_state_gs(
508498
q_liq_rai = specific(ᶜY.ρq_liq, ᶜY.ρ) + specific(ᶜY.ρq_rai, ᶜY.ρ)
509499
q_ice_sno = specific(ᶜY.ρq_ice, ᶜY.ρ) + specific(ᶜY.ρq_sno, ᶜY.ρ)
510500
T = TD.air_temperature(thermo_params, e_int, q_tot, q_liq_rai, q_ice_sno)
511-
return (;
512-
T = T,
513-
q_tot_safe = max(0, q_tot),
514-
q_liq_rai = max(0, q_liq_rai),
515-
q_ice_sno = max(0, q_ice_sno),
516-
)
501+
return ThermoState(T, max(0, q_tot), max(0, q_liq_rai), max(0, q_ice_sno))
517502
end
518503

519504
function eddy_diffusivity_coefficient_H(D₀, H, z_sfc, z)
@@ -526,6 +511,11 @@ function eddy_diffusivity_coefficient(C_E, norm_v_a, z_a, p)
526511
return p > p_pbl ? K_E : K_E * exp(-((p_pbl - p) / p_strato)^2)
527512
end
528513

514+
@inline get_thermo_state_value(thermo_state::ThermoState, ::Val{:T}) = thermo_state.T
515+
@inline get_thermo_state_value(thermo_state::ThermoState, ::Val{:q_tot_safe}) = thermo_state.q_tot_safe
516+
@inline get_thermo_state_value(thermo_state::ThermoState, ::Val{:q_liq_rai}) = thermo_state.q_liq_rai
517+
@inline get_thermo_state_value(thermo_state::ThermoState, ::Val{:q_ice_sno}) = thermo_state.q_ice_sno
518+
529519
"""
530520
set_implicit_precomputed_quantities!(Y, p, t)
531521
@@ -576,13 +566,13 @@ NVTX.@annotate function set_implicit_precomputed_quantities!(Y, p, t)
576566
# @. ᶜK += Y.c.ρtke / Y.c.ρ
577567
# TODO: We should think more about these increments before we use them.
578568
end
579-
# Compute all thermodynamic state variables in one call (avoids multiple saturation_adjustment calls)
580-
ᶜthermo_state = p.scratch.ᶜtemp_thermo_state
581-
@. ᶜthermo_state = thermo_state_gs(thermo_args..., Y.c, ᶜK, ᶜΦ, Y.c.ρ)
582-
@. ᶜT = ᶜthermo_state.T
583-
@. ᶜq_tot_safe = ᶜthermo_state.q_tot_safe
584-
@. ᶜq_liq_rai = ᶜthermo_state.q_liq_rai
585-
@. ᶜq_ice_sno = ᶜthermo_state.q_ice_sno
569+
# Compute thermodynamic state variables
570+
# Note: For EquilMoistModel, this calls saturation_adjustment 3 times per grid point.
571+
ᶜthermo_state = @. lazy(thermo_state_gs(thermo_args..., Y.c, ᶜK, ᶜΦ, Y.c.ρ))
572+
@. ᶜT = get_thermo_state_value(ᶜthermo_state, Val(:T))
573+
@. ᶜq_tot_safe = get_thermo_state_value(ᶜthermo_state, Val(:q_tot_safe))
574+
@. ᶜq_liq_rai = get_thermo_state_value(ᶜthermo_state, Val(:q_liq_rai))
575+
@. ᶜq_ice_sno = get_thermo_state_value(ᶜthermo_state, Val(:q_ice_sno))
586576
ᶜe_tot = @. lazy(specific(Y.c.ρe_tot, Y.c.ρ))
587577
@. ᶜh_tot = TD.total_enthalpy(thermo_params, ᶜe_tot, ᶜT, ᶜq_tot_safe, ᶜq_liq_rai, ᶜq_ice_sno)
588578
@. ᶜp = TD.air_pressure(thermo_params, ᶜT, Y.c.ρ, ᶜq_tot_safe, ᶜq_liq_rai, ᶜq_ice_sno)

src/cache/temporary_quantities.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,6 @@ function temporary_quantities(Y, atmos)
2929
3030
FT = Spaces.undertype(center_space)
3131
uvw_vec = UVW(FT(0), FT(0), FT(0))
32-
# Type for thermo state: (T, q_tot_safe, q_liq_rai, q_ice_sno)
33-
ThermoStateType = NamedTuple{
34-
(:T, :q_tot_safe, :q_liq_rai, :q_ice_sno),
35-
NTuple{4, FT},
36-
}
3732
return (;
3833
ᶠtemp_scalar = Fields.Field(FT, face_space), # ᶠp, ᶠρK_h
3934
ᶠtemp_scalar_2 = Fields.Field(FT, face_space), # ᶠρK_u
@@ -44,7 +39,6 @@ function temporary_quantities(Y, atmos)
4439
ᶜtemp_scalar_5 = Fields.Field(FT, center_space),
4540
ᶜtemp_scalar_6 = Fields.Field(FT, center_space),
4641
ᶜtemp_scalar_7 = Fields.Field(FT, center_space),
47-
ᶜtemp_thermo_state = Fields.Field(ThermoStateType, center_space), # thermo_state_gs result
4842
ᶠtemp_field_level = Fields.level(Fields.Field(FT, face_space), half),
4943
temp_field_level = Fields.level(Fields.Field(FT, center_space), 1),
5044
temp_field_level_2 = Fields.level(Fields.Field(FT, center_space), 1),

src/solver/types.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,16 @@ function MixingLength(master, wall, tke, buoy, l_grid)
379379
return MixingLength(promote(master, wall, tke, buoy, l_grid)...)
380380
end
381381

382+
struct ThermoState{FT}
383+
T::FT
384+
q_tot_safe::FT
385+
q_liq_rai::FT
386+
q_ice_sno::FT
387+
end
388+
389+
function ThermoState(T, q_tot_safe, q_liq_rai, q_ice_sno)
390+
return ThermoState(promote(T, q_tot_safe, q_liq_rai, q_ice_sno)...)
391+
end
382392

383393
abstract type AbstractEDMF end
384394

0 commit comments

Comments
 (0)