Skip to content

Commit a9f73bb

Browse files
Merge pull request #1107 from mattsignorelli/addgtpsa
Add GTPSA extension
2 parents ceec4c0 + c2d5509 commit a9f73bb

File tree

5 files changed

+61
-0
lines changed

5 files changed

+61
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
4141
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
4242
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4343
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
44+
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
4445
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
4546
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
4647
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
@@ -55,6 +56,7 @@ DiffEqBaseChainRulesCoreExt = "ChainRulesCore"
5556
DiffEqBaseDistributionsExt = "Distributions"
5657
DiffEqBaseEnzymeExt = ["ChainRulesCore", "Enzyme"]
5758
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
59+
DiffEqBaseGTPSAExt = "GTPSA"
5860
DiffEqBaseMPIExt = "MPI"
5961
DiffEqBaseMeasurementsExt = "Measurements"
6062
DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements"
@@ -81,6 +83,7 @@ ForwardDiff = "0.10"
8183
FunctionWrappers = "1.0"
8284
FunctionWrappersWrappers = "0.1"
8385
GeneralizedGenerated = "0.3"
86+
GTPSA = "1.3"
8487
LinearAlgebra = "1.9"
8588
Logging = "1.9"
8689
MPI = "0.20"

ext/DiffEqBaseGTPSAExt.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
module DiffEqBaseGTPSAExt
2+
3+
if isdefined(Base, :get_extension)
4+
using DiffEqBase
5+
import DiffEqBase: value
6+
using GTPSA
7+
else
8+
using ..DiffEqBase
9+
import ..DiffEqBase: value
10+
using ..GTPSA
11+
end
12+
13+
value(x::TPS) = scalar(x);
14+
value(::Type{TPS{T}}) where {T} = T
15+
16+
17+
end

test/downstream/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
55
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
66
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
77
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8+
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
89
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
910
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
1011
MultiScaleArrays = "f9640e96-87f6-5992-9c3b-0743c6a49ffa"

test/downstream/gtpsa.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using OrdinaryDiffEq, ForwardDiff, GTPSA, Test
2+
3+
f!(du, u, p, t) = du .= p .* u
4+
5+
# Initial variables and parameters
6+
x = [1.0, 2.0, 3.0]
7+
p = [4.0, 5.0, 6.0]
8+
9+
prob = ODEProblem(f!, x, (0.0, 1.0), p)
10+
sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16)
11+
12+
# Parametric GTPSA map
13+
desc = Descriptor(3, 2, 3, 2) # 3 variables 3 parameters, both to 2nd order
14+
dx = vars(desc)
15+
dp = params(desc)
16+
prob_GTPSA = ODEProblem(f!, x .+ dx, (0.0, 1.0), p .+ dp)
17+
sol_GTPSA = solve(prob_GTPSA, Tsit5(), reltol=1e-16, abstol=1e-16)
18+
19+
@test sol.u[end] scalar.(sol_GTPSA.u[end]) # scalar gets 0th order part
20+
21+
# Compare Jacobian against ForwardDiff
22+
J_FD = ForwardDiff.jacobian([x..., p...]) do t
23+
prob = ODEProblem(f!, t[1:3], (0.0, 1.0), t[4:6])
24+
sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16)
25+
sol.u[end]
26+
end
27+
28+
@test J_FD GTPSA.jacobian(sol_GTPSA.u[end], include_params=true)
29+
30+
# Compare Hessians against ForwardDiff
31+
for i in 1:3
32+
Hi_FD = ForwardDiff.hessian([x..., p...]) do t
33+
prob = ODEProblem(f!, t[1:3], (0.0, 1.0), t[4:6])
34+
sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16)
35+
sol.u[end][i]
36+
end
37+
@test Hi_FD GTPSA.hessian(sol_GTPSA.u[end][i], include_params=true)
38+
end
39+

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ end
5555
@time @safetestset "Default linsolve with structure" include("downstream/default_linsolve_structure.jl")
5656
@time @safetestset "Callback Merging Tests" include("downstream/callback_merging.jl")
5757
@time @safetestset "LabelledArrays Tests" include("downstream/labelledarrays.jl")
58+
@time @safetestset "GTPSA Tests" include("downstream/gtpsa.jl")
5859
end
5960

6061
if !is_APPVEYOR && GROUP == "Downstream2"

0 commit comments

Comments
 (0)