Skip to content

Commit 7e2b8dc

Browse files
committed
first draft
1 parent ef6d459 commit 7e2b8dc

File tree

5 files changed

+137
-0
lines changed

5 files changed

+137
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
88
CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b"
99
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
1010
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
11+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1112
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1213
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
1314
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1415
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
16+
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
1517
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1618
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1719
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

src/KernelFunctions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ include("chainrules.jl")
125125
include("zygoterules.jl")
126126

127127
include("TestUtils.jl")
128+
include("diffKernel.jl")
128129

129130
function __init__()
130131
@require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin

src/diffKernel.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
using OneHotArrays: OneHotVector
2+
import ForwardDiff as FD
3+
import LinearAlgebra as LA
4+
5+
"""
6+
DiffPt(x; partial=())
7+
8+
For a covariance kernel k of GP Z, i.e.
9+
```julia
10+
k(x,y) # = Cov(Z(x), Z(y)),
11+
```
12+
a DiffPt allows the differentiation of Z, i.e.
13+
```julia
14+
k(DiffPt(x, partial=1), y) # = Cov(∂₁Z(x), Z(y))
15+
```
16+
for higher order derivatives partial can be any iterable, i.e.
17+
```julia
18+
k(DiffPt(x, partial=(1,2)), y) # = Cov(∂₁∂₂Z(x), Z(y))
19+
```
20+
"""
21+
struct DiffPt{Dim}
22+
pos # the actual position
23+
partial
24+
end
25+
26+
DiffPt(x;partial=()) = DiffPt{length(x)}(x, partial) # convenience constructor
27+
28+
"""
29+
Take the partial derivative of a function `fun` with input dimesion `dim`.
30+
If partials=(i,j), then (∂ᵢ∂ⱼ fun) is returned.
31+
"""
32+
function partial(fun, dim, partials=())
33+
if !isnothing(local next = iterate(partials))
34+
idx, state = next
35+
return partial(
36+
x -> FD.derivative(0) do dx
37+
fun(x .+ dx * OneHotVector(idx, dim))
38+
end,
39+
dim,
40+
Base.rest(partials, state),
41+
)
42+
end
43+
return fun
44+
end
45+
46+
"""
47+
Take the partial derivative of a function with two dim-dimensional inputs,
48+
i.e. 2*dim dimensional input
49+
"""
50+
function partial(k, dim; partials_x=(), partials_y=())
51+
local f(x,y) = partial(t -> k(t,y), dim, partials_x)(x)
52+
return (x,y) -> partial(t -> f(x,t), dim, partials_y)(y)
53+
end
54+
55+
56+
57+
58+
"""
59+
_evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel}
60+
61+
implements `(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim})` for all kernel types. But since
62+
generics are not allowed in the syntax above by the dispatch system, this
63+
redirection over `_evaluate` is necessary
64+
65+
unboxes the partial instructions from DiffPt and applies them to k,
66+
evaluates them at the positions of DiffPt
67+
"""
68+
function _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel}
69+
return partial(
70+
k, Dim,
71+
partials_x=x.partial, partials_y=y.partial
72+
)(x.pos, y.pos)
73+
end
74+
75+
76+
77+
#=
78+
This is a hack to work around the fact that the `where {T<:Kernel}` clause is
79+
not allowed for the `(::T)(x,y)` syntax. If we were to only implement
80+
```julia
81+
(::Kernel)(::DiffPt,::DiffPt)
82+
```
83+
then julia would not know whether to use
84+
`(::SpecialKernel)(x,y)` or `(::Kernel)(x::DiffPt, y::DiffPt)`
85+
```
86+
To avoid this hack, no kernel type T should implement
87+
```julia
88+
(::T)(x,y)
89+
```
90+
and instead implement
91+
```julia
92+
_evaluate(k::T, x, y)
93+
```
94+
Then there should be only a single
95+
```julia
96+
(k::Kernel)(x,y) = evaluate(k, x, y)
97+
```
98+
which all the kernels would fall back to.
99+
100+
This ensures that evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) is always
101+
more specialized and call beforehand.
102+
=#
103+
for T in [SimpleKernel, Kernel] #subtypes(Kernel)
104+
(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = evaluate(k, x, y)
105+
(k::T)(x::DiffPt{Dim}, y) where {Dim} = evaluate(k, x, DiffPt(y))
106+
(k::T)(x, y::DiffPt{Dim}) where {Dim} = evaluate(k, DiffPt(x), y)
107+
end
108+

test/diffKernel.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
@testset "diffKernel" begin
2+
@testset "smoke test" begin
3+
k = MaternKernel()
4+
k(1,1)
5+
k(1, DiffPt(1, partial=(1,1))) # Cov(Z(x), ∂₁∂₁Z(y)) where x=1, y=1
6+
k(DiffPt([1], partial=1), [2]) # Cov(∂₁Z(x), Z(y)) where x=[1], y=[2]
7+
k(DiffPt([1,2], partial=(1)), DiffPt([1,2], partial=2))# Cov(∂₁Z(x), ∂₂Z(y)) where x=[1,2], y=[1,2]
8+
end
9+
10+
@testset "Sanity Checks with $k" for k in [MaternKernel()]
11+
for x in [0, 1, -1, 42]
12+
# for stationary kernels Cov(∂Z(x) , Z(x)) = 0
13+
@test k(DiffPt(x, partial=1), x) 0
14+
15+
# the slope should be positively correlated with a point further down
16+
@test k(
17+
DiffPt(x, partial=1), # slope
18+
x + 1e-10 # point further down
19+
) > 0
20+
21+
# correlation with self should be positive
22+
@test k(DiffPt(x, partial=1), DiffPt(x, partial=1)) > 0
23+
end
24+
end
25+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ include("test_utils.jl")
176176
include("generic.jl")
177177
include("chainrules.jl")
178178
include("zygoterules.jl")
179+
include("diffKernel.jl")
179180

180181
@testset "doctests" begin
181182
DocMeta.setdocmeta!(

0 commit comments

Comments
 (0)