Skip to content

Commit 96086dd

Browse files
authored
Merge pull request #157 from devmotion/functors
2 parents a92fcc2 + 1bba715 commit 96086dd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+133
-196
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.6.1"
3+
version = "0.7.0"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
8+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
89
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
@@ -16,6 +17,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1617
[compat]
1718
Compat = "2.2, 3"
1819
Distances = "0.9"
20+
Functors = "0.1"
1921
Requires = "1.0.1"
2022
SpecialFunctions = "0.8, 0.9, 0.10"
2123
StatsBase = "0.32, 0.33"

docs/src/create_kernel.md

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,43 @@ Note that `BaseKernel` do not use `Distances.jl` and can therefore be a bit slow
3333
### Additional Options
3434

3535
Finally there are additional functions you can define to bring in more features:
36-
- `KernelFunctions.trainable(k::MyKernel)`: it defines the trainable parameters of your kernel, it should return a `Tuple` of your parameters.
37-
These parameters will be passed to the `Flux.params` function. For some examples see the `trainable.jl` file in `src/`
3836
- `KernelFunctions.iskroncompatible(k::MyKernel)`: if your kernel factorizes in dimensions, you can declare your kernel as `iskroncompatible(k) = true` to use Kronecker methods.
3937
- `KernelFunctions.dim(x::MyDataType)`: by default the dimension of the inputs will only be checked for vectors of type `AbstractVector{<:Real}`. If you want to check the dimensionality of your inputs, dispatch the `dim` function on your datatype. Note that `0` is the default.
4038
- `dim` is called within `KernelFunctions.validate_inputs(x::MyDataType, y::MyDataType)`, which can instead be directly overloaded if you want to run special checks for your input types.
4139
- `kernelmatrix(k::MyKernel, ...)`: you can redefine the diverse `kernelmatrix` functions to eventually optimize the computations.
4240
- `Base.print(io::IO, k::MyKernel)`: if you want to specialize the printing of your kernel
41+
42+
KernelFunctions uses [Functors.jl](https://github.com/FluxML/Functors.jl) for specifying trainable kernel parameters
43+
in a way that is compatible with the [Flux ML framework](https://github.com/FluxML/Flux.jl).
44+
You can use `Functors.@functor` if all fields of your kernel struct are trainable. Note that optimization algorithms
45+
in Flux are not compatible with scalar parameters (yet), and hence vector-valued parameters should be preferred.
46+
47+
```julia
48+
import Functors
49+
50+
struct MyKernel{T} <: KernelFunctions.Kernel
51+
a::Vector{T}
52+
end
53+
54+
Functors.@functor MyKernel
55+
```
56+
57+
If only a subset of the fields are trainable, you have to specify explicitly how to (re)construct the kernel with
58+
modified parameter values by [implementing `Functors.functor(::Type{<:MyKernel}, x)` for your kernel struct](https://github.com/FluxML/Functors.jl/issues/3):
59+
60+
```julia
61+
import Functors
62+
63+
struct MyKernel{T} <: KernelFunctions.Kernel
64+
n::Int
65+
a::Vector{T}
66+
end
67+
68+
function Functors.functor(::Type{<:MyKernel}, x::MyKernel)
69+
function reconstruct_mykernel(xs)
70+
# keep field `n` of the original kernel and set `a` to (possibly different) `xs.a`
71+
return MyKernel(x.n, xs.a)
72+
end
73+
return (a = x.a,), reconstruct_mykernel
74+
end
75+
```

src/KernelFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ export IndependentMOKernel
3737
using Compat
3838
using Requires
3939
using Distances, LinearAlgebra
40+
using Functors
4041
using SpecialFunctions: loggamma, besselk, polygamma
4142
using ZygoteRules: @adjoint, pullback
4243
using StatsFuns: logtwo
@@ -79,7 +80,6 @@ include("zygote_adjoints.jl")
7980
function __init__()
8081
@require Kronecker="2c470bb0-bcc8-11e8-3dad-c9649493f05e" include("matrix/kernelkroneckermat.jl")
8182
@require PDMats="90014a1f-27ba-587c-ab20-58faa44d9150" include("matrix/kernelpdmat.jl")
82-
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" include("trainable.jl")
8383
end
8484

8585
end

src/basekernels/constant.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ metric(::ZeroKernel) = Delta()
1515

1616
Base.show(io::IO, ::ZeroKernel) = print(io, "Zero Kernel")
1717

18-
1918
"""
2019
WhiteKernel()
2120
@@ -55,6 +54,8 @@ struct ConstantKernel{Tc<:Real} <: SimpleKernel
5554
end
5655
end
5756

57+
@functor ConstantKernel
58+
5859
kappa::ConstantKernel,x::Real) = first.c)*one(x)
5960

6061
metric(::ConstantKernel) = Delta()

src/basekernels/exponential.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ struct GammaExponentialKernel{Tγ<:Real} <: SimpleKernel
6363
end
6464
end
6565

66+
@functor GammaExponentialKernel
67+
6668
kappa::GammaExponentialKernel, d²::Real) = exp(-^first.γ))
6769

6870
metric(::GammaExponentialKernel) = SqEuclidean()

src/basekernels/fbm.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ struct FBMKernel{T<:Real} <: Kernel
1717
end
1818
end
1919

20+
@functor FBMKernel
21+
2022
function::FBMKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
2123
modX = sum(abs2, x)
2224
modY = sum(abs2, y)

src/basekernels/gabor.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ struct GaborKernel{K<:Kernel} <: Kernel
1515
end
1616
end
1717

18+
@functor GaborKernel
19+
1820
::GaborKernel)(x, y) = κ.kernel(x ,y)
1921

2022
function _gabor(; ell = nothing, p = nothing)

src/basekernels/maha.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ struct MahalanobisKernel{T<:Real, A<:AbstractMatrix{T}} <: SimpleKernel
1616
end
1717
end
1818

19+
@functor MahalanobisKernel
20+
1921
kappa::MahalanobisKernel, d::T) where {T<:Real} = exp(-d)
2022

2123
metric::MahalanobisKernel) = SqMahalanobis.P)

src/basekernels/matern.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ struct MaternKernel{Tν<:Real} <: SimpleKernel
1515
end
1616
end
1717

18+
@functor MaternKernel
19+
1820
@inline function kappa::MaternKernel, d::Real)
1921
result = _matern(first.ν), d)
2022
return ifelse(iszero(d), one(result), result)

src/basekernels/periodic.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ PeriodicKernel(dims::Int) = PeriodicKernel(Float64, dims)
2020

2121
PeriodicKernel(T::DataType, dims::Int = 1) = PeriodicKernel(r = ones(T, dims))
2222

23+
@functor PeriodicKernel
24+
2325
metric::PeriodicKernel) = Sinus.r)
2426

2527
kappa::PeriodicKernel, d::Real) = exp(- 0.5d)

0 commit comments

Comments
 (0)