Skip to content

Commit dc9a3b4

Browse files
Merge pull request #406 from SciML/enzyme_default
handle the default algorithm with enzyme adjoints
2 parents 6f3f2cd + 0263399 commit dc9a3b4

File tree

3 files changed

+78
-1
lines changed

3 files changed

+78
-1
lines changed

ext/LinearSolveEnzymeExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,12 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s
147147
elseif _linsolve.alg isa LinearSolve.AbstractKrylovSubspaceMethod
148148
# Doesn't modify `A`, so it's safe to just reuse it
149149
invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy)
150-
solve(invprob;
150+
solve(invprob, _linearsolve.alg;
151151
abstol = _linsolve.val.abstol,
152152
reltol = _linsolve.val.reltol,
153153
verbose = _linsolve.val.verbose)
154+
elseif _linsolve.alg isa LinearSolve.DefaultLinearSolver
155+
LinearSolve.defaultalg_adjoint_eval(_linsolve, dy)
154156
else
155157
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
156158
end

src/default.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,61 @@ end
362362
end
363363
ex = Expr(:if, ex.args...)
364364
end
365+
366+
"""
367+
```
368+
elseif DefaultAlgorithmChoice.LUFactorization === cache.alg
369+
(cache.cacheval.LUFactorization)' \\ dy
370+
else
371+
...
372+
end
373+
```
374+
"""
375+
@generated function defaultalg_adjoint_eval(cache::LinearCache, dy)
376+
ex = :()
377+
for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))
378+
newex = if alg in Symbol.((DefaultAlgorithmChoice.MKLLUFactorization,
379+
DefaultAlgorithmChoice.AppleAccelerateLUFactorization,
380+
DefaultAlgorithmChoice.RFLUFactorization))
381+
quote
382+
getproperty(cache.cacheval,$(Meta.quot(alg)))[1]' \ dy
383+
end
384+
elseif alg in Symbol.((DefaultAlgorithmChoice.LUFactorization,
385+
DefaultAlgorithmChoice.QRFactorization,
386+
DefaultAlgorithmChoice.KLUFactorization,
387+
DefaultAlgorithmChoice.UMFPACKFactorization,
388+
DefaultAlgorithmChoice.LDLtFactorization,
389+
DefaultAlgorithmChoice.SparspakFactorization,
390+
DefaultAlgorithmChoice.BunchKaufmanFactorization,
391+
DefaultAlgorithmChoice.CHOLMODFactorization,
392+
DefaultAlgorithmChoice.SVDFactorization,
393+
DefaultAlgorithmChoice.CholeskyFactorization,
394+
DefaultAlgorithmChoice.NormalCholeskyFactorization,
395+
DefaultAlgorithmChoice.QRFactorizationPivoted,
396+
DefaultAlgorithmChoice.GenericLUFactorization))
397+
quote
398+
getproperty(cache.cacheval,$(Meta.quot(alg)))' \ dy
399+
end
400+
elseif alg in Symbol.((DefaultAlgorithmChoice.KrylovJL_GMRES,))
401+
quote
402+
invprob = LinearSolve.LinearProblem(transpose(cache.A), dy)
403+
solve(invprob, cache.alg;
404+
abstol = cache.val.abstol,
405+
reltol = cache.val.reltol,
406+
verbose = cache.val.verbose)
407+
end
408+
else
409+
quote
410+
error("Default linear solver with algorithm $(alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
411+
end
412+
end
413+
414+
ex = if ex == :()
415+
Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg.alg), newex,
416+
:(error("Algorithm Choice not Allowed")))
417+
else
418+
Expr(:elseif, :(getproperty(DefaultAlgorithmChoice, $(Meta.quot(alg))) === cache.alg.alg), newex, ex)
419+
end
420+
end
421+
ex = Expr(:if, ex.args...)
422+
end

test/enzyme.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,30 @@ db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))
2626
@test dA dA2
2727
@test db1 db12
2828

29+
A = rand(n, n);
30+
dA = zeros(n, n);
31+
b1 = rand(n);
32+
db1 = zeros(n);
33+
34+
_ff = (x,y) -> f(x,y; alg = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization))
35+
_ff(copy(A), copy(b1))
36+
37+
Enzyme.autodiff(Reverse, (x,y) -> f(x,y; alg = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)), Duplicated(copy(A), dA), Duplicated(copy(b1), db1))
38+
39+
dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
40+
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))
41+
42+
@test dA dA2
43+
@test db1 db12
44+
2945
A = rand(n, n);
3046
dA = zeros(n, n);
3147
dA2 = zeros(n, n);
3248
b1 = rand(n);
3349
db1 = zeros(n);
3450
db12 = zeros(n);
3551

52+
3653
# Batch test
3754
n = 4
3855
A = rand(n, n);

0 commit comments

Comments
 (0)