Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions test/forwarddiff_overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,49 @@ overload_x_p = solve(prob, UMFPACKFactorization())
backslash_x_p = A \ b

@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9)


# Test StaticArrays
# They don't go through the overloads
# But we should test that the overloads don't mess anything up
function static_h(p)
A = @SMatrix [p[1] p[2]+1 p[2]^3;
3*p[1] p[1]+5 p[2] * p[1]-4;
p[2]^2 9*p[1] p[2]]

b = SA[p[1] + 1, p[2] * 2, p[1]^2]

(A, b)
end

# Check to make sure it's not a DualLinearCache
function static_linsolve_cache_check(p)
A,b = static_h(p)
prob = LinearProblem(A, b)
cache = init(prob)
cache isa LinearCache
end

@test static_linsolve_cache_check([5.0, 5.0])

function static_linsolve(p)
A, b = static_h(p)
prob = LinearProblem(A, b)
solve(prob)
end

function static_backslash(p)
A, b = static_h(p)
A \ b
end

@test (ForwardDiff.jacobian(static_linsolve, [5.0, 5.0]) ≈
ForwardDiff.jacobian(static_backslash, [5.0, 5.0]))

#Test to make sure that the cache is not a DualLinearCache
A, b = static_h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])

static_dual_prob = LinearProblem(A, b)
static_dual_cache = init(static_dual_prob)

@test static_dual_cache isa LinearSolve.LinearCache
Loading