Skip to content

Commit b2d4f4a

Browse files
Simone Carlo Suracesimsurace
authored andcommitted
Add unthunk and @thunk
1 parent 236daf1 commit b2d4f4a

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -405,37 +405,39 @@ end
405405

406406
function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractVector{<:Number})
407407
function kron_pullback(z̄)
408-
dz = reshape(z̄, length(y), length(x))
409-
return NoTangent(), conj.(dz' * y), dz * conj.(x)
408+
dz = reshape(unthunk(z̄), length(y), length(x))
409+
= @thunk conj.(dz' * y)
410+
ȳ = @thunk dz * conj.(x)
411+
return NoTangent(), x̄, ȳ
410412
end
411413
return kron(x, y), kron_pullback
412414
end
413415

414416
function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractVector{<:Number})
415417
function kron_pullback(z̄)
416-
dz = reshape(, length(y), size(x)...)
417-
= Ref(y') .* eachslice(dz; dims = (2, 3))
418-
ȳ = conj.(dot.(eachslice(dz; dims = 1), Ref(x)))
418+
dz = reshape(unthunk(z̄), length(y), size(x)...)
419+
= @thunk Ref(y') .* eachslice(dz; dims = (2, 3))
420+
ȳ = @thunk conj.(dot.(eachslice(dz; dims = 1), Ref(x)))
419421
return NoTangent(), x̄, ȳ
420422
end
421423
return kron(x, y), kron_pullback
422424
end
423425

424426
function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<:Number})
425427
function kron_pullback(z̄)
426-
dz = reshape(, size(y, 1), length(x), size(y, 2))
427-
= conj.(dot.(eachslice(dz; dims = 2), Ref(y)))
428-
ȳ = Ref(x') .* eachslice(dz; dims = (1, 3))
428+
dz = reshape(unthunk(z̄), size(y, 1), length(x), size(y, 2))
429+
= @thunk conj.(dot.(eachslice(dz; dims = 2), Ref(y)))
430+
ȳ = @thunk Ref(x') .* eachslice(dz; dims = (1, 3))
429431
return NoTangent(), x̄, ȳ
430432
end
431433
return kron(x, y), kron_pullback
432434
end
433435

434436
function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:Number})
435437
function kron_pullback(z̄)
436-
dz = reshape(, size(y, 1), size(x, 1), size(y, 2), size(x, 2))
437-
= conj.(dot.(eachslice(dz, dims = (2, 4)), Ref(y)))
438-
ȳ = dot.(eachslice(conj.(dz); dims = (1, 3)), Ref(conj.(x)))
438+
dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2))
439+
= @thunk conj.(dot.(eachslice(dz, dims = (2, 4)), Ref(y)))
440+
ȳ = @thunk dot.(eachslice(conj.(dz); dims = (1, 3)), Ref(conj.(x)))
439441
return NoTangent(), x̄, ȳ
440442
end
441443
return kron(x, y), kron_pullback

0 commit comments

Comments
 (0)