Skip to content

Commit 7c53f4b

Browse files
Simone Carlo Suracesimsurace
authored andcommitted
Add rules and try to cover complex case
1 parent f0902e3 commit 7c53f4b

File tree

1 file changed

+40
-8
lines changed

1 file changed

+40
-8
lines changed

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -403,9 +403,25 @@ function frule((_, Δx, Δy), ::typeof(kron), x, y)
403403
return kron(x, y), kron(Δx, y) + kron(x, Δy)
404404
end
405405

406-
function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector)
407-
z = kron(x, y)
406+
function rrule(::typeof(kron), x::AbstractVector, y::AbstractVector)
407+
function kron_pullback(z̄)
408+
= zero(x)
409+
= zero(y)
410+
m = firstindex(z̄)
411+
@inbounds for i in eachindex(x)
412+
xi = x[i]
413+
for k in eachindex(y)
414+
x̄[i] += y[k]' * z̄[m]
415+
ȳ[k] += xi' * z̄[m]
416+
m += 1
417+
end
418+
end
419+
NoTangent(), x̄, ȳ
420+
end
421+
kron(x, y), kron_pullback
422+
end
408423

424+
function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector)
409425
function kron_pullback(z̄)
410426
= zero(x)
411427
= zero(y)
@@ -414,18 +430,16 @@ function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector)
414430
xij = x[i,j]
415431
for k in eachindex(y)
416432
x̄[i, j] += y[k]' * z̄[m]
417-
ȳ[k] += xij * z̄[m]
433+
ȳ[k] += xij' * z̄[m]
418434
m += 1
419435
end
420436
end
421437
NoTangent(), x̄, ȳ
422438
end
423-
z, kron_pullback
439+
kron(x, y), kron_pullback
424440
end
425441

426442
function rrule(::typeof(kron), x::AbstractVector, y::AbstractMatrix)
427-
z = kron(x, y)
428-
429443
function kron_pullback(z̄)
430444
= zero(x)
431445
= zero(y)
@@ -434,11 +448,29 @@ function rrule(::typeof(kron), x::AbstractVector, y::AbstractMatrix)
434448
xi = x[i]
435449
for k in axes(y,1)
436450
x̄[i] += y[k, l]' * z̄[m]
437-
ȳ[k, l] += xi * z̄[m]
451+
ȳ[k, l] += xi' * z̄[m]
452+
m += 1
453+
end
454+
end
455+
NoTangent(), x̄, ȳ
456+
end
457+
kron(x, y), kron_pullback
458+
end
459+
460+
function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractMatrix)
461+
function kron_pullback(z̄)
462+
= zero(x)
463+
= zero(y)
464+
m = firstindex(z̄)
465+
@inbounds for l in axes(y,2), j in axes(x,2), i in axes(x,1)
466+
xij = x[i, j]
467+
for k in axes(y,1)
468+
x̄[i, j] += y[k, l]' * z̄[m]
469+
ȳ[k, l] += xij' * z̄[m]
438470
m += 1
439471
end
440472
end
441473
NoTangent(), x̄, ȳ
442474
end
443-
z, kron_pullback
475+
kron(x, y), kron_pullback
444476
end

0 commit comments

Comments
 (0)