Skip to content

Commit a60a494

Browse files
committed
Use Functors
1 parent 7d3f47f commit a60a494

24 files changed

+59
-54
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.6.0"
55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
8+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
89
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
@@ -16,6 +17,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1617
[compat]
1718
Compat = "2.2, 3"
1819
Distances = "0.9"
20+
Functors = "0.1"
1921
Requires = "1.0.1"
2022
SpecialFunctions = "0.8, 0.9, 0.10"
2123
StatsBase = "0.32, 0.33"

src/KernelFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ export IndependentMOKernel
3737
using Compat
3838
using Requires
3939
using Distances, LinearAlgebra
40+
using Functors
4041
using SpecialFunctions: loggamma, besselk, polygamma
4142
using ZygoteRules: @adjoint, pullback
4243
using StatsFuns: logtwo
@@ -79,7 +80,6 @@ include("zygote_adjoints.jl")
7980
function __init__()
8081
@require Kronecker="2c470bb0-bcc8-11e8-3dad-c9649493f05e" include("matrix/kernelkroneckermat.jl")
8182
@require PDMats="90014a1f-27ba-587c-ab20-58faa44d9150" include("matrix/kernelpdmat.jl")
82-
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" include("trainable.jl")
8383
end
8484

8585
end

src/basekernels/constant.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ metric(::ZeroKernel) = Delta()
1515

1616
Base.show(io::IO, ::ZeroKernel) = print(io, "Zero Kernel")
1717

18-
1918
"""
2019
WhiteKernel()
2120
@@ -55,6 +54,8 @@ struct ConstantKernel{Tc<:Real} <: SimpleKernel
5554
end
5655
end
5756

57+
@functor ConstantKernel
58+
5859
kappa::ConstantKernel,x::Real) = first.c)*one(x)
5960

6061
metric(::ConstantKernel) = Delta()

src/basekernels/exponential.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ struct GammaExponentialKernel{Tγ<:Real} <: SimpleKernel
6363
end
6464
end
6565

66+
@functor GammaExponentialKernel
67+
6668
kappa::GammaExponentialKernel, d²::Real) = exp(-^first.γ))
6769

6870
metric(::GammaExponentialKernel) = SqEuclidean()

src/basekernels/fbm.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ struct FBMKernel{T<:Real} <: Kernel
1717
end
1818
end
1919

20+
@functor FBMKernel
21+
2022
function::FBMKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
2123
modX = sum(abs2, x)
2224
modY = sum(abs2, y)

src/basekernels/gabor.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ struct GaborKernel{K<:Kernel} <: Kernel
1515
end
1616
end
1717

18+
@functor GaborKernel
19+
1820
::GaborKernel)(x, y) = κ.kernel(x ,y)
1921

2022
function _gabor(; ell = nothing, p = nothing)

src/basekernels/maha.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ struct MahalanobisKernel{T<:Real, A<:AbstractMatrix{T}} <: SimpleKernel
1616
end
1717
end
1818

19+
@functor MahalanobisKernel
20+
1921
kappa::MahalanobisKernel, d::T) where {T<:Real} = exp(-d)
2022

2123
metric::MahalanobisKernel) = SqMahalanobis.P)

src/basekernels/matern.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ struct MaternKernel{Tν<:Real} <: SimpleKernel
1515
end
1616
end
1717

18+
@functor MaternKernel
19+
1820
@inline function kappa::MaternKernel, d::Real)
1921
result = _matern(first.ν), d)
2022
return ifelse(iszero(d), one(result), result)

src/basekernels/periodic.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ PeriodicKernel(dims::Int) = PeriodicKernel(Float64, dims)
2020

2121
PeriodicKernel(T::DataType, dims::Int = 1) = PeriodicKernel(r = ones(T, dims))
2222

23+
@functor PeriodicKernel
24+
2325
metric::PeriodicKernel) = Sinus.r)
2426

2527
kappa::PeriodicKernel, d::Real) = exp(- 0.5d)

src/basekernels/piecewisepolynomial.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ function PiecewisePolynomialKernel(;v::Integer=0, maha::AbstractMatrix{<:Real})
2525
return PiecewisePolynomialKernel{v}(maha)
2626
end
2727

28+
# Have to reconstruct the type parameter
29+
# See also https://github.com/FluxML/Functors.jl/issues/3#issuecomment-626747663
30+
function Functors.functor(::Type{<:PiecewisePolynomialKernel{V}}, x) where V
31+
function reconstruct_kernel(xs)
32+
return PiecewisePolynomialKernel{V}(xs.maha)
33+
end
34+
return (maha = x.maha,), reconstruct_kernel
35+
end
36+
2837
_f::PiecewisePolynomialKernel{0}, r, j) = 1
2938
_f::PiecewisePolynomialKernel{1}, r, j) = 1 + (j + 1) * r
3039
_f::PiecewisePolynomialKernel{2}, r, j) = 1 + (j + 2) * r + (j^2 + 4 * j + 3) / 3 * r.^2

0 commit comments

Comments
 (0)