Skip to content

Commit b0d228d

Browse files
getting very close
1 parent 9630121 commit b0d228d

File tree

2 files changed

+48
-15
lines changed

2 files changed

+48
-15
lines changed

ext/LinearSolveEnzymeExt.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,12 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
8383
(dr.u for dr in dres)
8484
end
8585

86-
cache = (res, resvals)
86+
cache = (res, resvals, deepcopy(linsolve.val))
8787
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
8888
end
8989

9090
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
91-
y, dys = cache
92-
_linsolve = linsolve.val
91+
y, dys, _linsolve = cache
9392

9493
@assert !(typeof(linsolve) <: Const)
9594
@assert !(typeof(linsolve) <: Active)
@@ -113,9 +112,9 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s
113112
for (dA, db, dy) in zip(dAs, dbs, dys)
114113
z = if _linsolve.cacheval isa Factorization
115114
_linsolve.cacheval' \ dy
116-
elseif linsolve.cacheval isa Tuple && linsolve.cacheval[1] isa Factorization
115+
elseif _linsolve.cacheval isa Tuple && _linsolve.cacheval[1] isa Factorization
117116
_linsolve.cacheval[1]' \ dy
118-
elseif linsolve.alg isa AbstractKrylovSubspaceMethod
117+
elseif _linsolve.alg isa AbstractKrylovSubspaceMethod
119118
# Doesn't modify `A`, so it's safe to just reuse it
120119
invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy)
121120
solve(invprob;

test/enzyme.jl

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Enzyme, FiniteDiff
1+
using Enzyme, ForwardDiff
22
using LinearSolve, LinearAlgebra, Test
33

44
n = 4
@@ -20,8 +20,8 @@ f(A, b1) # Uses BLAS
2020

2121
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1))
2222

23-
dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A))
24-
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1))
23+
dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
24+
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))
2525

2626
@test dA dA2
2727
@test db1 db12
@@ -35,8 +35,8 @@ db12 = zeros(n);
3535

3636
@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12)))
3737

38-
dA_2 = FiniteDiff.finite_difference_gradient(x->f(x,b1), copy(A))
39-
db1_2 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1))
38+
dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
39+
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))
4040

4141
@test_broken dA dA_2
4242
@test_broken dA2 dA_2
@@ -45,9 +45,8 @@ db1_2 = FiniteDiff.finite_difference_gradient(x->f(A,x), copy(b1))
4545

4646
function f(A, b1, b2; alg = LUFactorization())
4747
prob = LinearProblem(A, b1)
48-
4948
cache = init(prob, alg)
50-
s1 = solve!(cache).u
49+
s1 = copy(solve!(cache).u)
5150
cache.b = b2
5251
s2 = solve!(cache).u
5352
norm(s1 + s2)
@@ -60,11 +59,46 @@ db1 = zeros(n);
6059
b2 = rand(n);
6160
db2 = zeros(n);
6261

62+
f(A, b1, b2)
6363
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))
6464

65-
dA2 = FiniteDiff.finite_difference_gradient(x->f(x,b1,b2), copy(A))
66-
db12 = FiniteDiff.finite_difference_gradient(x->f(A,x,b2), copy(b1))
67-
db22 = FiniteDiff.finite_difference_gradient(x->f(A,b1,x), copy(b2))
65+
dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1),eltype(x).(b2)), copy(A))
66+
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x,eltype(x).(b2)), copy(b1))
67+
db22 = ForwardDiff.gradient(x->f(eltype(x).(A),eltype(x).(b1),x), copy(b2))
68+
69+
@test dA dA2
70+
@test db1 db12
71+
@test db2 db22
72+
73+
function f2(A, b1, b2; alg = RFLUFactorization())
74+
prob = LinearProblem(A, b1)
75+
cache = init(prob, alg)
76+
s1 = copy(solve!(cache).u)
77+
cache.b = b2
78+
s2 = solve!(cache).u
79+
norm(s1 + s2)
80+
end
81+
82+
f2(A, b1, b2)
83+
dA = zeros(n, n);
84+
db1 = zeros(n);
85+
db2 = zeros(n);
86+
Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))
87+
88+
@test dA dA2
89+
@test db1 db12
90+
@test db2 db22
91+
92+
function f3(A, b1, b2; alg = KrylovJL_GMRES())
93+
prob = LinearProblem(A, b1)
94+
cache = init(prob, alg)
95+
s1 = solve!(cache).u
96+
cache.b = b2
97+
s2 = solve!(cache).u
98+
norm(s1 + s2)
99+
end
100+
101+
Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))
68102

69103
@test dA dA2 atol=5e-5
70104
@test db1 db12

0 commit comments

Comments
 (0)