@@ -399,46 +399,48 @@ end
399
399
# #### `kron`
400
400
# ####
401
401
402
- function frule ((_, Δx, Δy), :: typeof (kron), x:: AbstractVecOrMat{<:Number} , y:: AbstractVecOrMat{<:Number} )
403
- return kron (x, y), kron (Δx, y) + kron (x, Δy)
404
- end
402
+ @static if VERSION ≥ v " 1.9.0"
403
+ function frule ((_, Δx, Δy), :: typeof (kron), x:: AbstractVecOrMat{<:Number} , y:: AbstractVecOrMat{<:Number} )
404
+ return kron (x, y), kron (Δx, y) + kron (x, Δy)
405
+ end
405
406
406
- function rrule (:: typeof (kron), x:: AbstractVector{<:Number} , y:: AbstractVector{<:Number} )
407
- function kron_pullback (z̄)
408
- dz = reshape (unthunk (z̄), length (y), length (x))
409
- x̄ = @thunk conj .(dz' * y)
410
- ȳ = @thunk dz * conj .(x)
411
- return NoTangent (), x̄, ȳ
407
+ function rrule (:: typeof (kron), x:: AbstractVector{<:Number} , y:: AbstractVector{<:Number} )
408
+ function kron_pullback (z̄)
409
+ dz = reshape (unthunk (z̄), length (y), length (x))
410
+ x̄ = @thunk conj .(dz' * y)
411
+ ȳ = @thunk dz * conj .(x)
412
+ return NoTangent (), x̄, ȳ
413
+ end
414
+ return kron (x, y), kron_pullback
412
415
end
413
- return kron (x, y), kron_pullback
414
- end
415
416
416
- function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractVector{<:Number} )
417
- function kron_pullback (z̄)
418
- dz = reshape (unthunk (z̄), length (y), size (x)... )
419
- x̄ = @thunk conj .(dot .(eachslice (dz; dims = (2 , 3 )), Ref (y)))
420
- ȳ = @thunk conj .(dot .(eachslice (dz; dims = 1 ), Ref (x)))
421
- return NoTangent (), x̄, ȳ
417
+ function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractVector{<:Number} )
418
+ function kron_pullback (z̄)
419
+ dz = reshape (unthunk (z̄), length (y), size (x)... )
420
+ x̄ = @thunk conj .(dot .(eachslice (dz; dims = (2 , 3 )), Ref (y)))
421
+ ȳ = @thunk conj .(dot .(eachslice (dz; dims = 1 ), Ref (x)))
422
+ return NoTangent (), x̄, ȳ
423
+ end
424
+ return kron (x, y), kron_pullback
422
425
end
423
- return kron (x, y), kron_pullback
424
- end
425
426
426
- function rrule (:: typeof (kron), x:: AbstractVector{<:Number} , y:: AbstractMatrix{<:Number} )
427
- function kron_pullback (z̄)
428
- dz = reshape (unthunk (z̄), size (y, 1 ), length (x), size (y, 2 ))
429
- x̄ = @thunk conj .(dot .(eachslice (dz; dims = 2 ), Ref (y)))
430
- ȳ = @thunk conj .(dot .(eachslice (dz; dims = (1 , 3 )), Ref (x)))
431
- return NoTangent (), x̄, ȳ
427
+ function rrule (:: typeof (kron), x:: AbstractVector{<:Number} , y:: AbstractMatrix{<:Number} )
428
+ function kron_pullback (z̄)
429
+ dz = reshape (unthunk (z̄), size (y, 1 ), length (x), size (y, 2 ))
430
+ x̄ = @thunk conj .(dot .(eachslice (dz; dims = 2 ), Ref (y)))
431
+ ȳ = @thunk conj .(dot .(eachslice (dz; dims = (1 , 3 )), Ref (x)))
432
+ return NoTangent (), x̄, ȳ
433
+ end
434
+ return kron (x, y), kron_pullback
432
435
end
433
- return kron (x, y), kron_pullback
434
- end
435
436
436
- function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractMatrix{<:Number} )
437
- function kron_pullback (z̄)
438
- dz = reshape (unthunk (z̄), size (y, 1 ), size (x, 1 ), size (y, 2 ), size (x, 2 ))
439
- x̄ = @thunk conj .(dot .(eachslice (dz, dims = (2 , 4 )), Ref (y)))
440
- ȳ = @thunk conj .(dot .(eachslice (dz; dims = (1 , 3 )), Ref (x)))
441
- return NoTangent (), x̄, ȳ
437
+ function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractMatrix{<:Number} )
438
+ function kron_pullback (z̄)
439
+ dz = reshape (unthunk (z̄), size (y, 1 ), size (x, 1 ), size (y, 2 ), size (x, 2 ))
440
+ x̄ = @thunk conj .(dot .(eachslice (dz, dims = (2 , 4 )), Ref (y)))
441
+ ȳ = @thunk conj .(dot .(eachslice (dz; dims = (1 , 3 )), Ref (x)))
442
+ return NoTangent (), x̄, ȳ
443
+ end
444
+ return kron (x, y), kron_pullback
442
445
end
443
- return kron (x, y), kron_pullback
444
446
end
0 commit comments