Skip to content

Commit 236daf1

Browse files
Simone Carlo Suracesimsurace
authored andcommitted
Write rules functionally and fix them
1 parent c1226eb commit 236daf1

File tree

1 file changed

+22
-56
lines changed

1 file changed

+22
-56
lines changed

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 22 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -403,74 +403,40 @@ function frule((_, Δx, Δy), ::typeof(kron), x::AbstractVecOrMat{<:Number}, y::
403403
return kron(x, y), kron(Δx, y) + kron(x, Δy)
404404
end
405405

406-
function rrule(::typeof(kron), x::AbstractVector, y::AbstractVector)
406+
function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractVector{<:Number})
407407
function kron_pullback(z̄)
408-
= zero(x)
409-
= zero(y)
410-
m = firstindex(z̄)
411-
@inbounds for i in eachindex(x)
412-
xi = x[i]
413-
for k in eachindex(y)
414-
x̄[i] += y[k]' * z̄[m]
415-
ȳ[k] += xi' * z̄[m]
416-
m += 1
417-
end
418-
end
419-
NoTangent(), x̄, ȳ
408+
dz = reshape(z̄, length(y), length(x))
409+
return NoTangent(), conj.(dz' * y), dz * conj.(x)
420410
end
421-
kron(x, y), kron_pullback
411+
return kron(x, y), kron_pullback
422412
end
423413

424-
function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractVector)
414+
function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractVector{<:Number})
425415
function kron_pullback(z̄)
426-
= zero(x)
427-
= zero(y)
428-
m = firstindex(z̄)
429-
@inbounds for j in axes(x,2), i in axes(x,1)
430-
xij = x[i,j]
431-
for k in eachindex(y)
432-
x̄[i, j] += y[k]' * z̄[m]
433-
ȳ[k] += xij' * z̄[m]
434-
m += 1
435-
end
436-
end
437-
NoTangent(), x̄, ȳ
416+
dz = reshape(z̄, length(y), size(x)...)
417+
= Ref(y') .* eachslice(dz; dims = (2, 3))
418+
ȳ = conj.(dot.(eachslice(dz; dims = 1), Ref(x)))
419+
return NoTangent(), x̄, ȳ
438420
end
439-
kron(x, y), kron_pullback
421+
return kron(x, y), kron_pullback
440422
end
441423

442-
function rrule(::typeof(kron), x::AbstractVector, y::AbstractMatrix)
424+
function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<:Number})
443425
function kron_pullback(z̄)
444-
= zero(x)
445-
= zero(y)
446-
m = firstindex(z̄)
447-
@inbounds for l in axes(y,2), i in eachindex(x)
448-
xi = x[i]
449-
for k in axes(y,1)
450-
x̄[i] += y[k, l]' * z̄[m]
451-
ȳ[k, l] += xi' * z̄[m]
452-
m += 1
453-
end
454-
end
455-
NoTangent(), x̄, ȳ
426+
dz = reshape(z̄, size(y, 1), length(x), size(y, 2))
427+
= conj.(dot.(eachslice(dz; dims = 2), Ref(y)))
428+
ȳ = Ref(x') .* eachslice(dz; dims = (1, 3))
429+
return NoTangent(), x̄, ȳ
456430
end
457-
kron(x, y), kron_pullback
431+
return kron(x, y), kron_pullback
458432
end
459433

460-
function rrule(::typeof(kron), x::AbstractMatrix, y::AbstractMatrix)
434+
function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:Number})
461435
function kron_pullback(z̄)
462-
= zero(x)
463-
= zero(y)
464-
m = firstindex(z̄)
465-
@inbounds for l in axes(y,2), j in axes(x,2), i in axes(x,1)
466-
xij = x[i, j]
467-
for k in axes(y,1)
468-
x̄[i, j] += y[k, l]' * z̄[m]
469-
ȳ[k, l] += xij' * z̄[m]
470-
m += 1
471-
end
472-
end
473-
NoTangent(), x̄, ȳ
436+
dz = reshape(z̄, size(y, 1), size(x, 1), size(y, 2), size(x, 2))
437+
= conj.(dot.(eachslice(dz, dims = (2, 4)), Ref(y)))
438+
ȳ = dot.(eachslice(conj.(dz); dims = (1, 3)), Ref(conj.(x)))
439+
return NoTangent(), x̄, ȳ
474440
end
475-
kron(x, y), kron_pullback
441+
return kron(x, y), kron_pullback
476442
end

0 commit comments

Comments
 (0)