-
Notifications
You must be signed in to change notification settings - Fork 40
Closed
Milestone
Description
Zygote currently fails to differentiate through kerneldiagmatrix
when given a RowVecs
or a ColVecs
MWE :
using KernelFunctions, Zygote
X = KernelFunctions.RowVecs(rand(3, 3))
k = transform(SqExponentialKernel(), 2.0)
Zygote.gradient(k) do k
sum(kerneldiagmatrix(k, X))
end
ERROR: In slow method
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] (::KernelFunctions.var"#back#186")(::Array{Array{Float64,1},1}) at /home/theo/.julia/packages/KernelFunctions/V02nz/src/zygote_adjoints.jl:75
[3] (::KernelFunctions.var"#171#back#187"{KernelFunctions.var"#back#186"})(::Array{Array{Float64,1},1}) at /home/theo/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[4] _map at /home/theo/.julia/packages/KernelFunctions/V02nz/src/transform/scaletransform.jl:26 [inlined]
[5] (::typeof(∂(_map)))(::Array{Array{Float64,1},1}) at /home/theo/.julia/packages/Zygote/nK6sg/src/compiler/interface2.jl:0
[6] kerneldiagmatrix at /home/theo/.julia/packages/KernelFunctions/V02nz/src/kernels/transformedkernel.jl:85 [inlined]
[7] (::typeof(∂(kerneldiagmatrix)))(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at /home/theo/.julia/packages/Zygote/nK6sg/src/compiler/interface2.jl:0
Metadata
Metadata
Assignees
Labels
No labels