You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/src/create_kernel.md
+35-2Lines changed: 35 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -33,10 +33,43 @@ Note that `BaseKernel` do not use `Distances.jl` and can therefore be a bit slow
33
33
### Additional Options
34
34
35
35
Finally there are additional functions you can define to bring in more features:
36
-
-`KernelFunctions.trainable(k::MyKernel)`: it defines the trainable parameters of your kernel, it should return a `Tuple` of your parameters.
37
-
These parameters will be passed to the `Flux.params` function. For some examples see the `trainable.jl` file in `src/`
38
36
-`KernelFunctions.iskroncompatible(k::MyKernel)`: if your kernel factorizes in dimensions, you can declare your kernel as `iskroncompatible(k) = true` to use Kronecker methods.
39
37
-`KernelFunctions.dim(x::MyDataType)`: by default the dimension of the inputs will only be checked for vectors of type `AbstractVector{<:Real}`. If you want to check the dimensionality of your inputs, dispatch the `dim` function on your datatype. Note that `0` is the default.
40
38
-`dim` is called within `KernelFunctions.validate_inputs(x::MyDataType, y::MyDataType)`, which can instead be directly overloaded if you want to run special checks for your input types.
41
39
-`kernelmatrix(k::MyKernel, ...)`: you can redefine the diverse `kernelmatrix` functions to eventually optimize the computations.
42
40
-`Base.print(io::IO, k::MyKernel)`: if you want to specialize the printing of your kernel
41
+
42
+
KernelFunctions uses [Functors.jl](https://github.com/FluxML/Functors.jl) for specifying trainable kernel parameters
43
+
in a way that is compatible with the [Flux ML framework](https://github.com/FluxML/Flux.jl).
44
+
You can use `Functors.@functor` if all fields of your kernel struct are trainable. Note that optimization algorithms
45
+
in Flux are not compatible with scalar parameters (yet), and hence vector-valued parameters should be preferred.
46
+
47
+
```julia
48
+
import Functors
49
+
50
+
struct MyKernel{T} <:KernelFunctions.Kernel
51
+
a::Vector{T}
52
+
end
53
+
54
+
Functors.@functor MyKernel
55
+
```
56
+
57
+
If only a subset of the fields are trainable, you have to specify explicitly how to (re)construct the kernel with
58
+
modified parameter values by [implementing `Functors.functor(::Type{<:MyKernel}, x)` for your kernel struct](https://github.com/FluxML/Functors.jl/issues/3):
59
+
60
+
```julia
61
+
import Functors
62
+
63
+
struct MyKernel{T} <:KernelFunctions.Kernel
64
+
n::Int
65
+
a::Vector{T}
66
+
end
67
+
68
+
function Functors.functor(::Type{<:MyKernel}, x::MyKernel)
69
+
functionreconstruct_mykernel(xs)
70
+
# keep field `n` of the original kernel and set `a` to (possibly different) `xs.a`
0 commit comments