Skip to content

Commit 277c4f8

Browse files
committed
bring in linalg, add tols to tests
1 parent 9aa8b19 commit 277c4f8

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module LinearSolveForwardDiffExt
22

33
using LinearSolve
4+
using LinearAlgebra
45
using ForwardDiff
56
using ForwardDiff: Dual, Partials
67
using SciMLBase
@@ -53,6 +54,8 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
5354
partial_prob = LinearProblem(cache.cache.A, rhs_list[1])
5455
partial_cache = init(partial_prob, alg, args...; kwargs...)
5556

57+
Main.@infiltrate
58+
5659
for i in eachindex(rhs_list)
5760
partial_cache.b = rhs_list[i]
5861
rhs_list[i] = copy(solve!(partial_cache, alg).u)
@@ -107,7 +110,6 @@ function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials
107110
b_list = partials_to_list(∂_b)
108111

109112
Auu = [A * uu for A in A_list]
110-
111113
b_list .- Auu
112114
end
113115

@@ -117,13 +119,13 @@ function xp_linsolve_rhs(
117119

118120
Auu = [A * uu for A in A_list]
119121

120-
Auu
122+
-Auu
121123
end
122124

123125
function xp_linsolve_rhs(
124126
uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
125127
b_list = partials_to_list(∂_b)
126-
128+
Main.@infiltrate
127129
b_list
128130
end
129131

test/forwarddiff_overloads.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
1313

1414
prob = LinearProblem(A, b)
1515
overload_x_p = solve(prob)
16-
original_x_p = solve!(init(prob))
16+
original_x_p = A \ b
1717

18-
@test overload_x_p original_x_p
18+
@test (overload_x_p, original_x_p, rtol = 1e-9)
1919

2020
A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
2121
prob = LinearProblem(A, [6.0, 10.0, 25.0])
22-
@test solve(prob).retcode == ReturnCode.Default
22+
@test (solve(prob).u, A \ [6.0, 10.0, 25.0], rtol = 1e-9)
2323

2424
_, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
2525
A = [5.0 6.0 125.0; 15.0 10.0 21.0; 25.0 45.0 5.0]
2626
prob = LinearProblem(A, b)
27-
@test solve(prob).retcode == ReturnCode.Default
27+
@test (solve(prob).u, A \ b, rtol = 1e-9)

0 commit comments

Comments
 (0)