Skip to content

Commit 156a715

Browse files
Merge pull request #2926 from CliMA/ck/gpu_prog_edmf
Hoist BidiagonalMatrixRow expressions for prog edmf
2 parents 61c6dc0 + d266493 commit 156a715

File tree

3 files changed

+19
-6
lines changed

3 files changed

+19
-6
lines changed

.buildkite/pipeline.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,6 @@ steps:
762762
agents:
763763
slurm_gpus: 1
764764
slurm_mem: 20G
765-
soft_fail: true
766765

767766
- group: "GPU Performance"
768767
steps:

src/cache/precomputed_quantities.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ function precomputed_quantities(Y, atmos)
9393
ᶜKʲs = similar(Y.c, NTuple{n, FT}),
9494
ᶠKᵥʲs = similar(Y.f, NTuple{n, FT}),
9595
ᶜtsʲs = similar(Y.c, NTuple{n, TST}),
96+
bdmr_l = similar(Y.c, BidiagonalMatrixRow{FT}),
97+
bdmr_r = similar(Y.c, BidiagonalMatrixRow{FT}),
98+
bdmr = similar(Y.c, BidiagonalMatrixRow{FT}),
9699
ᶜρʲs = similar(Y.c, NTuple{n, FT}),
97100
ᶜentrʲs = similar(Y.c, NTuple{n, FT}),
98101
ᶜdetrʲs = similar(Y.c, NTuple{n, FT}),

src/prognostic_equations/implicit/implicit_solver.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,9 @@ NVTX.@annotate function Wfact!(A, Y, p, dtγ, t)
429429
p.precomputed.ᶜρʲs,
430430
p.precomputed.ᶠu³ʲs,
431431
p.precomputed.ᶜtsʲs,
432+
p.precomputed.bdmr_l,
433+
p.precomputed.bdmr_r,
434+
p.precomputed.bdmr,
432435
) : (;)
433436
)...,
434437
p.core.ᶜΦ,
@@ -724,6 +727,7 @@ function update_implicit_equation_jacobian!(A, Y, p, dtγ, colidx)
724727
if p.atmos.turbconv_model isa PrognosticEDMFX
725728
if use_derivative(sgs_advection_flag)
726729
(; ᶜgradᵥ_ᶠΦ, ᶜρʲs, ᶠu³ʲs, ᶜtsʲs) = p
730+
(; bdmr_l, bdmr_r, bdmr) = p
727731
is_third_order = edmfx_upwinding == Val(:third_order)
728732
ᶠupwind = is_third_order ? ᶠupwind3 : ᶠupwind1
729733
ᶠset_upwind_bcs = Operators.SetBoundaryOperator(;
@@ -884,12 +888,19 @@ function update_implicit_equation_jacobian!(A, Y, p, dtγ, colidx)
884888
matrix[@name(f.sgsʲs.:(1).u₃), @name(f.sgsʲs.:(1).u₃)]
885889
ᶜu₃ʲ = ᶜtemp_C3
886890
@. ᶜu₃ʲ[colidx] = ᶜinterp(Y.f.sgsʲs.:(1).u₃[colidx])
891+
892+
@. bdmr_l[colidx] =
893+
convert(BidiagonalMatrixRow{FT}, ᶜleft_bias_matrix())
894+
@. bdmr_r[colidx] =
895+
convert(BidiagonalMatrixRow{FT}, ᶜright_bias_matrix())
896+
@. bdmr[colidx] = ifelse(
897+
ᶜu₃ʲ[colidx].components.data.:1 > 0,
898+
bdmr_l[colidx],
899+
bdmr_r[colidx],
900+
)
901+
887902
@. ᶠtridiagonal_matrix_c3[colidx] =
888-
-(ᶠgradᵥ_matrix()) ifelse(
889-
ᶜu₃ʲ[colidx].components.data.:1 > 0,
890-
convert(BidiagonalMatrixRow{FT}, ᶜleft_bias_matrix()),
891-
convert(BidiagonalMatrixRow{FT}, ᶜright_bias_matrix()),
892-
)
903+
-(ᶠgradᵥ_matrix()) bdmr[colidx]
893904
if p.atmos.rayleigh_sponge isa RayleighSponge
894905
@. ∂ᶠu₃ʲ_err_∂ᶠu₃ʲ[colidx] =
895906
dtγ * (

0 commit comments

Comments
 (0)