Skip to content

Commit c2c0850

Browse files
format
1 parent 1a870e8 commit c2c0850

8 files changed

+40
-30
lines changed

ext/DiffEqBaseForwardDiffExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ module DiffEqBaseForwardDiffExt
22

33
using DiffEqBase, ForwardDiff
44
using DiffEqBase.ArrayInterface
5-
using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag, AbstractTimeseriesSolution,
6-
RecursiveArrayTools, reduce_tup, _promote_tspan, has_continuous_callback
5+
using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag,
6+
AbstractTimeseriesSolution,
7+
RecursiveArrayTools, reduce_tup, _promote_tspan, has_continuous_callback
78
import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, prob2dtmin,
89
promote_tspan, anyeltypedual, isdualtype, value, ODE_DEFAULT_NORM,
910
InternalITP, nextfloat_tdir, DualEltypeChecker, sse
@@ -502,7 +503,8 @@ unitfulvalue(x::ForwardDiff.Dual) = unitfulvalue(ForwardDiff.unitfulvalue(x))
502503

503504
sse(x::ForwardDiff.Dual) = sse(ForwardDiff.value(x)) + sum(sse, ForwardDiff.partials(x))
504505
function DiffEqBase.totallength(x::ForwardDiff.Dual)
505-
return DiffEqBase.totallength(ForwardDiff.value(x)) + sum(DiffEqBase.totallength, ForwardDiff.partials(x))
506+
return DiffEqBase.totallength(ForwardDiff.value(x)) +
507+
sum(DiffEqBase.totallength, ForwardDiff.partials(x))
506508
end
507509

508510
@inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual, ::Any) = sqrt(sse(u))

ext/DiffEqBaseGTPSAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@ function ODE_DEFAULT_NORM(f::F, u::AbstractArray{<:TPS}, t) where {F}
3535
Base.FastMath.sqrt_fast(x / max(length(u), 1))
3636
end
3737

38-
end
38+
end

ext/DiffEqBaseMonteCarloMeasurementsExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ end
2626

2727
DiffEqBase.value(x::Type{MonteCarloMeasurements.AbstractParticles{T, N}}) where {T, N} = T
2828
DiffEqBase.value(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles)
29-
DiffEqBase.unitfulvalue(x::Type{MonteCarloMeasurements.AbstractParticles{T, N}}) where {T, N} = T
29+
function DiffEqBase.unitfulvalue(x::Type{MonteCarloMeasurements.AbstractParticles{
30+
T, N}}) where {T, N}
31+
T
32+
end
3033
DiffEqBase.unitfulvalue(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles)
3134

3235
# Support adaptive steps should be errorless

ext/DiffEqBaseReverseDiffExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ end
2222
DiffEqBase.value(x::ReverseDiff.TrackedReal) = x.value
2323
DiffEqBase.value(x::ReverseDiff.TrackedArray) = x.value
2424

25-
2625
DiffEqBase.unitfulvalue(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V
2726
function DiffEqBase.unitfulvalue(x::Type{
2827
ReverseDiff.TrackedArray{V, D, N, VA, DA},

ext/DiffEqBaseTrackerExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ DiffEqBase.value(x::Tracker.TrackedReal) = x.data
1616
DiffEqBase.value(x::Tracker.TrackedArray) = x.data
1717

1818
DiffEqBase.unitfulvalue(x::Type{Tracker.TrackedReal{T}}) where {T} = T
19-
DiffEqBase.unitfulvalue(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A} = Array{T, N}
19+
function DiffEqBase.unitfulvalue(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A}
20+
Array{T, N}
21+
end
2022
DiffEqBase.unitfulvalue(x::Tracker.TrackedReal) = x.data
2123
DiffEqBase.unitfulvalue(x::Tracker.TrackedArray) = x.data
2224

src/norecompile.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ end
2323

2424
# Default dispatch assumes no ForwardDiff, gets added in the new dispatch
2525
function wrapfun_iip(ff, inputs)
26-
FunctionWrappersWrappers.FunctionWrappersWrapper(Void(ff), (typeof(inputs),), (Nothing,))
26+
FunctionWrappersWrappers.FunctionWrappersWrapper(
27+
Void(ff), (typeof(inputs),), (Nothing,))
2728
end
2829

2930
function wrapfun_oop(ff, inputs)
30-
FunctionWrappersWrappers.FunctionWrappersWrapper(ff, (typeof(inputs),), (typeof(inputs[1]),))
31-
end
31+
FunctionWrappersWrappers.FunctionWrappersWrapper(
32+
ff, (typeof(inputs),), (typeof(inputs[1]),))
33+
end

test/downstream/gtpsa.jl

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,68 +9,69 @@ x = [1.0, 2.0, 3.0]
99
p = [4.0, 5.0, 6.0]
1010

1111
prob = ODEProblem(f!, x, (0.0, 1.0), p)
12-
sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16)
12+
sol = solve(prob, Tsit5(), reltol = 1e-16, abstol = 1e-16)
1313

1414
# Parametric GTPSA map
1515
desc = Descriptor(3, 2, 3, 2) # 3 variables 3 parameters, both to 2nd order
1616
dx = @vars(desc)
1717
dp = @params(desc)
1818
prob_GTPSA = ODEProblem(f!, x .+ dx, (0.0, 1.0), p .+ dp)
19-
sol_GTPSA = solve(prob_GTPSA, Tsit5(), reltol=1e-16, abstol=1e-16)
19+
sol_GTPSA = solve(prob_GTPSA, Tsit5(), reltol = 1e-16, abstol = 1e-16)
2020

2121
@test sol.u[end] scalar.(sol_GTPSA.u[end]) # scalar gets 0th order part
2222

2323
# Compare Jacobian against ForwardDiff
2424
J_FD = ForwardDiff.jacobian([x..., p...]) do t
2525
prob = ODEProblem(f!, t[1:3], (0.0, 1.0), t[4:6])
26-
sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16)
26+
sol = solve(prob, Tsit5(), reltol = 1e-16, abstol = 1e-16)
2727
sol.u[end]
2828
end
2929

30-
@test J_FD GTPSA.jacobian(sol_GTPSA.u[end], include_params=true)
30+
@test J_FD GTPSA.jacobian(sol_GTPSA.u[end], include_params = true)
3131

3232
# Compare Hessians against ForwardDiff
3333
for i in 1:3
3434
Hi_FD = ForwardDiff.hessian([x..., p...]) do t
3535
prob = ODEProblem(f!, t[1:3], (0.0, 1.0), t[4:6])
36-
sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16)
36+
sol = solve(prob, Tsit5(), reltol = 1e-16, abstol = 1e-16)
3737
sol.u[end][i]
3838
end
39-
@test Hi_FD GTPSA.hessian(sol_GTPSA.u[end][i], include_params=true)
39+
@test Hi_FD GTPSA.hessian(sol_GTPSA.u[end][i], include_params = true)
4040
end
4141

42-
4342
# ODEProblem 2 =======================
44-
pdot!(dq, p, q, params, t) = dq .= [0.0, 0.0, 0.0]
45-
qdot!(dp, p, q, params, t) = dp .= [p[1] / sqrt((1 + p[3])^2 - p[1]^2 - p[2]^2),
46-
p[2] / sqrt((1 + p[3])^2 - p[1]^2 - p[2]^2),
47-
p[3] / sqrt(1 + p[3]^2) - (p[3] + 1)/sqrt((1 + p[3])^2 - p[1]^2 - p[2]^2)]
43+
pdot!(dq, p, q, params, t) = dq .= [0.0, 0.0, 0.0]
44+
function qdot!(dp, p, q, params, t)
45+
dp .= [p[1] / sqrt((1 + p[3])^2 - p[1]^2 - p[2]^2),
46+
p[2] / sqrt((1 + p[3])^2 - p[1]^2 - p[2]^2),
47+
p[3] / sqrt(1 + p[3]^2) - (p[3] + 1) / sqrt((1 + p[3])^2 - p[1]^2 - p[2]^2)]
48+
end
4849

4950
prob = DynamicalODEProblem(pdot!, qdot!, [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], (0.0, 25.0))
50-
sol = solve(prob, Yoshida6(), dt = 1.0, reltol=1e-16, abstol=1e-16)
51+
sol = solve(prob, Yoshida6(), dt = 1.0, reltol = 1e-16, abstol = 1e-16)
5152

5253
desc = Descriptor(6, 2) # 6 variables to 2nd order
53-
dx = @vars(desc) # identity map
54+
dx = @vars(desc) # identity map
5455
prob_GTPSA = DynamicalODEProblem(pdot!, qdot!, dx[1:3], dx[4:6], (0.0, 25.0))
55-
sol_GTPSA = solve(prob_GTPSA, Yoshida6(), dt = 1.0, reltol=1e-16, abstol=1e-16)
56+
sol_GTPSA = solve(prob_GTPSA, Yoshida6(), dt = 1.0, reltol = 1e-16, abstol = 1e-16)
5657

5758
@test sol.u[end] scalar.(sol_GTPSA.u[end]) # scalar gets 0th order part
5859

5960
# Compare Jacobian against ForwardDiff
6061
J_FD = ForwardDiff.jacobian(zeros(6)) do t
6162
prob = DynamicalODEProblem(pdot!, qdot!, t[1:3], t[4:6], (0.0, 25.0))
62-
sol = solve(prob, Yoshida6(), dt = 1.0, reltol=1e-16, abstol=1e-16)
63+
sol = solve(prob, Yoshida6(), dt = 1.0, reltol = 1e-16, abstol = 1e-16)
6364
sol.u[end]
6465
end
6566

66-
@test J_FD GTPSA.jacobian(sol_GTPSA.u[end], include_params=true)
67+
@test J_FD GTPSA.jacobian(sol_GTPSA.u[end], include_params = true)
6768

6869
# Compare Hessians against ForwardDiff
6970
for i in 1:6
7071
Hi_FD = ForwardDiff.hessian(zeros(6)) do t
71-
prob = DynamicalODEProblem(pdot!, qdot!, t[1:3], t[4:6], (0.0, 25.0))
72-
sol = solve(prob, Yoshida6(), dt = 1.0, reltol=1e-16, abstol=1e-16)
72+
prob = DynamicalODEProblem(pdot!, qdot!, t[1:3], t[4:6], (0.0, 25.0))
73+
sol = solve(prob, Yoshida6(), dt = 1.0, reltol = 1e-16, abstol = 1e-16)
7374
sol.u[end][i]
7475
end
75-
@test Hi_FD GTPSA.hessian(sol_GTPSA.u[end][i], include_params=true)
76+
@test Hi_FD GTPSA.hessian(sol_GTPSA.u[end][i], include_params = true)
7677
end

test/forwarddiff_dual_detection.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,5 +387,6 @@ u = Dual.(val, par)
387387
@test DiffEqBase.totallength(val[1]) == 1
388388
@test DiffEqBase.totallength(val) == length(val)
389389
@test DiffEqBase.totallength(par) == length(par)
390-
@test DiffEqBase.totallength(u[1]) == DiffEqBase.totallength(val[1]) + DiffEqBase.totallength(par[1])
390+
@test DiffEqBase.totallength(u[1]) ==
391+
DiffEqBase.totallength(val[1]) + DiffEqBase.totallength(par[1])
391392
@test DiffEqBase.totallength(u) == sum(DiffEqBase.totallength, u)

0 commit comments

Comments
 (0)