Skip to content

Commit c2d5509

Browse files
Add GTPSA integration tests
1 parent a26018c commit c2d5509

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

test/downstream/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
55
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
66
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
77
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8+
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
89
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
910
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
1011
MultiScaleArrays = "f9640e96-87f6-5992-9c3b-0743c6a49ffa"

test/downstream/gtpsa.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using OrdinaryDiffEq, ForwardDiff, GTPSA, Test
2+
3+
f!(du, u, p, t) = du .= p .* u
4+
5+
# Initial variables and parameters
6+
x = [1.0, 2.0, 3.0]
7+
p = [4.0, 5.0, 6.0]
8+
9+
prob = ODEProblem(f!, x, (0.0, 1.0), p)
10+
sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16)
11+
12+
# Parametric GTPSA map
13+
desc = Descriptor(3, 2, 3, 2) # 3 variables 3 parameters, both to 2nd order
14+
dx = vars(desc)
15+
dp = params(desc)
16+
prob_GTPSA = ODEProblem(f!, x .+ dx, (0.0, 1.0), p .+ dp)
17+
sol_GTPSA = solve(prob_GTPSA, Tsit5(), reltol=1e-16, abstol=1e-16)
18+
19+
@test sol.u[end] scalar.(sol_GTPSA.u[end]) # scalar gets 0th order part
20+
21+
# Compare Jacobian against ForwardDiff
22+
J_FD = ForwardDiff.jacobian([x..., p...]) do t
23+
prob = ODEProblem(f!, t[1:3], (0.0, 1.0), t[4:6])
24+
sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16)
25+
sol.u[end]
26+
end
27+
28+
@test J_FD GTPSA.jacobian(sol_GTPSA.u[end], include_params=true)
29+
30+
# Compare Hessians against ForwardDiff
31+
for i in 1:3
32+
Hi_FD = ForwardDiff.hessian([x..., p...]) do t
33+
prob = ODEProblem(f!, t[1:3], (0.0, 1.0), t[4:6])
34+
sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16)
35+
sol.u[end][i]
36+
end
37+
@test Hi_FD GTPSA.hessian(sol_GTPSA.u[end][i], include_params=true)
38+
end
39+

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ end
5555
@time @safetestset "Default linsolve with structure" include("downstream/default_linsolve_structure.jl")
5656
@time @safetestset "Callback Merging Tests" include("downstream/callback_merging.jl")
5757
@time @safetestset "LabelledArrays Tests" include("downstream/labelledarrays.jl")
58+
@time @safetestset "GTPSA Tests" include("downstream/gtpsa.jl")
5859
end
5960

6061
if !is_APPVEYOR && GROUP == "Downstream2"

0 commit comments

Comments
 (0)