Skip to content

Commit d2b5a06

Browse files
committed
Adding more docs and explanations
1 parent 0526c5a commit d2b5a06

File tree

3 files changed

+33
-11
lines changed

3 files changed

+33
-11
lines changed

docs/src/create_kernel.md

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,41 @@
22

33
KernelFunctions.jl contains the most popular kernels already but you might want to make your own!
44

5-
Here is for example how one can define the Squared Exponential Kernel again :
5+
Here are a few ways depending on how complicated your kernel is :
6+
7+
### SimpleKernel for kernels function depending on a metric
8+
9+
If your kernel function is of the form `k(x, y) = f(d(x, y))` where `d(x, y)` is a `PreMetric`,
10+
you can construct your custom kernel by defining `kappa` and `metric` for your kernel.
11+
Here is for example how one can define the `SqExponentialKernel` again :
612

713
```julia
8-
struct MyKernel <: Kernel end
14+
struct MyKernel <: KernelFunctions.SimpleKernel end
915

1016
KernelFunctions.kappa(::MyKernel, d2::Real) = exp(-d2)
1117
KernelFunctions.metric(::MyKernel) = SqEuclidean()
1218
```
1319

14-
For a "Base" kernel, where the kernel function is simply a function applied on some metric between two vectors of real, you only need to:
15-
- Define your struct inheriting from `Kernel`.
16-
- Define a `kappa` function.
17-
- Define the metric used `SqEuclidean`, `DotProduct` etc. Note that the term "metric" is here overabused.
18-
- Optional : Define any parameter of your kernel as `trainable` by Flux.jl if you want to perform optimization on the parameters. We recommend wrapping all parameters in arrays to allow them to be mutable.
20+
### Kernel for more complex kernels
21+
22+
If your kernel does not satisfy such a representation, all you need to do is define `(k::MyKernel)(x, y)` and inherit from `Kernel`.
23+
For example we recreate here the `NeuralNetworkKernel`
24+
25+
```julia
26+
struct MyKernel <: KernelFunctions.Kernel end
27+
28+
(::MyKernel)(x, y) = asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y))))
29+
```
30+
31+
Note that `BaseKernel` do not use `Distances.jl` and can therefore be a bit slower.
32+
33+
### Additional Options
1934

20-
Once these functions are defined, you can use all the wrapping functions of KernelFuntions.jl
35+
Finally there are additional functions you can define to bring in more features :
36+
- Define the trainable parameters of your kernel with `KernelFunctions.trainable(k)` which should return a `Tuple` of your parameters.
37+
This parameters will be then passed to `Flux.params` function
38+
- `KernelFunctions.iskroncompatible(k)`, if your kernel factorizes in the dimensions. You can declare your kernel as `iskroncompatible(k) = true`
39+
- `KernelFunctions.dim`: by default the dimension of the inputs will only be checked for vectors of `AbstractVector{<:Real}`.
40+
If you want to check the dimensions of your inputs, dispatch the `dim` function on your kernel. Note that `0` is the default.
41+
- You can also redefine the `kernelmatrix(k, x, y)...` functions for your kernel to eventually optimize the computations of your kernel.
42+
- `Base.print(io::IO, k::Kernel)`, if you want to specialize the printing of your kernel

docs/src/metrics.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
KernelFunctions.jl relies on [Distances.jl](https://github.com/JuliaStats/Distances.jl) for computing the pairwise matrix.
44
To do so a distance measure is needed for each kernel. Two very common ones can already be used : `SqEuclidean` and `Euclidean`.
55
However all kernels do not rely on distances metrics respecting all the definitions. That's why additional metrics come with the package such as `DotProduct` (`<x,y>`) and `Delta` (`δ(x,y)`).
6-
Note that every `BaseKernel` must have a defined metric defined as :
6+
Note that every `SimpleKernel` must have a defined metric defined as :
77
```julia
8-
metric(::CustomKernel) = SqEuclidean()
8+
KernelFunctions.metric(::CustomKernel) = SqEuclidean()
99
```
1010

1111
## Adding a new metric

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ For a transform return its parameters, for a `ChainTransform` return a vector of
9191
"""
9292
#params
9393

94-
dim(x) = 0
94+
dim(x) = 0 # This is the passes-by-default choice. For a proper check, implement `KernelFunctions.dim` for your datatype.
9595
dim(x::AbstractVector{<:AbstractVector{<:Real}}) = length(first(x))
9696
dim(x::AbstractVector{<:Real}) = 1
9797
dim(x::AbstractVector{Tuple{Any,Int}}) = 1

0 commit comments

Comments
 (0)