Skip to content

Commit 2ad5473

Browse files
Simone Carlo Suracesimsurace
authored andcommitted
Only define rules for Julia 1.9 onwards
1 parent b71b8ef commit 2ad5473

File tree

1 file changed

+36
-34
lines changed

1 file changed

+36
-34
lines changed

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -399,46 +399,48 @@ end
399399
##### `kron`
400400
#####
401401

402-
function frule((_, Δx, Δy), ::typeof(kron), x::AbstractVecOrMat{<:Number}, y::AbstractVecOrMat{<:Number})
403-
return kron(x, y), kron(Δx, y) + kron(x, Δy)
404-
end
402+
@static if VERSION v"1.9.0"
403+
function frule((_, Δx, Δy), ::typeof(kron), x::AbstractVecOrMat{<:Number}, y::AbstractVecOrMat{<:Number})
404+
return kron(x, y), kron(Δx, y) + kron(x, Δy)
405+
end
405406

406-
function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractVector{<:Number})
407-
function kron_pullback(z̄)
408-
dz = reshape(unthunk(z̄), length(y), length(x))
409-
= @thunk conj.(dz' * y)
410-
ȳ = @thunk dz * conj.(x)
411-
return NoTangent(), x̄, ȳ
407+
function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractVector{<:Number})
408+
function kron_pullback(z̄)
409+
dz = reshape(unthunk(z̄), length(y), length(x))
410+
= @thunk conj.(dz' * y)
411+
ȳ = @thunk dz * conj.(x)
412+
return NoTangent(), x̄, ȳ
413+
end
414+
return kron(x, y), kron_pullback
412415
end
413-
return kron(x, y), kron_pullback
414-
end
415416

416-
function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractVector{<:Number})
417-
function kron_pullback(z̄)
418-
dz = reshape(unthunk(z̄), length(y), size(x)...)
419-
= @thunk conj.(dot.(eachslice(dz; dims = (2, 3)), Ref(y)))
420-
ȳ = @thunk conj.(dot.(eachslice(dz; dims = 1), Ref(x)))
421-
return NoTangent(), x̄, ȳ
417+
function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractVector{<:Number})
418+
function kron_pullback(z̄)
419+
dz = reshape(unthunk(z̄), length(y), size(x)...)
420+
= @thunk conj.(dot.(eachslice(dz; dims = (2, 3)), Ref(y)))
421+
ȳ = @thunk conj.(dot.(eachslice(dz; dims = 1), Ref(x)))
422+
return NoTangent(), x̄, ȳ
423+
end
424+
return kron(x, y), kron_pullback
422425
end
423-
return kron(x, y), kron_pullback
424-
end
425426

426-
function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<:Number})
427-
function kron_pullback(z̄)
428-
dz = reshape(unthunk(z̄), size(y, 1), length(x), size(y, 2))
429-
= @thunk conj.(dot.(eachslice(dz; dims = 2), Ref(y)))
430-
ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x)))
431-
return NoTangent(), x̄, ȳ
427+
function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<:Number})
428+
function kron_pullback(z̄)
429+
dz = reshape(unthunk(z̄), size(y, 1), length(x), size(y, 2))
430+
= @thunk conj.(dot.(eachslice(dz; dims = 2), Ref(y)))
431+
ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x)))
432+
return NoTangent(), x̄, ȳ
433+
end
434+
return kron(x, y), kron_pullback
432435
end
433-
return kron(x, y), kron_pullback
434-
end
435436

436-
function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:Number})
437-
function kron_pullback(z̄)
438-
dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2))
439-
= @thunk conj.(dot.(eachslice(dz, dims = (2, 4)), Ref(y)))
440-
ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x)))
441-
return NoTangent(), x̄, ȳ
437+
function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:Number})
438+
function kron_pullback(z̄)
439+
dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2))
440+
= @thunk conj.(dot.(eachslice(dz, dims = (2, 4)), Ref(y)))
441+
ȳ = @thunk conj.(dot.(eachslice(dz; dims = (1, 3)), Ref(x)))
442+
return NoTangent(), x̄, ȳ
443+
end
444+
return kron(x, y), kron_pullback
442445
end
443-
return kron(x, y), kron_pullback
444446
end

0 commit comments

Comments
 (0)