416
416
function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractVector{<:Number} )
417
417
function kron_pullback (z̄)
418
418
dz = reshape (unthunk (z̄), length (y), size (x)... )
419
- x̄ = @thunk Ref (y ' ) .* eachslice (dz; dims = (2 , 3 ))
419
+ x̄ = @thunk conj .( dot .( eachslice (dz; dims = (2 , 3 )), Ref (y) ))
420
420
ȳ = @thunk conj .(dot .(eachslice (dz; dims = 1 ), Ref (x)))
421
421
return NoTangent (), x̄, ȳ
422
422
end
@@ -427,7 +427,7 @@ function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<:
427
427
function kron_pullback (z̄)
428
428
dz = reshape (unthunk (z̄), size (y, 1 ), length (x), size (y, 2 ))
429
429
x̄ = @thunk conj .(dot .(eachslice (dz; dims = 2 ), Ref (y)))
430
- ȳ = @thunk Ref (x ' ) .* eachslice (dz; dims = (1 , 3 ))
430
+ ȳ = @thunk conj .( dot .( eachslice (dz; dims = (1 , 3 )), Ref (x) ))
431
431
return NoTangent (), x̄, ȳ
432
432
end
433
433
return kron (x, y), kron_pullback
@@ -437,7 +437,7 @@ function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:
437
437
function kron_pullback (z̄)
438
438
dz = reshape (unthunk (z̄), size (y, 1 ), size (x, 1 ), size (y, 2 ), size (x, 2 ))
439
439
x̄ = @thunk conj .(dot .(eachslice (dz, dims = (2 , 4 )), Ref (y)))
440
- ȳ = @thunk dot .(eachslice (conj .(dz) ; dims = (1 , 3 )), Ref ( conj . (x)))
440
+ ȳ = @thunk conj .( dot .(eachslice (dz ; dims = (1 , 3 )), Ref (x)))
441
441
return NoTangent (), x̄, ȳ
442
442
end
443
443
return kron (x, y), kron_pullback
0 commit comments