Skip to content

Commit 1a6fdd4

Browse files
Merge pull request #318 from CliMA/ck/callback_inference
Fix inference failure in callbacks
2 parents cbcd1f9 + 46bc872 commit 1a6fdd4

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ClimaTimeSteppers"
22
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
33
authors = ["Climate Modeling Alliance"]
4-
version = "0.7.37"
4+
version = "0.7.38"
55

66
[deps]
77
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"

perf/jet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,5 @@ end
6060
JET.@test_opt CTS.step_u!(integrator, integrator.cache)
6161

6262
CTS.__step!(integrator) # compile first, and make sure it runs
63-
JET.@test_opt broken = true CTS.__step!(integrator)
63+
JET.@test_opt CTS.__step!(integrator)
6464
end

src/integrators.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,17 @@ is_past_t(integrator, t) = tdir(integrator) * (t - integrator.t) < zero(integrat
225225
reached_tstop(integrator, tstop, stop_at_tstop = integrator.dtchangeable) =
226226
integrator.t == tstop || (!stop_at_tstop && is_past_t(integrator, tstop))
227227

228+
229+
@inline unrolled_foreach(::Tuple{}, integrator) = nothing
230+
@inline unrolled_foreach(callback, integrator) =
231+
callback.condition(integrator.u, integrator.t, integrator) ? callback.affect!(integrator) : nothing
232+
@inline unrolled_foreach(discrete_callbacks::Tuple{Any}, integrator) =
233+
unrolled_foreach(first(discrete_callbacks), integrator)
234+
@inline function unrolled_foreach(discrete_callbacks::Tuple, integrator)
235+
unrolled_foreach(first(discrete_callbacks), integrator)
236+
unrolled_foreach(Base.tail(discrete_callbacks), integrator)
237+
end
238+
228239
function __step!(integrator)
229240
(; _dt, dtchangeable, tstops) = integrator
230241

@@ -246,13 +257,7 @@ function __step!(integrator)
246257

247258
# apply callbacks
248259
discrete_callbacks = integrator.callback.discrete_callbacks
249-
for (ncb, callback) in enumerate(discrete_callbacks)
250-
if callback.condition(integrator.u, integrator.t, integrator)::Bool
251-
NVTX.@range "Callback $ncb of $(length(discrete_callbacks))" color = colorant"yellow" begin
252-
callback.affect!(integrator)
253-
end
254-
end
255-
end
260+
unrolled_foreach(discrete_callbacks, integrator)
256261

257262
# remove tstops that were just reached
258263
while !isempty(tstops) && reached_tstop(integrator, first(tstops))

0 commit comments

Comments
 (0)