Skip to content

Commit 31b758a

Browse files
author
Simone Carlo Surace
committed
Add projections and remove redundant conj calls
1 parent 649e797 commit 31b758a

File tree

2 files changed

+27
-15
lines changed

2 files changed

+27
-15
lines changed

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -417,30 +417,36 @@ end
417417
end
418418

419419
function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractVector{<:Number})
420+
project_x = ProjectTo(x)
421+
project_y = ProjectTo(y)
420422
function kron_pullback(z̄)
421423
dz = reshape(unthunk(z̄), length(y), size(x)...)
422-
= @thunk conj.(dot.(eachslice(dz; dims = (2, 3)), Ref(y)))
423-
ȳ = @thunk conj.(dot.(eachslice(dz; dims = 1), Ref(x)))
424+
= @thunk(project_x(dot.(Ref(y), eachslice(dz; dims = (2, 3)))))
425+
ȳ = @thunk(project_y(dot.(Ref(x), eachslice(dz; dims = 1))))
424426
return NoTangent(), x̄, ȳ
425427
end
426428
return kron(x, y), kron_pullback
427429
end
428430

429431
function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<:Number})
432+
project_x = ProjectTo(x)
433+
project_y = ProjectTo(y)
430434
function kron_pullback(z̄)
431435
dz = reshape(unthunk(z̄), size(y, 1), length(x), size(y, 2))
432-
= @thunk conj.(dot.(eachslice(dz; dims = 2), Ref(y)))
433-
ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x)))
436+
= @thunk(project_x(dot.(Ref(y), eachslice(dz; dims = 2))))
437+
ȳ = @thunk(project_y(dot.(Ref(x), eachslice(dz; dims = (1, 3)))))
434438
return NoTangent(), x̄, ȳ
435439
end
436440
return kron(x, y), kron_pullback
437441
end
438442

439443
function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:Number})
444+
project_x = ProjectTo(x)
445+
project_y = ProjectTo(y)
440446
function kron_pullback(z̄)
441447
dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2))
442-
= @thunk conj.(dot.(eachslice(dz, dims = (2, 4)), Ref(y)))
443-
ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x)))
448+
= @thunk(project_x(dot.(Ref(y), eachslice(dz; dims = (2, 4)))))
449+
ȳ = @thunk(project_y(dot.(Ref(x), eachslice(dz; dims = (1, 3)))))
444450
return NoTangent(), x̄, ȳ
445451
end
446452
return kron(x, y), kron_pullback

test/rulesets/LinearAlgebra/dense.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,18 +160,24 @@
160160
end
161161
end
162162
@testset "kron" begin
163-
@testset "AbstractVecOrMat{$T}" for T in (Float64, ComplexF64)
163+
@testset "AbstractVecOrMat{$T1}, AbstractVecOrMat{$T2}" for T1 in (Float64, ComplexF64), T2 in (Float64, ComplexF64)
164164
@testset "frule" begin
165-
test_frule(kron, randn(T, 2), randn(T, 3))
166-
test_frule(kron, randn(T, 2, 3), randn(T, 5))
167-
test_frule(kron, randn(T, 2), randn(T, 3, 5))
168-
test_frule(kron, randn(T, 2, 3), randn(T, 5, 7))
165+
test_frule(kron, randn(T1, 2), randn(T2, 3))
166+
test_frule(kron, randn(T1, 2, 3), randn(T2, 5))
167+
test_frule(kron, randn(T1, 2), randn(T2, 3, 5))
168+
test_frule(kron, randn(T1, 2, 3), randn(T2, 5, 7))
169169
end
170170
@testset "rrule" begin
171-
test_rrule(kron, randn(T, 2), randn(T, 3))
172-
test_rrule(kron, randn(T, 2, 3), randn(T, 5))
173-
test_rrule(kron, randn(T, 2), randn(T, 3, 5))
174-
test_rrule(kron, randn(T, 2, 3), randn(T, 5, 7))
171+
test_rrule(kron, randn(T1, 2), randn(T2, 3))
172+
173+
test_rrule(kron, Diagonal(randn(T1, 2)), randn(T2, 3))
174+
test_rrule(kron, randn(T1, 2, 3), randn(T2, 5))
175+
176+
test_rrule(kron, randn(T1, 2), randn(T2, 3, 5))
177+
test_rrule(kron, randn(T1, 2), Diagonal(randn(T2, 3)))
178+
179+
test_rrule(kron, randn(T1, 2, 3), randn(T2, 5, 7))
180+
test_rrule(kron, Diagonal(randn(T1, 2)), Diagonal(randn(T2, 3)))
175181
end
176182
end
177183
end

0 commit comments

Comments
 (0)