Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions test/forwarddiff_overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using LinearSolve
using ForwardDiff
using Test
using SparseArrays
using StaticArrays

function h(p)
(A = [p[1] p[2]+1 p[2]^3;
Expand Down Expand Up @@ -193,3 +194,39 @@ overload_x_p = solve(prob, UMFPACKFactorization())
backslash_x_p = A \ b

@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9)


# Test StaticArrays
# They don't go through the overloads
# But we should test that the overloads don't mess anything up
function static_h(p)
A = @SMatrix [p[1] p[2]+1 p[2]^3;
3*p[1] p[1]+5 p[2] * p[1]-4;
p[2]^2 9*p[1] p[2]]

b = SA[p[1] + 1, p[2] * 2, p[1]^2]

(A, b)
end

function static_linsolve(p)
A, b = static_h(p)
prob = LinearProblem(A, b)
solve(prob)
end

function static_backslash(p)
A, b = static_h(p)
A \ b
end

@test (ForwardDiff.jacobian(static_linsolve, [5.0, 5.0]) ≈
ForwardDiff.jacobian(static_backslash, [5.0, 5.0]))

#Test to make sure that the cache is not a DualLinearCache
A, b = static_h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])

static_dual_prob = LinearProblem(A, b)
static_dual_cache = init(static_dual_prob)

@test static_dual_cache isa LinearSolve.LinearCache
Loading