Skip to content

Commit 0928441

Browse files
Add SciML integration tests (#2593)
* Add SciML integration tests * add compats * Update test/integration/SciML/Project.toml * Update runtests.jl * enzyme fails in a testset * move a bit
1 parent 8a2fa2e commit 0928441

File tree

3 files changed

+128
-0
lines changed

3 files changed

+128
-0
lines changed

.github/workflows/Integration.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ jobs:
5050
- Distributions
5151
- DynamicExpressions
5252
- Lux
53+
- SciML
5354
steps:
5455
- uses: actions/checkout@v5
5556
- uses: julia-actions/setup-julia@v2
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
[deps]
2+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
3+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
5+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
6+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
8+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
9+
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
10+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
11+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
12+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
13+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
14+
15+
[sources]
16+
Enzyme = {path = "../../.."}
17+
EnzymeCore = {path = "../../../lib/EnzymeCore"}
18+
19+
[compat]
20+
DiffEqBase = "6.190"
21+
ForwardDiff = "0.10.36, 1"
22+
LinearSolve = "3.12"
23+
OrdinaryDiffEq = "6.89"
24+
OrdinaryDiffEqTsit5 = "1.1"
25+
SciMLSensitivity = "7.69"
26+
StaticArrays = "1.9"
27+
Zygote = "0.7.10"

test/integration/SciML/runtests.jl

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
using Enzyme, OrdinaryDiffEqTsit5, StaticArrays, DiffEqBase, ForwardDiff, Test
2+
using OrdinaryDiffEq, SciMLSensitivity, Zygote
3+
using LinearSolve, LinearAlgebra
4+
5+
@testset "Direct Differentiation of Explicit ODE Solve" begin
6+
function lorenz!(du, u, p, t)
7+
du[1] = 10.0(u[2] - u[1])
8+
du[2] = u[1] * (28.0 - u[3]) - u[2]
9+
du[3] = u[1] * u[2] - (8 / 3) * u[3]
10+
end
11+
12+
_saveat = SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]
13+
14+
function f_dt(y::Array{Float64}, u0::Array{Float64})
15+
tspan = (0.0, 3.0)
16+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
17+
sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)
18+
y .= sol[1,:]
19+
return nothing
20+
end;
21+
22+
function f_dt(u0)
23+
tspan = (0.0, 3.0)
24+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
25+
sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)
26+
sol[1,:]
27+
end;
28+
29+
u0 = [1.0; 0.0; 0.0]
30+
fdj = ForwardDiff.jacobian(f_dt, u0)
31+
32+
ezj = stack(map(1:3) do i
33+
d_u0 = zeros(3)
34+
dy = zeros(13)
35+
y = zeros(13)
36+
d_u0[i] = 1.0
37+
Enzyme.autodiff(Forward, f_dt, Duplicated(y, dy), Duplicated(u0, d_u0));
38+
dy
39+
end)
40+
41+
@test ezj fdj
42+
43+
function f_dt2(u0)
44+
tspan = (0.0, 3.0)
45+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
46+
sol = DiffEqBase.solve(prob, Tsit5(), dt=0.1, saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)
47+
sum(sol[1,:])
48+
end
49+
50+
fdg = ForwardDiff.gradient(f_dt2, u0)
51+
d_u0 = zeros(3)
52+
Enzyme.autodiff(Reverse, f_dt2, Active, Duplicated(u0, d_u0));
53+
54+
@test d_u0 fdg
55+
end
56+
57+
odef(du, u, p, t) = du .= u .* p
58+
prob = ODEProblem(odef, [2.0], (0.0, 1.0), [3.0])
59+
struct senseloss0{T}
60+
sense::T
61+
end
62+
function (f::senseloss0)(u0p)
63+
prob = ODEProblem{true}(odef, u0p[1:1], (0.0, 1.0), u0p[2:2])
64+
sum(solve(prob, Tsit5(), abstol = 1e-12, reltol = 1e-12, saveat = 0.1))
65+
end
66+
67+
@testset "SciMLSensitivity Adjoint Interface" begin
68+
u0p = [2.0, 3.0]
69+
du0p = zeros(2)
70+
@test senseloss0(InterpolatingAdjoint())(u0p) isa Number
71+
dup = Zygote.gradient(senseloss0(InterpolatingAdjoint()), u0p)[1]
72+
Enzyme.autodiff(Reverse, senseloss0(InterpolatingAdjoint()), Active, Duplicated(u0p, du0p))
73+
@test du0p dup
74+
end
75+
76+
@testset "LinearSolve Adjoints" begin
77+
n = 4
78+
A = rand(n, n);
79+
dA = zeros(n, n);
80+
b1 = rand(n);
81+
db1 = zeros(n);
82+
83+
function f(A, b1; alg = LUFactorization())
84+
prob = LinearProblem(A, b1)
85+
86+
sol1 = solve(prob, alg)
87+
88+
s1 = sol1.u
89+
norm(s1)
90+
end
91+
92+
f(A, b1) # Uses BLAS
93+
94+
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1))
95+
dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1)), copy(A))
96+
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))
97+
98+
@test dA dA2
99+
@test db1 db12
100+
end

0 commit comments

Comments
 (0)