Skip to content

Commit c79f177

Browse files
AlexRobsonmzgubic
andauthored
Update src/chainrules.jl
Co-authored-by: Miha Zgubic <[email protected]>
1 parent bbfe9b2 commit c79f177

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

src/chainrules.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,9 @@ function ChainRulesCore.rrule(::typeof(*), A::Real, B::WoodburyPDMat)
44
project_A = ProjectTo(A)
55
project_B = ProjectTo(B)
66
function times_pullback(ȳ::AbstractMatrix)
7-
= unthunk(ȳ)
8-
= dot(Ȳ, B)
9-
= A' *
10-
return (
11-
NoTangent(),
12-
@thunk(project_A(Ā')),
13-
@thunk(project_B(B̄)),
14-
)
7+
Ā = @thunk(project_A(dot(Ȳ, B)'))
8+
= @thunk(project_B(A' * Ȳ))
9+
return (NoTangent(), Ā, B̄)
1510
end
1611

1712
function times_pullback(ȳ::Tangent{<:WoodburyPDMat})
@@ -35,4 +30,4 @@ function ChainRulesCore.ProjectTo(W::WoodburyPDMat)
3530
return Tangent{typeof(W)}(; A = (W̄), D = (W̄), S = (W̄))
3631
end
3732
return dW
38-
end
33+
end

0 commit comments

Comments
 (0)