@@ -8,31 +8,84 @@ using Enzyme
8
8
9
9
using EnzymeCore
10
10
11
+ function EnzymeCore. EnzymeRules. augmented_primal (config, func:: Const{typeof(LinearSolve.init)} , :: Type{RT} , prob:: EnzymeCore.Annotation{LP} , alg:: Const ; kwargs... ) where {RT, LP <: LinearSolve.LinearProblem }
12
+ res = func. val (prob. val, alg. val; kwargs... )
13
+ dres = if EnzymeRules. width (config) == 1
14
+ func. val (prob. dval, alg. val; kwargs... )
15
+ else
16
+ (func. val (dval, alg. val; kwargs... ) for dval in prob. dval)
17
+ end
18
+ return EnzymeCore. EnzymeRules. AugmentedReturn (res, dres, nothing )
19
+ end
20
+
21
+ function EnzymeCore. EnzymeRules. reverse (config, func:: Const{typeof(LinearSolve.init)} , :: Type{RT} , cache, prob:: EnzymeCore.Annotation{LP} , alg:: Const ; kwargs... ) where {RT, LP <: LinearSolve.LinearProblem }
22
+ return (nothing , nothing )
23
+ end
24
+
11
25
# y=inv(A) B
12
26
# dA −= z y^T
13
27
# dB += z, where z = inv(A^T) dy
14
- function EnzymeCore. EnzymeRules. augmented_primal (config:: EnzymeCore.EnzymeRules.ConfigWidth{1} , func:: Const{typeof(LinearSolve.solve)} , :: Type{Duplicated{RT}} , prob:: Duplicated{LP} , alg:: Const ; kwargs... ) where {RT, LP <: LinearProblem }
15
- res = func. val (prob. val, alg. val; kwargs... )
16
- dres = deepcopy (res)
17
- dres. u .= 0
18
- cache = (copy (prob. val. A), res, dres. u)
19
- return EnzymeCore. EnzymeRules. AugmentedReturn {RT, RT, typeof(cache)} (res, dres, cache)
28
+ function EnzymeCore. EnzymeRules. augmented_primal (config, func:: Const{typeof(LinearSolve.solve!)} , :: Type{RT} , linsolve:: EnzymeCore.Annotation{LP} ; kwargs... ) where {RT, LP <: LinearSolve.LinearCache }
29
+ res = func. val (linsolve. val; kwargs... )
30
+ dres = if EnzymeRules. width (config) == 1
31
+ deepcopy (res)
32
+ else
33
+ (deepcopy (res) for dval in linsolve. dval)
34
+ end
35
+
36
+ if EnzymeRules. width (config) == 1
37
+ dres. u .= 0
38
+ else
39
+ for dr in dres
40
+ dr. u .= 0
41
+ end
42
+ end
43
+
44
+ resvals = if EnzymeRules. width (config) == 1
45
+ dres. u
46
+ else
47
+ (dr. u for dr in dres)
48
+ end
49
+
50
+ cache = (copy (linsolve. val. A), res, resvals)
51
+ return EnzymeCore. EnzymeRules. AugmentedReturn (res, dres, cache)
20
52
end
21
53
22
- function EnzymeCore. EnzymeRules. reverse (config:: EnzymeCore.EnzymeRules.ConfigWidth{1} , func:: Const{typeof(LinearSolve.solve)} , :: Type{Duplicated{ RT}} , cache, prob :: Duplicated {LP}, alg :: Const ; kwargs... ) where {RT, LP <: LinearProblem }
23
- A, y, dy = cache
54
+ function EnzymeCore. EnzymeRules. reverse (config, func:: Const{typeof(LinearSolve.solve! )} , :: Type{RT} , cache, linsolve :: EnzymeCore.Annotation {LP} ; kwargs... ) where {RT, LP <: LinearSolve.LinearCache }
55
+ A, y, dys = cache
24
56
25
- dA = prob . dval . A
26
- db = prob . dval . b
57
+ @assert ! ( typeof (linsolve) <: Const )
58
+ @assert ! ( typeof (linsolve) <: Active )
27
59
28
- invprob = LinearProblem (transpose (A), dy)
60
+ if EnzymeRules. width (config) == 1
61
+ dys = (dys,)
62
+ end
29
63
30
- z = func. val (invprob, alg; kwargs... )
64
+ dAs = if EnzymeRules. width (config) == 1
65
+ (linsolve. dval. A,)
66
+ else
67
+ (dval. A for dval in linsolve. dval)
68
+ end
31
69
32
- dA .- = z * transpose (y)
33
- db .+ = z
34
- dy .= 0
35
- return (nothing , nothing )
70
+ dbs = if EnzymeRules. width (config) == 1
71
+ (linsolve. dval. b,)
72
+ else
73
+ (dval. b for dval in linsolve. dval)
74
+ end
75
+
76
+ for (dA, db, dy) in zip (dAs, dbs, dys)
77
+ invprob = LinearSolve. LinearProblem (transpose (A), dy)
78
+ z = solve (invprob;
79
+ abstol = linsolve. val. abstol,
80
+ reltol = linsolve. val. reltol,
81
+ verbose = linsolve. val. verbose)
82
+
83
+ dA .- = z * transpose (y)
84
+ db .+ = z
85
+ dy .= eltype (dy)(0 )
86
+ end
87
+
88
+ return (nothing ,)
36
89
end
37
90
38
91
end
0 commit comments