Skip to content

Commit 419a2b5

Browse files
Merge pull request #1108 from mattsignorelli/gtpsa-2
Add `ODE_DEFAULT_NORM` overload for GTPSA
2 parents a9f73bb + 0f8481c commit 419a2b5

File tree

2 files changed

+58
-3
lines changed

2 files changed

+58
-3
lines changed

ext/DiffEqBaseGTPSAExt.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,34 @@ module DiffEqBaseGTPSAExt
22

33
if isdefined(Base, :get_extension)
44
using DiffEqBase
5-
import DiffEqBase: value
5+
import DiffEqBase: value, ODE_DEFAULT_NORM
66
using GTPSA
77
else
88
using ..DiffEqBase
9-
import ..DiffEqBase: value
9+
import ..DiffEqBase: value, ODE_DEFAULT_NORM
1010
using ..GTPSA
1111
end
1212

13-
value(x::TPS) = scalar(x);
13+
value(x::TPS) = scalar(x)
1414
value(::Type{TPS{T}}) where {T} = T
1515

16+
ODE_DEFAULT_NORM(u::TPS, t) = normTPS(u)
17+
ODE_DEFAULT_NORM(f::F, u::TPS, t) where {F} = normTPS(f(u))
18+
19+
function ODE_DEFAULT_NORM(u::AbstractArray{TPS{T}}, t) where {T}
20+
x = zero(real(T))
21+
@inbounds @fastmath for ui in u
22+
x += normTPS(ui)^2
23+
end
24+
Base.FastMath.sqrt_fast(x / max(length(u), 1))
25+
end
26+
27+
function ODE_DEFAULT_NORM(f::F, u::AbstractArray{TPS{T}}, t) where {F, T}
28+
x = zero(real(T))
29+
@inbounds @fastmath for ui in u
30+
x += normTPS(f(ui))^2
31+
end
32+
Base.FastMath.sqrt_fast(x / max(length(u), 1))
33+
end
1634

1735
end

test/downstream/gtpsa.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using OrdinaryDiffEq, ForwardDiff, GTPSA, Test
22

3+
# ODEProblem 1 =======================
4+
35
f!(du, u, p, t) = du .= p .* u
46

57
# Initial variables and parameters
@@ -37,3 +39,38 @@ for i in 1:3
3739
@test Hi_FD GTPSA.hessian(sol_GTPSA.u[end][i], include_params=true)
3840
end
3941

42+
43+
# ODEProblem 2 =======================
44+
pdot!(dq, p, q, params, t) = dq .= [0.0, 0.0, 0.0]
45+
qdot!(dp, p, q, params, t) = dp .= [p[1] / sqrt((1 + p[3])^2 - p[1]^2 - p[2]^2),
46+
p[2] / sqrt((1 + p[3])^2 - p[1]^2 - p[2]^2),
47+
p[3] / sqrt(1 + p[3]^2) - (p[3] + 1)/sqrt((1 + p[3])^2 - p[1]^2 - p[2]^2)]
48+
49+
prob = DynamicalODEProblem(pdot!, qdot!, [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], (0.0, 25.0))
50+
sol = solve(prob, Yoshida6(), dt = 1.0, reltol=1e-16, abstol=1e-16)
51+
52+
desc = Descriptor(6, 2) # 6 variables to 2nd order
53+
dx = vars(desc) # identity map
54+
prob_GTPSA = DynamicalODEProblem(pdot!, qdot!, dx[1:3], dx[4:6], (0.0, 25.0))
55+
sol_GTPSA = solve(prob_GTPSA, Yoshida6(), dt = 1.0, reltol=1e-16, abstol=1e-16)
56+
57+
@test sol.u[end] scalar.(sol_GTPSA.u[end]) # scalar gets 0th order part
58+
59+
# Compare Jacobian against ForwardDiff
60+
J_FD = ForwardDiff.jacobian(zeros(6)) do t
61+
prob = DynamicalODEProblem(pdot!, qdot!, t[1:3], t[4:6], (0.0, 25.0))
62+
sol = solve(prob, Yoshida6(), dt = 1.0, reltol=1e-16, abstol=1e-16)
63+
sol.u[end]
64+
end
65+
66+
@test J_FD GTPSA.jacobian(sol_GTPSA.u[end], include_params=true)
67+
68+
# Compare Hessians against ForwardDiff
69+
for i in 1:6
70+
Hi_FD = ForwardDiff.hessian(zeros(6)) do t
71+
prob = DynamicalODEProblem(pdot!, qdot!, t[1:3], t[4:6], (0.0, 25.0))
72+
sol = solve(prob, Yoshida6(), dt = 1.0, reltol=1e-16, abstol=1e-16)
73+
sol.u[end][i]
74+
end
75+
@test Hi_FD GTPSA.hessian(sol_GTPSA.u[end][i], include_params=true)
76+
end

0 commit comments

Comments
 (0)