|
| 1 | +using OneHotArrays: OneHotVector |
| 2 | +import ForwardDiff as FD |
| 3 | +import LinearAlgebra as LA |
| 4 | + |
| 5 | +""" |
| 6 | + DiffPt(x; partial=()) |
| 7 | +
|
| 8 | +For a covariance kernel k of GP Z, i.e. |
| 9 | +```julia |
| 10 | + k(x,y) # = Cov(Z(x), Z(y)), |
| 11 | +``` |
| 12 | +a DiffPt allows the differentiation of Z, i.e. |
| 13 | +```julia |
| 14 | + k(DiffPt(x, partial=1), y) # = Cov(∂₁Z(x), Z(y)) |
| 15 | +``` |
| 16 | +for higher order derivatives partial can be any iterable, i.e. |
| 17 | +```julia |
| 18 | + k(DiffPt(x, partial=(1,2)), y) # = Cov(∂₁∂₂Z(x), Z(y)) |
| 19 | +``` |
| 20 | +""" |
| 21 | +struct DiffPt{Dim} |
| 22 | + pos # the actual position |
| 23 | + partial |
| 24 | +end |
| 25 | + |
| 26 | +DiffPt(x;partial=()) = DiffPt{length(x)}(x, partial) # convenience constructor |
| 27 | + |
| 28 | +""" |
| 29 | +Take the partial derivative of a function `fun` with input dimesion `dim`. |
| 30 | +If partials=(i,j), then (∂ᵢ∂ⱼ fun) is returned. |
| 31 | +""" |
| 32 | +function partial(fun, dim, partials=()) |
| 33 | + if !isnothing(local next = iterate(partials)) |
| 34 | + idx, state = next |
| 35 | + return partial( |
| 36 | + x -> FD.derivative(0) do dx |
| 37 | + fun(x .+ dx * OneHotVector(idx, dim)) |
| 38 | + end, |
| 39 | + dim, |
| 40 | + Base.rest(partials, state), |
| 41 | + ) |
| 42 | + end |
| 43 | + return fun |
| 44 | +end |
| 45 | + |
| 46 | +""" |
| 47 | +Take the partial derivative of a function with two dim-dimensional inputs, |
| 48 | +i.e. 2*dim dimensional input |
| 49 | +""" |
| 50 | +function partial(k, dim; partials_x=(), partials_y=()) |
| 51 | + local f(x,y) = partial(t -> k(t,y), dim, partials_x)(x) |
| 52 | + return (x,y) -> partial(t -> f(x,t), dim, partials_y)(y) |
| 53 | +end |
| 54 | + |
| 55 | + |
| 56 | + |
| 57 | + |
| 58 | +""" |
| 59 | + _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel} |
| 60 | +
|
| 61 | +implements `(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim})` for all kernel types. But since |
| 62 | +generics are not allowed in the syntax above by the dispatch system, this |
| 63 | +redirection over `_evaluate` is necessary |
| 64 | +
|
| 65 | +unboxes the partial instructions from DiffPt and applies them to k, |
| 66 | +evaluates them at the positions of DiffPt |
| 67 | +""" |
| 68 | +function _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel} |
| 69 | + return partial( |
| 70 | + k, Dim, |
| 71 | + partials_x=x.partial, partials_y=y.partial |
| 72 | + )(x.pos, y.pos) |
| 73 | +end |
| 74 | + |
| 75 | + |
| 76 | + |
| 77 | +#= |
| 78 | +This is a hack to work around the fact that the `where {T<:Kernel}` clause is |
| 79 | +not allowed for the `(::T)(x,y)` syntax. If we were to only implement |
| 80 | +```julia |
| 81 | + (::Kernel)(::DiffPt,::DiffPt) |
| 82 | +``` |
| 83 | +then julia would not know whether to use |
| 84 | +`(::SpecialKernel)(x,y)` or `(::Kernel)(x::DiffPt, y::DiffPt)` |
| 85 | +``` |
| 86 | +To avoid this hack, no kernel type T should implement |
| 87 | +```julia |
| 88 | + (::T)(x,y) |
| 89 | +``` |
| 90 | +and instead implement |
| 91 | +```julia |
| 92 | + _evaluate(k::T, x, y) |
| 93 | +``` |
| 94 | +Then there should be only a single |
| 95 | +```julia |
| 96 | + (k::Kernel)(x,y) = evaluate(k, x, y) |
| 97 | +``` |
| 98 | +which all the kernels would fall back to. |
| 99 | +
|
| 100 | +This ensures that evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) is always |
| 101 | +more specialized and call beforehand. |
| 102 | +=# |
| 103 | +for T in [SimpleKernel, Kernel] #subtypes(Kernel) |
| 104 | + (k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = evaluate(k, x, y) |
| 105 | + (k::T)(x::DiffPt{Dim}, y) where {Dim} = evaluate(k, x, DiffPt(y)) |
| 106 | + (k::T)(x, y::DiffPt{Dim}) where {Dim} = evaluate(k, DiffPt(x), y) |
| 107 | +end |
| 108 | + |
0 commit comments