Skip to content

Commit 447fb61

Browse files
sharanryChrisRackauckas
authored andcommitted
Avoid refactorization and cleanup tests to allow for numerical errors due to summation ordering
1 parent 8b8f8c4 commit 447fb61

File tree

2 files changed

+11
-20
lines changed

2 files changed

+11
-20
lines changed

ext/LinearSolveEnzymeExt.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,22 @@ end
2828
function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
2929
@assert !(linsolve isa Const)
3030

31-
A = deepcopy(linsolve.val.A) #mutates after function is applied
31+
linsolve = deepcopy(linsolve) #mutates after function is applied
3232
res = func.val(linsolve.val; kwargs...)
3333

3434
if RT <: Const
3535
return res
3636
end
37-
38-
dres = deepcopy(res)
37+
38+
b = deepcopy(linsolve.val.b)
39+
3940
db = linsolve.dval.b
4041
dA = linsolve.dval.A
41-
dres.u .= A \ (db - dA * res.u)
42+
43+
linsolve.val.b = db - dA * res.u
44+
dres = func.val(linsolve.val; kwargs...)
45+
46+
linsolve.val.b = b
4247

4348
if RT <: DuplicatedNoNeed
4449
return dres

test/enzyme.jl

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,6 @@ function fb(b; alg = LUFactorization())
177177
end
178178
fb(b1)
179179

180-
manual_jac = map(onehot(b1)) do db
181-
y = A \ b1
182-
sum(inv(A) * (db - dA*y))
183-
end |> collect
184-
@show manual_jac
185-
186180
fd_jac = FiniteDiff.finite_difference_jacobian(fb, b1) |> vec
187181
@show fd_jac
188182

@@ -192,8 +186,7 @@ en_jac = map(onehot(b1)) do db1
192186
end |> collect
193187
@show en_jac
194188

195-
@test_broken en_jac manual_jac
196-
@test_broken en_jac fd_jac
189+
@test en_jac fd_jac atol=1e-6
197190

198191
function fA(A; alg = LUFactorization())
199192
prob = LinearProblem(A, b1)
@@ -204,12 +197,6 @@ function fA(A; alg = LUFactorization())
204197
end
205198
fA(A)
206199

207-
manual_jac = map(onehot(A)) do dA
208-
y = A \ b1
209-
sum(inv(A) * (db1 - dA*y))
210-
end |> collect
211-
@show manual_jac
212-
213200
fd_jac = FiniteDiff.finite_difference_jacobian(fA, A) |> vec
214201
@show fd_jac
215202

@@ -219,5 +206,4 @@ en_jac = map(onehot(A)) do dA
219206
end |> collect
220207
@show en_jac
221208

222-
@test_broken en_jac manual_jac
223-
@test_broken en_jac fd_jac
209+
@test en_jac fd_jac atol=1e-6

0 commit comments

Comments
 (0)