|
2 | 2 |
|
3 | 3 | KernelFunctions.jl contains the most popular kernels already but you might want to make your own!
|
4 | 4 |
|
5 |
| -Here is for example how one can define the Squared Exponential Kernel again : |
| 5 | +Here are a few ways depending on how complicated your kernel is : |
| 6 | + |
| 7 | +### SimpleKernel for kernels function depending on a metric |
| 8 | + |
| 9 | +If your kernel function is of the form `k(x, y) = f(d(x, y))` where `d(x, y)` is a `PreMetric`, |
| 10 | +you can construct your custom kernel by defining `kappa` and `metric` for your kernel. |
| 11 | +Here is for example how one can define the `SqExponentialKernel` again : |
6 | 12 |
|
7 | 13 | ```julia
|
8 |
| -struct MyKernel <: Kernel end |
| 14 | +struct MyKernel <: KernelFunctions.SimpleKernel end |
9 | 15 |
|
10 | 16 | KernelFunctions.kappa(::MyKernel, d2::Real) = exp(-d2)
|
11 | 17 | KernelFunctions.metric(::MyKernel) = SqEuclidean()
|
12 | 18 | ```
|
13 | 19 |
|
14 |
| -For a "Base" kernel, where the kernel function is simply a function applied on some metric between two vectors of real, you only need to: |
15 |
| - - Define your struct inheriting from `Kernel`. |
16 |
| - - Define a `kappa` function. |
17 |
| - - Define the metric used `SqEuclidean`, `DotProduct` etc. Note that the term "metric" is here overabused. |
18 |
| - - Optional : Define any parameter of your kernel as `trainable` by Flux.jl if you want to perform optimization on the parameters. We recommend wrapping all parameters in arrays to allow them to be mutable. |
| 20 | +### Kernel for more complex kernels |
| 21 | + |
| 22 | +If your kernel does not satisfy such a representation, all you need to do is define `(k::MyKernel)(x, y)` and inherit from `Kernel`. |
| 23 | +For example we recreate here the `NeuralNetworkKernel` |
| 24 | + |
| 25 | +```julia |
| 26 | +struct MyKernel <: KernelFunctions.Kernel end |
| 27 | + |
| 28 | +(::MyKernel)(x, y) = asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y)))) |
| 29 | +``` |
| 30 | + |
| 31 | +Note that `BaseKernel` do not use `Distances.jl` and can therefore be a bit slower. |
| 32 | + |
| 33 | +### Additional Options |
19 | 34 |
|
20 |
| -Once these functions are defined, you can use all the wrapping functions of KernelFuntions.jl |
| 35 | +Finally there are additional functions you can define to bring in more features : |
| 36 | + - Define the trainable parameters of your kernel with `KernelFunctions.trainable(k)` which should return a `Tuple` of your parameters. |
| 37 | +This parameters will be then passed to `Flux.params` function |
| 38 | + - `KernelFunctions.iskroncompatible(k)`, if your kernel factorizes in the dimensions. You can declare your kernel as `iskroncompatible(k) = true` |
| 39 | + - `KernelFunctions.dim`: by default the dimension of the inputs will only be checked for vectors of `AbstractVector{<:Real}`. |
| 40 | +If you want to check the dimensions of your inputs, dispatch the `dim` function on your kernel. Note that `0` is the default. |
| 41 | + - You can also redefine the `kernelmatrix(k, x, y)...` functions for your kernel to eventually optimize the computations of your kernel. |
| 42 | + - `Base.print(io::IO, k::Kernel)`, if you want to specialize the printing of your kernel |
0 commit comments