@@ -405,37 +405,39 @@ end
405
405
406
406
function rrule (:: typeof (kron), x:: AbstractVector{<:Number} , y:: AbstractVector{<:Number} )
407
407
function kron_pullback (z̄)
408
- dz = reshape (z̄, length (y), length (x))
409
- return NoTangent (), conj .(dz' * y), dz * conj .(x)
408
+ dz = reshape (unthunk (z̄), length (y), length (x))
409
+ x̄ = @thunk conj .(dz' * y)
410
+ ȳ = @thunk dz * conj .(x)
411
+ return NoTangent (), x̄, ȳ
410
412
end
411
413
return kron (x, y), kron_pullback
412
414
end
413
415
414
416
function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractVector{<:Number} )
415
417
function kron_pullback (z̄)
416
- dz = reshape (z̄ , length (y), size (x)... )
417
- x̄ = Ref (y' ) .* eachslice (dz; dims = (2 , 3 ))
418
- ȳ = conj .(dot .(eachslice (dz; dims = 1 ), Ref (x)))
418
+ dz = reshape (unthunk (z̄) , length (y), size (x)... )
419
+ x̄ = @thunk Ref (y' ) .* eachslice (dz; dims = (2 , 3 ))
420
+ ȳ = @thunk conj .(dot .(eachslice (dz; dims = 1 ), Ref (x)))
419
421
return NoTangent (), x̄, ȳ
420
422
end
421
423
return kron (x, y), kron_pullback
422
424
end
423
425
424
426
function rrule (:: typeof (kron), x:: AbstractVector{<:Number} , y:: AbstractMatrix{<:Number} )
425
427
function kron_pullback (z̄)
426
- dz = reshape (z̄ , size (y, 1 ), length (x), size (y, 2 ))
427
- x̄ = conj .(dot .(eachslice (dz; dims = 2 ), Ref (y)))
428
- ȳ = Ref (x' ) .* eachslice (dz; dims = (1 , 3 ))
428
+ dz = reshape (unthunk (z̄) , size (y, 1 ), length (x), size (y, 2 ))
429
+ x̄ = @thunk conj .(dot .(eachslice (dz; dims = 2 ), Ref (y)))
430
+ ȳ = @thunk Ref (x' ) .* eachslice (dz; dims = (1 , 3 ))
429
431
return NoTangent (), x̄, ȳ
430
432
end
431
433
return kron (x, y), kron_pullback
432
434
end
433
435
434
436
function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractMatrix{<:Number} )
435
437
function kron_pullback (z̄)
436
- dz = reshape (z̄ , size (y, 1 ), size (x, 1 ), size (y, 2 ), size (x, 2 ))
437
- x̄ = conj .(dot .(eachslice (dz, dims = (2 , 4 )), Ref (y)))
438
- ȳ = dot .(eachslice (conj .(dz); dims = (1 , 3 )), Ref (conj .(x)))
438
+ dz = reshape (unthunk (z̄) , size (y, 1 ), size (x, 1 ), size (y, 2 ), size (x, 2 ))
439
+ x̄ = @thunk conj .(dot .(eachslice (dz, dims = (2 , 4 )), Ref (y)))
440
+ ȳ = @thunk dot .(eachslice (conj .(dz); dims = (1 , 3 )), Ref (conj .(x)))
439
441
return NoTangent (), x̄, ȳ
440
442
end
441
443
return kron (x, y), kron_pullback
0 commit comments