Skip to content

Commit 173b3df

Browse files
all GPU implicit tests pass
1 parent 807f5f4 commit 173b3df

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

test/gpu/simple_gpu.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using OrdinaryDiffEq, CuArrays, LinearAlgebra, Test
2+
Base.:+(A::CuArray,I::UniformScaling) = A + CuArray(I,size(A,1),size(A,2))
3+
Base.:-(A::CuArray,I::UniformScaling) = A + CuArray(I,size(A,1),size(A,2))
24
function f(u,p,t)
35
A*u
46
end
@@ -11,35 +13,40 @@ end
1113
function jac(u,p,t)
1214
A
1315
end
14-
ff = ODEFunction(f,jac=jac)
16+
function tgrad(du,u,p,t)
17+
du .= 0
18+
end
19+
function tgrad(u,p,t)
20+
zero(u)
21+
end
22+
ff = ODEFunction(f,jac=jac,tgrad=tgrad)
1523
CuArrays.allowscalar(false)
1624
A = cu(-rand(3,3))
1725
u0 = cu([1.0;0.0;0.0])
1826
tspan = (0f0,100f0)
1927

2028
prob = ODEProblem(ff,u0,tspan)
2129
sol = solve(prob,Tsit5())
22-
@test_broken solve(prob,Rosenbrock23()).retcode == :Success
30+
@test solve(prob,Rosenbrock23()).retcode == :Success
2331
solve(prob,Rosenbrock23(autodiff=false))
2432

2533
prob_oop = ODEProblem{false}(ff,u0,tspan)
2634
CuArrays.allowscalar(false)
2735
sol = solve(prob_oop,Tsit5())
28-
@test_broken solve(prob_oop,Rosenbrock23()).retcode == :Success
29-
@test_broken solve(prob_oop,Rosenbrock23(autodiff=false))
36+
@test solve(prob_oop,Rosenbrock23()).retcode == :Success
37+
@test solve(prob_oop,Rosenbrock23(autodiff=false)).retcode == :Success
3038

3139
prob_nojac = ODEProblem(f,u0,tspan)
32-
@test_broken solve(prob_nojac,Rosenbrock23()).retcode == :Success
40+
@test solve(prob_nojac,Rosenbrock23()).retcode == :Success
3341
@test solve(prob_nojac,Rosenbrock23(autodiff=false)).retcode == :Success
3442
@test solve(prob_nojac,Rosenbrock23(autodiff=false,diff_type = Val{:central})).retcode == :Success
3543
@test solve(prob_nojac,Rosenbrock23(autodiff=false,diff_type = Val{:complex})).retcode == :Success
3644

3745
prob_nojac_oop = ODEProblem{false}(f,u0,tspan)
38-
@test_broken solve(prob_nojac_oop,Rosenbrock23()).retcode == :Success
39-
@test_broken solve(prob_nojac_oop,Rosenbrock23(autodiff=false)).retcode == :Success
40-
@test_broken solve(prob_nojac_oop,Rosenbrock23(autodiff=false,diff_type = Val{:central})).retcode == :Success
41-
# hits a generic matmul fallback
42-
@test_broken solve(prob_nojac_oop,Rosenbrock23(autodiff=false,diff_type = Val{:complex})).retcode == :Success
46+
@test solve(prob_nojac_oop,Rosenbrock23()).retcode == :Success
47+
@test solve(prob_nojac_oop,Rosenbrock23(autodiff=false)).retcode == :Success
48+
@test solve(prob_nojac_oop,Rosenbrock23(autodiff=false,diff_type = Val{:central})).retcode == :Success
49+
@test solve(prob_nojac_oop,Rosenbrock23(autodiff=false,diff_type = Val{:complex})).retcode == :Success
4350

4451
# Test auto-offload
4552
_A = -rand(3,3)

0 commit comments

Comments
 (0)