@@ -403,74 +403,40 @@ function frule((_, Δx, Δy), ::typeof(kron), x::AbstractVecOrMat{<:Number}, y::
403
403
return kron (x, y), kron (Δx, y) + kron (x, Δy)
404
404
end
405
405
406
- function rrule (:: typeof (kron), x:: AbstractVector , y:: AbstractVector )
406
+ function rrule (:: typeof (kron), x:: AbstractVector{<:Number} , y:: AbstractVector{<:Number} )
407
407
function kron_pullback (z̄)
408
- x̄ = zero (x)
409
- ȳ = zero (y)
410
- m = firstindex (z̄)
411
- @inbounds for i in eachindex (x)
412
- xi = x[i]
413
- for k in eachindex (y)
414
- x̄[i] += y[k]' * z̄[m]
415
- ȳ[k] += xi' * z̄[m]
416
- m += 1
417
- end
418
- end
419
- NoTangent (), x̄, ȳ
408
+ dz = reshape (z̄, length (y), length (x))
409
+ return NoTangent (), conj .(dz' * y), dz * conj .(x)
420
410
end
421
- kron (x, y), kron_pullback
411
+ return kron (x, y), kron_pullback
422
412
end
423
413
424
- function rrule (:: typeof (kron), x:: AbstractMatrix , y:: AbstractVector )
414
+ function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractVector{<:Number} )
425
415
function kron_pullback (z̄)
426
- x̄ = zero (x)
427
- ȳ = zero (y)
428
- m = firstindex (z̄)
429
- @inbounds for j in axes (x,2 ), i in axes (x,1 )
430
- xij = x[i,j]
431
- for k in eachindex (y)
432
- x̄[i, j] += y[k]' * z̄[m]
433
- ȳ[k] += xij' * z̄[m]
434
- m += 1
435
- end
436
- end
437
- NoTangent (), x̄, ȳ
416
+ dz = reshape (z̄, length (y), size (x)... )
417
+ x̄ = Ref (y' ) .* eachslice (dz; dims = (2 , 3 ))
418
+ ȳ = conj .(dot .(eachslice (dz; dims = 1 ), Ref (x)))
419
+ return NoTangent (), x̄, ȳ
438
420
end
439
- kron (x, y), kron_pullback
421
+ return kron (x, y), kron_pullback
440
422
end
441
423
442
- function rrule (:: typeof (kron), x:: AbstractVector , y:: AbstractMatrix )
424
+ function rrule (:: typeof (kron), x:: AbstractVector{<:Number} , y:: AbstractMatrix{<:Number} )
443
425
function kron_pullback (z̄)
444
- x̄ = zero (x)
445
- ȳ = zero (y)
446
- m = firstindex (z̄)
447
- @inbounds for l in axes (y,2 ), i in eachindex (x)
448
- xi = x[i]
449
- for k in axes (y,1 )
450
- x̄[i] += y[k, l]' * z̄[m]
451
- ȳ[k, l] += xi' * z̄[m]
452
- m += 1
453
- end
454
- end
455
- NoTangent (), x̄, ȳ
426
+ dz = reshape (z̄, size (y, 1 ), length (x), size (y, 2 ))
427
+ x̄ = conj .(dot .(eachslice (dz; dims = 2 ), Ref (y)))
428
+ ȳ = Ref (x' ) .* eachslice (dz; dims = (1 , 3 ))
429
+ return NoTangent (), x̄, ȳ
456
430
end
457
- kron (x, y), kron_pullback
431
+ return kron (x, y), kron_pullback
458
432
end
459
433
460
- function rrule (:: typeof (kron), x:: AbstractMatrix , y:: AbstractMatrix )
434
+ function rrule (:: typeof (kron), x:: AbstractMatrix{<:Number} , y:: AbstractMatrix{<:Number} )
461
435
function kron_pullback (z̄)
462
- x̄ = zero (x)
463
- ȳ = zero (y)
464
- m = firstindex (z̄)
465
- @inbounds for l in axes (y,2 ), j in axes (x,2 ), i in axes (x,1 )
466
- xij = x[i, j]
467
- for k in axes (y,1 )
468
- x̄[i, j] += y[k, l]' * z̄[m]
469
- ȳ[k, l] += xij' * z̄[m]
470
- m += 1
471
- end
472
- end
473
- NoTangent (), x̄, ȳ
436
+ dz = reshape (z̄, size (y, 1 ), size (x, 1 ), size (y, 2 ), size (x, 2 ))
437
+ x̄ = conj .(dot .(eachslice (dz, dims = (2 , 4 )), Ref (y)))
438
+ ȳ = dot .(eachslice (conj .(dz); dims = (1 , 3 )), Ref (conj .(x)))
439
+ return NoTangent (), x̄, ȳ
474
440
end
475
- kron (x, y), kron_pullback
441
+ return kron (x, y), kron_pullback
476
442
end
0 commit comments