Skip to content

Commit 5bb9766

Browse files
Simone Carlo Suracesimsurace
authored andcommitted
Add rules and tests
1 parent b3f9f9b commit 5bb9766

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,3 +394,51 @@ function rrule(
394394
end
395395
return Ω, lyap_pullback
396396
end
397+
398+
#####
399+
##### `kron`
400+
#####
401+
402+
function frule((_, Δx, Δy), ::typeof(kron), x, y)
403+
return kron(x, y), kron(Δx, y) + kron(x, Δy)
404+
end
405+
406+
function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector)
407+
z = kron(x, y)
408+
409+
function kron_pullback(z̄)
410+
= zero(x)
411+
= zero(y)
412+
m = firstindex(z̄)
413+
@inbounds for j in axes(x,2), i in axes(x,1)
414+
xij = x[i,j]
415+
for k in eachindex(y)
416+
x̄[i, j] += y[k]' * z̄[m]
417+
ȳ[k] += xij * z̄[m]
418+
m += 1
419+
end
420+
end
421+
NoTangent(), x̄, ȳ
422+
end
423+
z, kron_pullback
424+
end
425+
426+
function rrule(::typeof(kron), x::AbstractVector, y::AbstractMatrix)
427+
z = kron(x, y)
428+
429+
function kron_pullback(z̄)
430+
= zero(x)
431+
= zero(y)
432+
m = firstindex(z̄)
433+
@inbounds for l in axes(y,2), i in eachindex(x)
434+
xi = x[i]
435+
for k in axes(y,1)
436+
x̄[i] += y[k, l]' * z̄[m]
437+
ȳ[k, l] += xi * z̄[m]
438+
m += 1
439+
end
440+
end
441+
NoTangent(), x̄, ȳ
442+
end
443+
z, kron_pullback
444+
end

test/rulesets/LinearAlgebra/dense.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,4 +159,12 @@
159159
test_rrule(lyap, A, C)
160160
end
161161
end
162+
@testset "kron" begin
163+
@testset "AbstractVecOrMat{$T}" for T in (Float64, ComplexF64)
164+
test_frule(kron, randn(T, 3), randn(T, 3))
165+
test_frule(kron, randn(T, 3, 2), randn(T, 3))
166+
test_frule(kron, randn(T, 3), randn(T, 3, 4))
167+
test_frule(kron, randn(T, 3, 4), randn(T, 2, 2))
168+
end
169+
end
162170
end

0 commit comments

Comments
 (0)