You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <:LinearSolve.LinearProblem}
alg::Const; kwargs...) where {RT, LP <:LinearSolve.LinearProblem}
13
14
@assert!(prob isa Const)
14
15
res = func.val(prob.val, alg.val; kwargs...)
15
16
if RT <:Const
@@ -26,11 +27,13 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.init)}, :
26
27
error("Unsupported return type $RT")
27
28
end
28
29
29
-
function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <:LinearSolve.LinearCache}
30
+
function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)},
31
+
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
32
+
kwargs...) where {RT, LP <:LinearSolve.LinearCache}
30
33
@assert!(linsolve isa Const)
31
34
32
35
res = func.val(linsolve.val; kwargs...)
33
-
36
+
34
37
if RT <:Const
35
38
return res
36
39
end
@@ -56,7 +59,10 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)},
56
59
returnDuplicated(res, dres)
57
60
end
58
61
59
-
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <:LinearSolve.LinearProblem}
kwargs...) where {RT, LP <:LinearSolve.LinearProblem}
96
104
d_A, d_b, prob_d_A, prob_d_b = cache
97
105
98
106
if EnzymeRules.width(config) ==1
@@ -105,7 +113,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.i
105
113
d_b .=0
106
114
end
107
115
else
108
-
for (_prob_d_A,_d_A,_prob_d_b, _d_b) inzip(prob_d_A, d_A, prob_d_b, d_b)
116
+
for (_prob_d_A,_d_A,_prob_d_b, _d_b) inzip(prob_d_A, d_A, prob_d_b, d_b)
109
117
if _d_A !== _prob_d_A
110
118
_prob_d_A .+= _d_A
111
119
_d_A .=0
@@ -123,7 +131,10 @@ end
123
131
# y=inv(A) B
124
132
# dA −= z y^T
125
133
# dB += z, where z = inv(A^T) dy
126
-
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <:LinearSolve.LinearCache}
134
+
function EnzymeCore.EnzymeRules.augmented_primal(
135
+
config, func::Const{typeof(LinearSolve.solve!)},
136
+
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
137
+
kwargs...) where {RT, LP <:LinearSolve.LinearCache}
127
138
res = func.val(linsolve.val; kwargs...)
128
139
129
140
dres =if EnzymeRules.width(config) ==1
@@ -176,7 +187,9 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
0 commit comments