Skip to content

Commit 00431ad

Browse files
committed
add tests for nested duals
1 parent f829926 commit 00431ad

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using RecursiveArrayTools
1313
const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Number , P}
1414

1515
# Define type for nested dual numbers
16-
const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <: Dual, P}
16+
const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <:Dual, P}
1717

1818
const SingleDualLinearProblem = LinearProblem{
1919
<:Union{Number, <:AbstractArray, Nothing}, iip,

test/forwarddiff_overloads.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,60 @@ x_p = solve!(cache)
8080
backslash_x_p = A \ new_b
8181

8282
@test (x_p, backslash_x_p, rtol = 1e-9)
83+
84+
# Nested Duals
85+
function h(p)
86+
(A = [p[1] p[2]+1 p[2]^3;
87+
3*p[1] p[1]+5 p[2] * p[1]-4;
88+
p[2]^2 9*p[1] p[2]],
89+
b = [p[1] + 1, p[2] * 2, p[1]^2])
90+
end
91+
92+
A, b = h([ForwardDiff.Dual(ForwardDiff.Dual(5.0, 1.0, 0.0), 1.0, 0.0),
93+
ForwardDiff.Dual(ForwardDiff.Dual(5.0, 1.0, 0.0), 0.0, 1.0)])
94+
95+
prob = LinearProblem(A, b)
96+
overload_x_p = solve(prob)
97+
98+
original_x_p = A \ b
99+
100+
(overload_x_p, original_x_p, rtol = 1e-9)
101+
102+
function linprob_f(p)
103+
A, b = h(p)
104+
prob = LinearProblem(A, b)
105+
solve(prob)
106+
end
107+
108+
function slash_f(p)
109+
A, b = h(p)
110+
A \ b
111+
end
112+
113+
(ForwardDiff.jacobian(slash_f, [5.0, 5.0]), ForwardDiff.jacobian(linprob_f, [5.0, 5.0]))
114+
115+
(ForwardDiff.jacobian(p -> ForwardDiff.jacobian(slash_f, [5.0, p[1]]), [5.0]),
116+
ForwardDiff.jacobian(p -> ForwardDiff.jacobian(linprob_f, [5.0, p[1]]), [5.0]))
117+
118+
function g(p)
119+
(A = [p[1] p[1]+1 p[1]^3;
120+
3*p[1] p[1]+5 p[1] * p[1]-4;
121+
p[1]^2 9*p[1] p[1]],
122+
b = [p[1] + 1, p[1] * 2, p[1]^2])
123+
end
124+
125+
function slash_f_hes(p)
126+
A, b = g(p)
127+
x = A \ b
128+
sum(x)
129+
end
130+
131+
function linprob_f_hes(p)
132+
A, b = g(p)
133+
prob = LinearProblem(A, b)
134+
x = solve(prob)
135+
sum(x)
136+
end
137+
138+
(ForwardDiff.hessian(slash_f_hes, [5.0]),
139+
ForwardDiff.hessian(linprob_f_hes, [5.0]))

0 commit comments

Comments
 (0)