Skip to content

Commit 248718f

Browse files
committed
add tests for sparse arrays and sparse solvers
1 parent 762bfb1 commit 248718f

File tree

1 file changed

+42
-1
lines changed

1 file changed

+42
-1
lines changed

test/forwarddiff_overloads.jl

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using LinearSolve
22
using ForwardDiff
33
using Test
4+
using SparseArrays
45

56
function h(p)
67
(A = [p[1] p[2]+1 p[2]^3;
@@ -48,6 +49,9 @@ new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.
4849
cache.A = new_A
4950
cache.b = new_b
5051

52+
@test cache.A == new_A
53+
@test cache.b == new_b
54+
5155
x_p = solve!(cache)
5256
backslash_x_p = new_A \ new_b
5357

@@ -61,7 +65,7 @@ cache = init(prob)
6165

6266
new_A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
6367
cache.A = new_A
64-
@test cache.A = new_A
68+
@test cache.A == new_A
6569

6670
x_p = solve!(cache)
6771
backslash_x_p = new_A \ b
@@ -139,3 +143,40 @@ end
139143

140144
@test (ForwardDiff.hessian(slash_f_hes, [5.0]),
141145
ForwardDiff.hessian(linprob_f_hes, [5.0]))
146+
147+
148+
# Test aliasing
149+
150+
prob = LinearProblem(A, b)
151+
cache = init(prob)
152+
153+
new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
154+
cache.A = new_A
155+
cache.b = new_b
156+
157+
linu = [ForwardDiff.Dual(0.0, 0.0, 0.0), ForwardDiff.Dual(0.0, 0.0, 0.0),
158+
ForwardDiff.Dual(0.0, 0.0, 0.0)]
159+
cache.u = linu
160+
x_p = solve!(cache)
161+
backslash_x_p = new_A \ new_b
162+
163+
@test linu == cache.u
164+
165+
166+
# Test Float Only solvers
167+
168+
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
169+
170+
prob = LinearProblem(sparse(A), sparse(b))
171+
overload_x_p = solve(prob, KLUFactorization())
172+
backslash_x_p = A \ b
173+
174+
@test (overload_x_p, backslash_x_p, rtol = 1e-9)
175+
176+
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
177+
178+
prob = LinearProblem(A, b)
179+
overload_x_p = solve(prob, UMFPACKFactorization())
180+
backslash_x_p = A \ b
181+
182+
@test (overload_x_p, backslash_x_p, rtol = 1e-9)

0 commit comments

Comments
 (0)