Skip to content

Commit f6db1da

Browse files
author
Alex Robson
committed
Use Functor for ProjectTo. Update test
1 parent b3ff0c2 commit f6db1da

File tree

2 files changed

+39
-23
lines changed

2 files changed

+39
-23
lines changed

src/chainrules.jl

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,25 @@
33
function ChainRulesCore.rrule(::typeof(*), A::Real, B::WoodburyPDMat)
44
project_A = ProjectTo(A)
55
project_B = ProjectTo(B)
6-
function times_pullback::AbstractMatrix)
7-
Ā = @thunk(project_A(dot(Ȳ, B)'))
8-
= @thunk(project_B(A' * Ȳ))
9-
return (NoTangent(), Ā, B̄)
10-
end
11-
12-
function times_pullback::Tangent{<:WoodburyPDMat})
13-
= dot(Ȳ.A *.D *.A' +.S, B)
14-
=.A * (A' *.D) *.A' + A' *.S
15-
return (
16-
NoTangent(),
17-
@thunk(project_A(Ā')),
18-
@thunk(project_B(B̄)),
19-
)
20-
end
6+
times_pullback(ȳ) = _times_pullback(ȳ, A, B, (;A=project_A, B=project_B))
217
return A * B, times_pullback
228
end
239

24-
function ChainRulesCore.ProjectTo(W::WoodburyPDMat)
25-
function dW(W̄)
26-
(W̄) = ProjectTo(W.A)(collect((W.D * W.A' *' + W.D * W.A' * W̄)'))
27-
(W̄) = ProjectTo(W.D)(W.A' * (W̄) * W.A)
28-
(W̄) = ProjectTo(W.S)(W̄)
29-
return Tangent{typeof(W)}(; A = (W̄), D = (W̄), S = (W̄))
30-
end
31-
return dW
10+
function _times_pullback::AbstractMatrix, A, B, proj)
11+
Ā = proj.A(dot(Ȳ, B)')
12+
= proj.B(A' * Ȳ)
13+
return (NoTangent(), Ā, B̄)
14+
end
15+
_times_pullback::AbstractThunk, A, B, proj) = _times_pullback(unthunk(ȳ), A, B, proj)
16+
17+
function ChainRulesCore.ProjectTo(W::T) where {T<:WoodburyPDMat}
18+
fields = (A = W.A, D = W.D, S = W.S)
19+
ProjectTo{T}(; fields...)
20+
end
21+
22+
function (W::ProjectTo{T})(W̄) where {T<:WoodburyPDMat}
23+
(W̄) = ProjectTo(W.A)((W̄ +') * (W.A * W.D))
24+
(W̄) = ProjectTo(W.D)(W.A' * (W̄) * W.A)
25+
(W̄) = ProjectTo(W.S)(W̄)
26+
return Tangent{T}(; A = (W̄), D = (W̄), S = (W̄))
3227
end

test/chainrules.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,27 @@ LinearAlgebra.dot(A, B::WoodburyLike) = dot(A, WoodburyPDMat(B))
7171
# Cannot get this to work. Here the T will be
7272
# T = rand_tangent(primal::WoodburyPDMat) which breaks.
7373
# test_rrule(*, 5.0, W)
74+
75+
#####################################################################################################################
76+
77+
primal = R * W
78+
79+
# Generate the Tangent as ChainRulesTestUtils would do
80+
∂primal = rand_tangent(Random.GLOBAL_RNG, collect(primal))
81+
T = ProjectTo(primal)(∂primal)
82+
83+
f_jvp = j′vp(ChainRulesTestUtils._fdm, x -> (*(x...)), T, (R, W))[1]
84+
85+
# Expected
86+
= ProjectTo(R)(dot(∂primal, W'))
87+
= ProjectTo(W)(conj(R) * ∂primal)
88+
89+
@test res[1] == primal
90+
@test f_jvp[1]
91+
@test.A f_jvp[2].A
92+
@test.D f_jvp[2].D
93+
@test.S f_jvp[2].S
94+
7495
end
7596
end
7697
end

0 commit comments

Comments
 (0)