Skip to content

Commit b71b8ef

Browse files
Simone Carlo Suracesimsurace
authored andcommitted
Further simplify rules
1 parent 8b94cfc commit b71b8ef

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ end
416416
function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractVector{<:Number})
417417
function kron_pullback(z̄)
418418
dz = reshape(unthunk(z̄), length(y), size(x)...)
419-
= @thunk Ref(y') .* eachslice(dz; dims = (2, 3))
419+
= @thunk conj.(dot.(eachslice(dz; dims = (2, 3)), Ref(y)))
420420
ȳ = @thunk conj.(dot.(eachslice(dz; dims = 1), Ref(x)))
421421
return NoTangent(), x̄, ȳ
422422
end
@@ -427,7 +427,7 @@ function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<:
427427
function kron_pullback(z̄)
428428
dz = reshape(unthunk(z̄), size(y, 1), length(x), size(y, 2))
429429
= @thunk conj.(dot.(eachslice(dz; dims = 2), Ref(y)))
430-
ȳ = @thunk Ref(x') .* eachslice(dz; dims = (1, 3))
430+
ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x)))
431431
return NoTangent(), x̄, ȳ
432432
end
433433
return kron(x, y), kron_pullback
@@ -437,7 +437,7 @@ function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:
437437
function kron_pullback(z̄)
438438
dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2))
439439
= @thunk conj.(dot.(eachslice(dz, dims = (2, 4)), Ref(y)))
440-
ȳ = @thunk dot.(eachslice(conj.(dz); dims = (1, 3)), Ref(conj.(x)))
440+
ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x)))
441441
return NoTangent(), x̄, ȳ
442442
end
443443
return kron(x, y), kron_pullback

0 commit comments

Comments
 (0)