You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
alg::Const; kwargs...) where {RT, LP <:LinearSolve.LinearProblem}
14
14
@assert!(prob isa Const)
15
15
res = func.val(prob.val, alg.val; kwargs...)
16
16
if RT <:Const
17
-
return res
17
+
if EnzymeRules.needs_primal(config)
18
+
return res
19
+
else
20
+
returnnothing
21
+
end
18
22
end
19
23
dres = func.val(prob.dval, alg.val; kwargs...)
20
24
dres.b .= res.b == dres.b ?zero(dres.b) : dres.b
@@ -25,17 +29,31 @@ function EnzymeCore.EnzymeRules.forward(
25
29
returnDuplicated(res, dres)
26
30
end
27
31
error("Unsupported return type $RT")
32
+
33
+
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
34
+
Duplicated(res, dres)
35
+
elseif EnzymeRules.needs_shadow(config)
36
+
dres
37
+
elseif EnzymeRules.needs_primal(config)
38
+
res
39
+
else
40
+
nothing
41
+
end
28
42
end
29
43
30
-
function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)},
44
+
function EnzymeCore.EnzymeRules.forward(config::ConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)},
31
45
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
32
46
kwargs...) where {RT, LP <:LinearSolve.LinearCache}
33
47
@assert!(linsolve isa Const)
34
48
35
49
res = func.val(linsolve.val; kwargs...)
36
50
37
51
if RT <:Const
38
-
return res
52
+
if EnzymeRules.needs_primal(config)
53
+
return res
54
+
else
55
+
returnnothing
56
+
end
39
57
end
40
58
if linsolve.val.alg isa LinearSolve.AbstractKrylovSubspaceMethod
41
59
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
@@ -50,13 +68,15 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)},
50
68
51
69
linsolve.val.b = b
52
70
53
-
if RT <:DuplicatedNoNeed
54
-
return dres
55
-
elseif RT <:Duplicated
56
-
returnDuplicated(res, dres)
71
+
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
0 commit comments