Skip to content

Commit 3d1e57a

Browse files
committed
Separate implicit methods for CUDA only
1 parent 09085d7 commit 3d1e57a

File tree

2 files changed

+34
-29
lines changed

2 files changed

+34
-29
lines changed

test/ensemblegpuarray.jl

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ solve(monteprob,TRBDF2(),EnsembleGPUArray(backend),dt=0.1,trajectories=2,saveat=
4343
@test_broken solve(monteprob,TRBDF2(),EnsembleGPUArray(backend),dt=0.1,trajectories=2,saveat=1.0f0)
4444
=#
4545

46-
GROUP == "AMDGPU" && return
47-
4846
@info "Implicit Methods"
4947

5048
function lorenz_jac(J, u, p, t)
@@ -76,15 +74,18 @@ monteprob_jac = EnsembleProblem(prob_jac, prob_func = prob_func)
7674
@time solve(monteprob_jac, Rodas5(), EnsembleCPUArray(), dt = 0.1,
7775
trajectories = 10,
7876
saveat = 1.0f0)
79-
@time solve(monteprob_jac, Rodas5(), EnsembleGPUArray(backend), dt = 0.1,
80-
trajectories = 10,
81-
saveat = 1.0f0)
8277
@time solve(monteprob_jac, TRBDF2(), EnsembleCPUArray(), dt = 0.1,
8378
trajectories = 10,
8479
saveat = 1.0f0)
85-
@time solve(monteprob_jac, TRBDF2(), EnsembleGPUArray(backend), dt = 0.1,
86-
trajectories = 10,
87-
saveat = 1.0f0)
80+
81+
if GROUP == "CUDA"
82+
@time solve(monteprob_jac, Rodas5(), EnsembleGPUArray(backend), dt = 0.1,
83+
trajectories = 10,
84+
saveat = 1.0f0)
85+
@time solve(monteprob_jac, TRBDF2(), EnsembleGPUArray(backend), dt = 0.1,
86+
trajectories = 10,
87+
saveat = 1.0f0)
88+
end
8889

8990
@info "Callbacks"
9091

@@ -185,18 +186,19 @@ sol = solve(rober_prob, Rodas5(), abstol = 1.0f-8, reltol = 1.0f-8)
185186
sol = solve(rober_prob, TRBDF2(), abstol = 1.0f-4, reltol = 1.0f-1)
186187
rober_monteprob = EnsembleProblem(rober_prob, prob_func = prob_func)
187188

188-
# TODO: Does not work with Linearsolve.jl v1.35.0 https://github.com/SciML/DiffEqGPU.jl/pull/229
189+
if GROUP == "CUDA"
190+
@time sol = solve(rober_monteprob, Rodas5(),
191+
EnsembleGPUArray(backend), trajectories = 10,
192+
saveat = 1.0f0,
193+
abstol = 1.0f-8,
194+
reltol = 1.0f-8)
195+
@time sol = solve(rober_monteprob, TRBDF2(),
196+
EnsembleGPUArray(backend), trajectories = 10,
197+
saveat = 1.0f0,
198+
abstol = 1.0f-4,
199+
reltol = 1.0f-1)
200+
end
189201

190-
@time sol = solve(rober_monteprob, Rodas5(),
191-
EnsembleGPUArray(backend), trajectories = 10,
192-
saveat = 1.0f0,
193-
abstol = 1.0f-8,
194-
reltol = 1.0f-8)
195-
@time sol = solve(rober_monteprob, TRBDF2(),
196-
EnsembleGPUArray(backend), trajectories = 10,
197-
saveat = 1.0f0,
198-
abstol = 1.0f-4,
199-
reltol = 1.0f-1)
200202
@time sol = solve(rober_monteprob, TRBDF2(), EnsembleThreads(),
201203
trajectories = 10,
202204
abstol = 1e-4, reltol = 1e-1, saveat = 1.0f0)
@@ -243,5 +245,8 @@ monteprob = EnsembleProblem(prob_jac,
243245
sol = solve(monteprob, Tsit5(), EnsembleGPUArray(backend, 0.0), trajectories = 10,
244246
adaptive = false, dt = 0.01f0, save_everystep = false)
245247

246-
sol = solve(monteprob, Rosenbrock23(), EnsembleGPUArray(backend, 0.0), trajectories = 10,
247-
adaptive = false, dt = 0.01f0, save_everystep = false)
248+
if GROUP == "CUDA"
249+
sol = solve(monteprob, Rosenbrock23(), EnsembleGPUArray(backend, 0.0),
250+
trajectories = 10,
251+
adaptive = false, dt = 0.01f0, save_everystep = false)
252+
end

test/ensemblegpuarray_oop.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false)
3535
@time sol = solve(monteprob, Tsit5(), EnsembleGPUArray(backend), trajectories = 10_000,
3636
saveat = 1.0f0)
3737

38-
GROUP == "AMDGPU" && return
39-
40-
@time sol = solve(monteprob, Rosenbrock23(), EnsembleGPUArray(backend),
41-
trajectories = 10_000,
42-
saveat = 1.0f0)
43-
@time sol = solve(monteprob, TRBDF2(), EnsembleGPUArray(backend),
44-
trajectories = 10_000,
45-
saveat = 1.0f0)
38+
if GROUP == "CUDA"
39+
@time sol = solve(monteprob, Rosenbrock23(), EnsembleGPUArray(backend),
40+
trajectories = 10_000,
41+
saveat = 1.0f0)
42+
@time sol = solve(monteprob, TRBDF2(), EnsembleGPUArray(backend),
43+
trajectories = 10_000,
44+
saveat = 1.0f0)
45+
end

0 commit comments

Comments
 (0)