Skip to content

Slow mode for kerneldiagmatrix #203

@theogf

Description

@theogf

Zygote currently fails to differentiate through kerneldiagmatrix when given a RowVecs or a ColVecs

MWE :

using KernelFunctions, Zygote

X = KernelFunctions.RowVecs(rand(3, 3))
k = transform(SqExponentialKernel(), 2.0)
Zygote.gradient(k) do k
	sum(kerneldiagmatrix(k, X))
end

ERROR: In slow method
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::KernelFunctions.var"#back#186")(::Array{Array{Float64,1},1}) at /home/theo/.julia/packages/KernelFunctions/V02nz/src/zygote_adjoints.jl:75
 [3] (::KernelFunctions.var"#171#back#187"{KernelFunctions.var"#back#186"})(::Array{Array{Float64,1},1}) at /home/theo/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [4] _map at /home/theo/.julia/packages/KernelFunctions/V02nz/src/transform/scaletransform.jl:26 [inlined]
 [5] (::typeof((_map)))(::Array{Array{Float64,1},1}) at /home/theo/.julia/packages/Zygote/nK6sg/src/compiler/interface2.jl:0
 [6] kerneldiagmatrix at /home/theo/.julia/packages/KernelFunctions/V02nz/src/kernels/transformedkernel.jl:85 [inlined]
 [7] (::typeof((kerneldiagmatrix)))(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at /home/theo/.julia/packages/Zygote/nK6sg/src/compiler/interface2.jl:0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions