@@ -403,9 +403,25 @@ function frule((_, Δx, Δy), ::typeof(kron), x, y)
403
403
return kron (x, y), kron (Δx, y) + kron (x, Δy)
404
404
end
405
405
406
- function rrule (:: typeof (kron), x:: AbstractMatrix , y:: AbstractVector )
407
- z = kron (x, y)
406
+ function rrule (:: typeof (kron), x:: AbstractVector , y:: AbstractVector )
407
+ function kron_pullback (z̄)
408
+ x̄ = zero (x)
409
+ ȳ = zero (y)
410
+ m = firstindex (z̄)
411
+ @inbounds for i in eachindex (x)
412
+ xi = x[i]
413
+ for k in eachindex (y)
414
+ x̄[i] += y[k]' * z̄[m]
415
+ ȳ[k] += xi' * z̄[m]
416
+ m += 1
417
+ end
418
+ end
419
+ NoTangent (), x̄, ȳ
420
+ end
421
+ kron (x, y), kron_pullback
422
+ end
408
423
424
+ function rrule (:: typeof (kron), x:: AbstractMatrix , y:: AbstractVector )
409
425
function kron_pullback (z̄)
410
426
x̄ = zero (x)
411
427
ȳ = zero (y)
@@ -414,18 +430,16 @@ function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector)
414
430
xij = x[i,j]
415
431
for k in eachindex (y)
416
432
x̄[i, j] += y[k]' * z̄[m]
417
- ȳ[k] += xij * z̄[m]
433
+ ȳ[k] += xij' * z̄[m]
418
434
m += 1
419
435
end
420
436
end
421
437
NoTangent (), x̄, ȳ
422
438
end
423
- z , kron_pullback
439
+ kron (x, y) , kron_pullback
424
440
end
425
441
426
442
function rrule (:: typeof (kron), x:: AbstractVector , y:: AbstractMatrix )
427
- z = kron (x, y)
428
-
429
443
function kron_pullback (z̄)
430
444
x̄ = zero (x)
431
445
ȳ = zero (y)
@@ -434,11 +448,29 @@ function rrule(::typeof(kron), x::AbstractVector, y::AbstractMatrix)
434
448
xi = x[i]
435
449
for k in axes (y,1 )
436
450
x̄[i] += y[k, l]' * z̄[m]
437
- ȳ[k, l] += xi * z̄[m]
451
+ ȳ[k, l] += xi' * z̄[m]
452
+ m += 1
453
+ end
454
+ end
455
+ NoTangent (), x̄, ȳ
456
+ end
457
+ kron (x, y), kron_pullback
458
+ end
459
+
460
+ function rrule (:: typeof (kron), x:: AbstractMatrix , y:: AbstractMatrix )
461
+ function kron_pullback (z̄)
462
+ x̄ = zero (x)
463
+ ȳ = zero (y)
464
+ m = firstindex (z̄)
465
+ @inbounds for l in axes (y,2 ), j in axes (x,2 ), i in axes (x,1 )
466
+ xij = x[i, j]
467
+ for k in axes (y,1 )
468
+ x̄[i, j] += y[k, l]' * z̄[m]
469
+ ȳ[k, l] += xij' * z̄[m]
438
470
m += 1
439
471
end
440
472
end
441
473
NoTangent (), x̄, ȳ
442
474
end
443
- z , kron_pullback
475
+ kron (x, y) , kron_pullback
444
476
end
0 commit comments