Skip to content

Commit a256536

Browse files
Merge pull request #223 from CliMA/ck/rm_nvtx
Add `post_explicit!` call to end of step, rm nvtx range macros
2 parents efec3c3 + 57c3828 commit a256536

File tree

3 files changed

+90
-132
lines changed

3 files changed

+90
-132
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ClimaTimeSteppers"
22
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
33
authors = ["Climate Modeling Alliance"]
4-
version = "0.7.11"
4+
version = "0.7.12"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

src/solvers/imex_ark.jl

Lines changed: 88 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -59,177 +59,134 @@ function step_u!(integrator, cache::IMEXARKCache, f, name)
5959
s = length(b_exp)
6060

6161
if !isnothing(T_imp!) && !isnothing(newtons_method)
62-
NVTX.@range "update!" color = colorant"yellow" begin
63-
(; update_j) = newtons_method
64-
jacobian = newtons_method_cache.j
65-
if (!isnothing(jacobian)) && needs_update!(update_j, NewTimeStep(t))
66-
if γ isa Nothing
67-
sdirk_error(name)
68-
else
69-
T_imp!.Wfact(jacobian, u, p, dt * γ, t)
70-
end
62+
(; update_j) = newtons_method
63+
jacobian = newtons_method_cache.j
64+
if (!isnothing(jacobian)) && needs_update!(update_j, NewTimeStep(t))
65+
if γ isa Nothing
66+
sdirk_error(name)
67+
else
68+
T_imp!.Wfact(jacobian, u, p, dt * γ, t)
7169
end
7270
end
7371
end
7472

7573
for i in 1:s
76-
NVTX.@range "stage" payload = i begin
77-
t_exp = t + dt * c_exp[i]
78-
t_imp = t + dt * c_imp[i]
74+
t_exp = t + dt * c_exp[i]
75+
t_imp = t + dt * c_imp[i]
7976

80-
NVTX.@range "U = u" color = colorant"yellow" begin
81-
@. U = u
82-
end
77+
@. U = u
8378

84-
if !isnothing(T_lim!) # Update based on limited tendencies from previous stages
85-
NVTX.@range "U+=dt*a_exp*T_lim" color = colorant"yellow" begin
86-
for j in 1:(i - 1)
87-
iszero(a_exp[i, j]) && continue
88-
@. U += dt * a_exp[i, j] * T_lim[j]
89-
end
90-
end
91-
NVTX.@range "lim!" color = colorant"yellow" begin
92-
lim!(U, p, t_exp, u)
93-
end
94-
end
95-
96-
if !isnothing(T_exp!) # Update based on explicit tendencies from previous stages
97-
NVTX.@range "U+=dt*a_exp*T_exp" color = colorant"yellow" begin
98-
for j in 1:(i - 1)
99-
iszero(a_exp[i, j]) && continue
100-
@. U += dt * a_exp[i, j] * T_exp[j]
101-
end
102-
end
79+
if !isnothing(T_lim!) # Update based on limited tendencies from previous stages
80+
for j in 1:(i - 1)
81+
iszero(a_exp[i, j]) && continue
82+
@. U += dt * a_exp[i, j] * T_lim[j]
10383
end
84+
lim!(U, p, t_exp, u)
85+
end
10486

105-
if !isnothing(T_imp!) # Update based on implicit tendencies from previous stages
106-
NVTX.@range "U+=dt*a_imp*T_imp" color = colorant"yellow" begin
107-
for j in 1:(i - 1)
108-
iszero(a_imp[i, j]) && continue
109-
@. U += dt * a_imp[i, j] * T_imp[j]
110-
end
111-
end
87+
if !isnothing(T_exp!) # Update based on explicit tendencies from previous stages
88+
for j in 1:(i - 1)
89+
iszero(a_exp[i, j]) && continue
90+
@. U += dt * a_exp[i, j] * T_exp[j]
11291
end
92+
end
11393

114-
NVTX.@range "dss!" color = colorant"yellow" begin
115-
dss!(U, p, t_exp)
94+
if !isnothing(T_imp!) # Update based on implicit tendencies from previous stages
95+
for j in 1:(i - 1)
96+
iszero(a_imp[i, j]) && continue
97+
@. U += dt * a_imp[i, j] * T_imp[j]
11698
end
99+
end
117100

118-
if !(!isnothing(T_imp!) && !iszero(a_imp[i, i])) # Implicit solve
119-
post_explicit!(U, p, t_imp)
120-
else
121-
@assert !isnothing(newtons_method)
122-
NVTX.@range "temp = U" color = colorant"yellow" begin
123-
@. temp = U
101+
dss!(U, p, t_exp)
102+
103+
if !(!isnothing(T_imp!) && !iszero(a_imp[i, i])) # Implicit solve
104+
post_explicit!(U, p, t_imp)
105+
else
106+
@assert !isnothing(newtons_method)
107+
@. temp = U
108+
post_explicit!(U, p, t_imp)
109+
# TODO: can/should we remove these closures?
110+
implicit_equation_residual! =
111+
(residual, Ui) -> begin
112+
T_imp!(residual, Ui, p, t_imp)
113+
@. residual = temp + dt * a_imp[i, i] * residual - Ui
124114
end
125-
post_explicit!(U, p, t_imp)
126-
# TODO: can/should we remove these closures?
127-
implicit_equation_residual! =
128-
(residual, Ui) -> begin
129-
NVTX.@range "T_imp!" color = colorant"yellow" begin
130-
T_imp!(residual, Ui, p, t_imp)
131-
end
132-
NVTX.@range "residual=temp+dt*a_imp*residual-Ui" color = colorant"yellow" begin
133-
@. residual = temp + dt * a_imp[i, i] * residual - Ui
134-
end
115+
implicit_equation_jacobian! = (jacobian, Ui) -> T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
116+
call_post_implicit! = Ui -> begin
117+
post_implicit!(Ui, p, t_imp)
118+
end
119+
call_post_implicit_last! =
120+
Ui -> begin
121+
if (!all(iszero, a_imp[:, i]) || !iszero(b_imp[i])) && !iszero(a_imp[i, i])
122+
# If T_imp[i] is being treated implicitly, ensure that it
123+
# exactly satisfies the implicit equation.
124+
@. T_imp[i] = (Ui - temp) / (dt * a_imp[i, i])
135125
end
136-
implicit_equation_jacobian! = (jacobian, Ui) -> T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
137-
call_post_implicit! = Ui -> begin
138126
post_implicit!(Ui, p, t_imp)
139127
end
140-
call_post_implicit_last! =
141-
Ui -> begin
142-
if (!all(iszero, a_imp[:, i]) || !iszero(b_imp[i])) && !iszero(a_imp[i, i])
143-
# If T_imp[i] is being treated implicitly, ensure that it
144-
# exactly satisfies the implicit equation.
145-
@. T_imp[i] = (Ui - temp) / (dt * a_imp[i, i])
146-
end
147-
post_implicit!(Ui, p, t_imp)
148-
end
149128

150-
NVTX.@range "solve_newton!" color = colorant"yellow" begin
151-
solve_newton!(
152-
newtons_method,
153-
newtons_method_cache,
154-
U,
155-
implicit_equation_residual!,
156-
implicit_equation_jacobian!,
157-
call_post_implicit!,
158-
call_post_implicit_last!,
159-
)
160-
end
161-
end
129+
solve_newton!(
130+
newtons_method,
131+
newtons_method_cache,
132+
U,
133+
implicit_equation_residual!,
134+
implicit_equation_jacobian!,
135+
call_post_implicit!,
136+
call_post_implicit_last!,
137+
)
138+
end
162139

163-
# We do not need to DSS U again because the implicit solve should
164-
# give the same results for redundant columns (as long as the implicit
165-
# tendency only acts in the vertical direction).
140+
# We do not need to DSS U again because the implicit solve should
141+
# give the same results for redundant columns (as long as the implicit
142+
# tendency only acts in the vertical direction).
166143

167-
if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i])
168-
if iszero(a_imp[i, i]) && !isnothing(T_imp!)
169-
# If its coefficient is 0, T_imp[i] is effectively being
170-
# treated explicitly.
171-
NVTX.@range "T_imp!" color = colorant"yellow" begin
172-
T_imp!(T_imp[i], U, p, t_imp)
173-
end
174-
end
144+
if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i])
145+
if iszero(a_imp[i, i]) && !isnothing(T_imp!)
146+
# If its coefficient is 0, T_imp[i] is effectively being
147+
# treated explicitly.
148+
T_imp!(T_imp[i], U, p, t_imp)
175149
end
150+
end
176151

177-
if !all(iszero, a_exp[:, i]) || !iszero(b_exp[i])
178-
if !isnothing(T_lim!)
179-
NVTX.@range "T_lim!" color = colorant"yellow" begin
180-
T_lim!(T_lim[i], U, p, t_exp)
181-
end
182-
end
183-
if !isnothing(T_exp!)
184-
NVTX.@range "T_exp!" color = colorant"yellow" begin
185-
T_exp!(T_exp[i], U, p, t_exp)
186-
end
187-
end
152+
if !all(iszero, a_exp[:, i]) || !iszero(b_exp[i])
153+
if !isnothing(T_lim!)
154+
T_lim!(T_lim[i], U, p, t_exp)
155+
end
156+
if !isnothing(T_exp!)
157+
T_exp!(T_exp[i], U, p, t_exp)
188158
end
189159
end
190160
end
191161

192162
t_final = t + dt
193163

194164
if !isnothing(T_lim!) # Update based on limited tendencies from previous stages
195-
NVTX.@range "temp=u" color = colorant"yellow" begin
196-
@. temp = u
197-
end
198-
NVTX.@range "temp+=dt*b_exp*T_lim" color = colorant"yellow" begin
199-
for j in 1:s
200-
iszero(b_exp[j]) && continue
201-
@. temp += dt * b_exp[j] * T_lim[j]
202-
end
203-
end
204-
NVTX.@range "lim!" color = colorant"yellow" begin
205-
lim!(temp, p, t_final, u)
206-
end
207-
NVTX.@range "u=temp" color = colorant"yellow" begin
208-
@. u = temp
165+
@. temp = u
166+
for j in 1:s
167+
iszero(b_exp[j]) && continue
168+
@. temp += dt * b_exp[j] * T_lim[j]
209169
end
170+
lim!(temp, p, t_final, u)
171+
@. u = temp
210172
end
211173

212174
if !isnothing(T_exp!) # Update based on explicit tendencies from previous stages
213-
NVTX.@range "u+=dt*b_exp*T_exp" color = colorant"yellow" begin
214-
for j in 1:s
215-
iszero(b_exp[j]) && continue
216-
@. u += dt * b_exp[j] * T_exp[j]
217-
end
175+
for j in 1:s
176+
iszero(b_exp[j]) && continue
177+
@. u += dt * b_exp[j] * T_exp[j]
218178
end
219179
end
220180

221181
if !isnothing(T_imp!) # Update based on implicit tendencies from previous stages
222-
NVTX.@range "u+=dt*b_imp*T_imp" color = colorant"yellow" begin
223-
for j in 1:s
224-
iszero(b_imp[j]) && continue
225-
@. u += dt * b_imp[j] * T_imp[j]
226-
end
182+
for j in 1:s
183+
iszero(b_imp[j]) && continue
184+
@. u += dt * b_imp[j] * T_imp[j]
227185
end
228186
end
229187

230-
NVTX.@range "dss!" color = colorant"yellow" begin
231-
dss!(u, p, t_final)
232-
end
188+
dss!(u, p, t_final)
189+
post_explicit!(u, p, t_final)
233190

234191
return u
235192
end

src/solvers/imex_ssprk.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ function step_u!(integrator, cache::IMEXSSPRKCache, f, name)
185185
end
186186

187187
dss!(u, p, t_final)
188+
post_explicit!(u, p, t_final)
188189

189190
return u
190191
end

0 commit comments

Comments
 (0)