|
3 | 3 | function ChainRulesCore.rrule(::typeof(*), A::Real, B::WoodburyPDMat) |
4 | 4 | project_A = ProjectTo(A) |
5 | 5 | project_B = ProjectTo(B) |
6 | | - function times_pullback(Ȳ::AbstractMatrix) |
7 | | - Ā = @thunk(project_A(dot(Ȳ, B)')) |
8 | | - B̄ = @thunk(project_B(A' * Ȳ)) |
9 | | - return (NoTangent(), Ā, B̄) |
10 | | - end |
11 | | - |
12 | | - function times_pullback(Ȳ::Tangent{<:WoodburyPDMat}) |
13 | | - Ā = dot(Ȳ.A * Ȳ.D * Ȳ.A' + Ȳ.S, B) |
14 | | - B̄ = Ȳ.A * (A' * Ȳ.D) * Ȳ.A' + A' * Ȳ.S |
15 | | - return ( |
16 | | - NoTangent(), |
17 | | - @thunk(project_A(Ā')), |
18 | | - @thunk(project_B(B̄)), |
19 | | - ) |
20 | | - end |
| 6 | + times_pullback(ȳ) = _times_pullback(ȳ, A, B, (;A=project_A, B=project_B)) |
21 | 7 | return A * B, times_pullback |
22 | 8 | end |
23 | 9 |
|
24 | | -function ChainRulesCore.ProjectTo(W::WoodburyPDMat) |
25 | | - function dW(W̄) |
26 | | - Ā(W̄) = ProjectTo(W.A)(collect((W.D * W.A' * W̄' + W.D * W.A' * W̄)')) |
27 | | - D̄(W̄) = ProjectTo(W.D)(W.A' * (W̄) * W.A) |
28 | | - S̄(W̄) = ProjectTo(W.S)(W̄) |
29 | | - return Tangent{typeof(W)}(; A = Ā(W̄), D = D̄(W̄), S = S̄(W̄)) |
30 | | - end |
31 | | - return dW |
| 10 | +function _times_pullback(Ȳ::AbstractMatrix, A, B, proj) |
| 11 | + Ā = proj.A(dot(Ȳ, B)') |
| 12 | + B̄ = proj.B(A' * Ȳ) |
| 13 | + return (NoTangent(), Ā, B̄) |
| 14 | +end |
| 15 | +_times_pullback(ȳ::AbstractThunk, A, B, proj) = _times_pullback(unthunk(ȳ), A, B, proj) |
| 16 | + |
| 17 | +function ChainRulesCore.ProjectTo(W::T) where {T<:WoodburyPDMat} |
| 18 | + fields = (A = W.A, D = W.D, S = W.S) |
| 19 | + ProjectTo{T}(; fields...) |
| 20 | +end |
| 21 | + |
| 22 | +function (W::ProjectTo{T})(W̄) where {T<:WoodburyPDMat} |
| 23 | + Ā(W̄) = ProjectTo(W.A)((W̄ + W̄') * (W.A * W.D)) |
| 24 | + D̄(W̄) = ProjectTo(W.D)(W.A' * (W̄) * W.A) |
| 25 | + S̄(W̄) = ProjectTo(W.S)(W̄) |
| 26 | + return Tangent{T}(; A = Ā(W̄), D = D̄(W̄), S = S̄(W̄)) |
32 | 27 | end |
0 commit comments