Skip to content

Commit fde509e

Browse files
Simone Carlo Suracesimsurace
authored andcommitted
Add projections
1 parent 2ad5473 commit fde509e

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,10 +405,12 @@ end
405405
end
406406

407407
function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractVector{<:Number})
408+
project_x = ProjectTo(x)
409+
project_y = ProjectTo(y)
408410
function kron_pullback(z̄)
409411
dz = reshape(unthunk(z̄), length(y), length(x))
410-
= @thunk conj.(dz' * y)
411-
ȳ = @thunk dz * conj.(x)
412+
= @thunk(project_x(conj.(dz' * y)))
413+
ȳ = @thunk(project_y(dz * conj.(x)))
412414
return NoTangent(), x̄, ȳ
413415
end
414416
return kron(x, y), kron_pullback

0 commit comments

Comments
 (0)