Skip to content

Commit 1bba715

Browse files
authored
Update documentation
1 parent 472a9c3 commit 1bba715

File tree

1 file changed

+35
-2
lines changed

1 file changed

+35
-2
lines changed

docs/src/create_kernel.md

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,43 @@ Note that `BaseKernel` do not use `Distances.jl` and can therefore be a bit slow
3333
### Additional Options
3434

3535
Finally there are additional functions you can define to bring in more features:
36-
- `KernelFunctions.trainable(k::MyKernel)`: it defines the trainable parameters of your kernel, it should return a `Tuple` of your parameters.
37-
These parameters will be passed to the `Flux.params` function. For some examples see the `trainable.jl` file in `src/`
3836
- `KernelFunctions.iskroncompatible(k::MyKernel)`: if your kernel factorizes in dimensions, you can declare your kernel as `iskroncompatible(k) = true` to use Kronecker methods.
3937
- `KernelFunctions.dim(x::MyDataType)`: by default the dimension of the inputs will only be checked for vectors of type `AbstractVector{<:Real}`. If you want to check the dimensionality of your inputs, dispatch the `dim` function on your datatype. Note that `0` is the default.
4038
- `dim` is called within `KernelFunctions.validate_inputs(x::MyDataType, y::MyDataType)`, which can instead be directly overloaded if you want to run special checks for your input types.
4139
- `kernelmatrix(k::MyKernel, ...)`: you can redefine the diverse `kernelmatrix` functions to eventually optimize the computations.
4240
- `Base.print(io::IO, k::MyKernel)`: if you want to specialize the printing of your kernel
41+
42+
KernelFunctions uses [Functors.jl](https://github.com/FluxML/Functors.jl) for specifying trainable kernel parameters
43+
in a way that is compatible with the [Flux ML framework](https://github.com/FluxML/Flux.jl).
44+
You can use `Functors.@functor` if all fields of your kernel struct are trainable. Note that optimization algorithms
45+
in Flux are not compatible with scalar parameters (yet), and hence vector-valued parameters should be preferred.
46+
47+
```julia
48+
import Functors
49+
50+
struct MyKernel{T} <: KernelFunctions.Kernel
51+
a::Vector{T}
52+
end
53+
54+
Functors.@functor MyKernel
55+
```
56+
57+
If only a subset of the fields are trainable, you have to specify explicitly how to (re)construct the kernel with
58+
modified parameter values by [implementing `Functors.functor(::Type{<:MyKernel}, x)` for your kernel struct](https://github.com/FluxML/Functors.jl/issues/3):
59+
60+
```julia
61+
import Functors
62+
63+
struct MyKernel{T} <: KernelFunctions.Kernel
64+
n::Int
65+
a::Vector{T}
66+
end
67+
68+
function Functors.functor(::Type{<:MyKernel}, x::MyKernel)
69+
function reconstruct_mykernel(xs)
70+
# keep field `n` of the original kernel and set `a` to (possibly different) `xs.a`
71+
return MyKernel(x.n, xs.a)
72+
end
73+
return (a = x.a,), reconstruct_mykernel
74+
end
75+
```

0 commit comments

Comments
 (0)