Skip to content

Commit cbcd1f9

Browse files
Merge pull request #317 from CliMA/ck/rm_nvtx_range
Add broken test for inference failure in callbacks loop
2 parents cdcab7f + 420f606 commit cbcd1f9

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

perf/jet.jl

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using ArgParse, JET, Test, BenchmarkTools, DiffEqBase, ClimaTimeSteppers
1+
# using Revise; include("perf/jet.jl")
2+
using ArgParse, JET, Test, BenchmarkTools, SciMLBase, ClimaTimeSteppers
23
import ClimaTimeSteppers as CTS
34
function parse_commandline()
45
s = ArgParse.ArgParseSettings()
@@ -15,10 +16,29 @@ end
1516
cts = joinpath(dirname(@__DIR__));
1617
include(joinpath(cts, "test", "problems.jl"))
1718
config_integrators(itc::IntegratorTestCase) = config_integrators(itc.prob)
19+
20+
struct Foo end
21+
foo!(integrator) = nothing
22+
(::Foo)(integrator) = foo!(integrator)
23+
struct Bar end
24+
bar!(integrator) = nothing
25+
(::Bar)(integrator) = bar!(integrator)
26+
27+
function discrete_cb(cb!, n)
28+
cond = if n == 1
29+
(u, t, integrator) -> isnothing(cb!(integrator))
30+
else
31+
(u, t, integrator) -> isnothing(cb!(integrator)) || rand() 0.5
32+
end
33+
SciMLBase.DiscreteCallback(cond, cb!;)
34+
end
1835
function config_integrators(problem)
1936
algorithm = CTS.IMEXAlgorithm(ARS343(), NewtonsMethod(; max_iters = 2))
2037
dt = 0.01
21-
integrator = DiffEqBase.init(problem, algorithm; dt)
38+
discrete_callbacks = (discrete_cb(Foo(), 0), discrete_cb(Bar(), 0), discrete_cb(Foo(), 1), discrete_cb(Bar(), 1))
39+
callback = SciMLBase.CallbackSet((), discrete_callbacks)
40+
41+
integrator = SciMLBase.init(problem, algorithm; dt, callback)
2242
integrator.cache = CTS.init_cache(problem, algorithm)
2343
return (; integrator)
2444
end
@@ -33,7 +53,12 @@ else
3353
end
3454
(; integrator) = config_integrators(prob)
3555

36-
CTS.step_u!(integrator, integrator.cache) # compile first, and make sure it runs
37-
step_allocs = @allocated CTS.step_u!(integrator, integrator.cache)
38-
@show step_allocs
39-
JET.@test_opt CTS.step_u!(integrator, integrator.cache)
56+
@testset "JET / allocations" begin
57+
CTS.step_u!(integrator, integrator.cache) # compile first, and make sure it runs
58+
step_allocs = @allocated CTS.step_u!(integrator, integrator.cache)
59+
@show step_allocs
60+
JET.@test_opt CTS.step_u!(integrator, integrator.cache)
61+
62+
CTS.__step!(integrator) # compile first, and make sure it runs
63+
JET.@test_opt broken = true CTS.__step!(integrator)
64+
end

0 commit comments

Comments
 (0)