|
1 | | -using ArgParse, JET, Test, BenchmarkTools, DiffEqBase, ClimaTimeSteppers |
| 1 | +# using Revise; include("perf/jet.jl") |
| 2 | +using ArgParse, JET, Test, BenchmarkTools, SciMLBase, ClimaTimeSteppers |
2 | 3 | import ClimaTimeSteppers as CTS |
3 | 4 | function parse_commandline() |
4 | 5 | s = ArgParse.ArgParseSettings() |
|
15 | 16 | cts = joinpath(dirname(@__DIR__)); |
16 | 17 | include(joinpath(cts, "test", "problems.jl")) |
17 | 18 | config_integrators(itc::IntegratorTestCase) = config_integrators(itc.prob) |
| 19 | + |
| 20 | +struct Foo end |
| 21 | +foo!(integrator) = nothing |
| 22 | +(::Foo)(integrator) = foo!(integrator) |
| 23 | +struct Bar end |
| 24 | +bar!(integrator) = nothing |
| 25 | +(::Bar)(integrator) = bar!(integrator) |
| 26 | + |
| 27 | +function discrete_cb(cb!, n) |
| 28 | + cond = if n == 1 |
| 29 | + (u, t, integrator) -> isnothing(cb!(integrator)) |
| 30 | + else |
| 31 | + (u, t, integrator) -> isnothing(cb!(integrator)) || rand() ≤ 0.5 |
| 32 | + end |
| 33 | + SciMLBase.DiscreteCallback(cond, cb!;) |
| 34 | +end |
18 | 35 | function config_integrators(problem) |
19 | 36 | algorithm = CTS.IMEXAlgorithm(ARS343(), NewtonsMethod(; max_iters = 2)) |
20 | 37 | dt = 0.01 |
21 | | - integrator = DiffEqBase.init(problem, algorithm; dt) |
| 38 | + discrete_callbacks = (discrete_cb(Foo(), 0), discrete_cb(Bar(), 0), discrete_cb(Foo(), 1), discrete_cb(Bar(), 1)) |
| 39 | + callback = SciMLBase.CallbackSet((), discrete_callbacks) |
| 40 | + |
| 41 | + integrator = SciMLBase.init(problem, algorithm; dt, callback) |
22 | 42 | integrator.cache = CTS.init_cache(problem, algorithm) |
23 | 43 | return (; integrator) |
24 | 44 | end |
|
33 | 53 | end |
34 | 54 | (; integrator) = config_integrators(prob) |
35 | 55 |
|
36 | | -CTS.step_u!(integrator, integrator.cache) # compile first, and make sure it runs |
37 | | -step_allocs = @allocated CTS.step_u!(integrator, integrator.cache) |
38 | | -@show step_allocs |
39 | | -JET.@test_opt CTS.step_u!(integrator, integrator.cache) |
| 56 | +@testset "JET / allocations" begin |
| 57 | + CTS.step_u!(integrator, integrator.cache) # compile first, and make sure it runs |
| 58 | + step_allocs = @allocated CTS.step_u!(integrator, integrator.cache) |
| 59 | + @show step_allocs |
| 60 | + JET.@test_opt CTS.step_u!(integrator, integrator.cache) |
| 61 | + |
| 62 | + CTS.__step!(integrator) # compile first, and make sure it runs |
| 63 | + JET.@test_opt broken = true CTS.__step!(integrator) |
| 64 | +end |
0 commit comments