Skip to content

Commit 3b39753

Browse files
fix multiple solve handling
1 parent f9b0784 commit 3b39753

File tree

2 files changed

+56
-15
lines changed

2 files changed

+56
-15
lines changed

ext/LinearSolveEnzymeExt.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module LinearSolveEnzymeExt
22

33
using LinearSolve
4+
using LinearSolve.LinearAlgebra
45
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)
56

67

@@ -60,9 +61,6 @@ end
6061
# dA −= z y^T
6162
# dB += z, where z = inv(A^T) dy
6263
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
63-
@assert linsolve.val.isfresh
64-
A_cache = copy(linsolve.val.A)
65-
6664
res = func.val(linsolve.val; kwargs...)
6765

6866
dres = if EnzymeRules.width(config) == 1
@@ -85,12 +83,12 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
8583
(dr.u for dr in dres)
8684
end
8785

88-
cache = (A_cache, res, resvals)
86+
cache = (res, resvals, linsolve.val)
8987
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
9088
end
9189

9290
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
93-
A, y, dys = cache
91+
y, dys, _linsolve = cache
9492

9593
@assert !(typeof(linsolve) <: Const)
9694
@assert !(typeof(linsolve) <: Active)
@@ -112,11 +110,21 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s
112110
end
113111

114112
for (dA, db, dy) in zip(dAs, dbs, dys)
115-
invprob = LinearSolve.LinearProblem(transpose(A), dy)
116-
z = solve(invprob;
117-
abstol = linsolve.val.abstol,
118-
reltol = linsolve.val.reltol,
119-
verbose = linsolve.val.verbose)
113+
z = if linsolve.cacheval isa Factorization
114+
linsolve.cacheval' \ dy
115+
elseif linsolve.cacheval isa Tuple && linsolve.cacheval[1] isa Factorization
116+
linsolve.cacheval[1]' \ dy
117+
elseif linsolve.alg isa AbstractKrylovSubspaceMethod
118+
# Doesn't modify `A`, so it's safe to just reuse it
119+
invprob = LinearSolve.LinearProblem(transpose(linsolve.A), dy)
120+
solve(invprob;
121+
abstol = linsolve.val.abstol,
122+
reltol = linsolve.val.reltol,
123+
verbose = linsolve.val.verbose,
124+
isfresh = freshbefore)
125+
else
126+
error("Algorithm $(linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
127+
end
120128

121129
dA .-= z * transpose(y)
122130
db .+= z

test/enzyme.jl

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ f(A, b1) # Uses BLAS
2020

2121
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1))
2222

23-
dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1, b2), copy(A))
24-
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x, b2), copy(b1))
23+
dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A))
24+
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1))
2525

2626
@test dA dA2
2727
@test db1 db12
@@ -33,6 +33,39 @@ b1 = rand(n);
3333
db1 = zeros(n);
3434
db12 = zeros(n);
3535

36-
# This is not legal, all args need to be batch'd at the same size
37-
@test_broken Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), BatchDuplicated(copy(b1), (db1, db12)))
38-
@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), Duplicated(copy(b1), db1))
36+
@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12)))
37+
38+
dA_2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A))
39+
db1_2 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1))
40+
41+
@test_broken dA dA_2
42+
@test_broken dA2 dA_2
43+
@test_broken db1 db1_2
44+
@test_broken db12 db1_2
45+
46+
function f(A, b1, b2; alg = LUFactorization())
47+
prob = LinearProblem(A, b1)
48+
49+
cache = init(prob, alg)
50+
s1 = solve!(cache).u
51+
cache.b = b2
52+
s2 = solve!(cache).u
53+
norm(s1 + s2)
54+
end
55+
56+
A = rand(n, n);
57+
dA = zeros(n, n);
58+
b1 = rand(n);
59+
db1 = zeros(n);
60+
b2 = rand(n);
61+
db2 = zeros(n);
62+
63+
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))
64+
65+
dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1,b2), copy(A))
66+
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x,b2), copy(b1))
67+
db22 = FiniteDiff.finite_difference_gradient(x->f(A,b1,x), copy(b2))
68+
69+
@test dA dA2 atol=5e-5
70+
@test db1 db12
71+
@test db2 db22

0 commit comments

Comments
 (0)