|
1 | 1 | using OrdinaryDiffEq, CuArrays, LinearAlgebra, Test
|
| 2 | +Base.:+(A::CuArray,I::UniformScaling) = A + CuArray(I,size(A,1),size(A,2)) |
| 3 | +Base.:-(A::CuArray,I::UniformScaling) = A + CuArray(I,size(A,1),size(A,2)) |
2 | 4 | function f(u,p,t)
|
3 | 5 | A*u
|
4 | 6 | end
|
|
11 | 13 | function jac(u,p,t)
|
12 | 14 | A
|
13 | 15 | end
|
14 |
| -ff = ODEFunction(f,jac=jac) |
| 16 | +function tgrad(du,u,p,t) |
| 17 | + du .= 0 |
| 18 | +end |
| 19 | +function tgrad(u,p,t) |
| 20 | + zero(u) |
| 21 | +end |
| 22 | +ff = ODEFunction(f,jac=jac,tgrad=tgrad) |
15 | 23 | CuArrays.allowscalar(false)
|
16 | 24 | A = cu(-rand(3,3))
|
17 | 25 | u0 = cu([1.0;0.0;0.0])
|
18 | 26 | tspan = (0f0,100f0)
|
19 | 27 |
|
20 | 28 | prob = ODEProblem(ff,u0,tspan)
|
21 | 29 | sol = solve(prob,Tsit5())
|
22 |
| -@test_broken solve(prob,Rosenbrock23()).retcode == :Success |
| 30 | +@test solve(prob,Rosenbrock23()).retcode == :Success |
23 | 31 | solve(prob,Rosenbrock23(autodiff=false))
|
24 | 32 |
|
25 | 33 | prob_oop = ODEProblem{false}(ff,u0,tspan)
|
26 | 34 | CuArrays.allowscalar(false)
|
27 | 35 | sol = solve(prob_oop,Tsit5())
|
28 |
| -@test_broken solve(prob_oop,Rosenbrock23()).retcode == :Success |
29 |
| -@test_broken solve(prob_oop,Rosenbrock23(autodiff=false)) |
| 36 | +@test solve(prob_oop,Rosenbrock23()).retcode == :Success |
| 37 | +@test solve(prob_oop,Rosenbrock23(autodiff=false)).retcode == :Success |
30 | 38 |
|
31 | 39 | prob_nojac = ODEProblem(f,u0,tspan)
|
32 |
| -@test_broken solve(prob_nojac,Rosenbrock23()).retcode == :Success |
| 40 | +@test solve(prob_nojac,Rosenbrock23()).retcode == :Success |
33 | 41 | @test solve(prob_nojac,Rosenbrock23(autodiff=false)).retcode == :Success
|
34 | 42 | @test solve(prob_nojac,Rosenbrock23(autodiff=false,diff_type = Val{:central})).retcode == :Success
|
35 | 43 | @test solve(prob_nojac,Rosenbrock23(autodiff=false,diff_type = Val{:complex})).retcode == :Success
|
36 | 44 |
|
37 | 45 | prob_nojac_oop = ODEProblem{false}(f,u0,tspan)
|
38 |
| -@test_broken solve(prob_nojac_oop,Rosenbrock23()).retcode == :Success |
39 |
| -@test_broken solve(prob_nojac_oop,Rosenbrock23(autodiff=false)).retcode == :Success |
40 |
| -@test_broken solve(prob_nojac_oop,Rosenbrock23(autodiff=false,diff_type = Val{:central})).retcode == :Success |
41 |
| -# hits a generic matmul fallback |
42 |
| -@test_broken solve(prob_nojac_oop,Rosenbrock23(autodiff=false,diff_type = Val{:complex})).retcode == :Success |
| 46 | +@test solve(prob_nojac_oop,Rosenbrock23()).retcode == :Success |
| 47 | +@test solve(prob_nojac_oop,Rosenbrock23(autodiff=false)).retcode == :Success |
| 48 | +@test solve(prob_nojac_oop,Rosenbrock23(autodiff=false,diff_type = Val{:central})).retcode == :Success |
| 49 | +@test solve(prob_nojac_oop,Rosenbrock23(autodiff=false,diff_type = Val{:complex})).retcode == :Success |
43 | 50 |
|
44 | 51 | # Test auto-offload
|
45 | 52 | _A = -rand(3,3)
|
|
0 commit comments