Skip to content

Commit 30171c9

Browse files
committed
Add ParameterHandling
1 parent 9e1f96a commit 30171c9

File tree

1 file changed

+70
-4
lines changed

1 file changed

+70
-4
lines changed

examples/train-kernel-parameters/script.jl

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# # Kernel Ridge Regression
22

3+
# In this example we show the two main methods to perform regression on a kernel from KernelFunctions.jl.
4+
35
# ## We load KernelFunctions and some other packages
46

57
using KernelFunctions
@@ -11,6 +13,7 @@ using Flux: Optimise
1113
using ForwardDiff
1214
using Random: seed!
1315
seed!(42)
16+
using ParameterHandling
1417

1518
# ## Data Generation
1619
# We generated data in 1 dimension
@@ -28,9 +31,18 @@ x_test = range(xmin - 0.1, xmax + 0.1; length=300)
2831
scatter(x_train, y_train; lab="data")
2932
plot!(x_test, sinc; lab="true function")
3033

31-
# ## Kernel training
34+
# ## Kernel training, Method 1
35+
# The first method is to rebuild the parametrized kernel from a vector of parameters
36+
# in each evaluation of the cost fuction. This is similar to the approach taken in
37+
# [Stheno.jl](https://github.com/JuliaGaussianProcesses/Stheno.jl).
38+
39+
40+
# ### Simplest Approach
41+
# A simple way to ensure that the kernel parameters are positive
42+
# is to optimize over the logarithm of the parameters.
43+
3244
# To train the kernel parameters via ForwardDiff.jl
33-
# we need to create a function creating a kernel from an array
45+
# we need to create a function creating a kernel from an array.
3446

3547
function kernelcall(θ)
3648
return (exp(θ[1]) * SqExponentialKernel() + exp(θ[2]) * Matern32Kernel())
@@ -54,6 +66,7 @@ scatter(x_train, y_train; lab="data")
5466
plot!(x_test, sinc; lab="true function")
5567
plot!(x_test, ŷ; lab="prediction")
5668

69+
5770
# We define the loss based on the L2 norm both
5871
# for the loss and the regularization
5972

@@ -68,11 +81,13 @@ loss(log.(ones(4)))
6881

6982
# ## Training the model
7083

71-
θ = log.([1.0, 0.1, 0.01, 0.001]) # Initial vector
84+
θ = log.([1.1, 0.1, 0.01, 0.001]) # Initial vector
85+
# θ = positive.([1.1, 0.1, 0.01, 0.001])
7286
anim = Animation()
7387
opt = Optimise.ADAGrad(0.5)
7488
for i in 1:30
75-
grads = ForwardDiff.gradient(loss, θ) # We compute the gradients given the kernel parameters and regularization
89+
println(i)
90+
grads = only(Zygote.gradient(loss, θ)) # We compute the gradients given the kernel parameters and regularization
7691
Optimise.update!(opt, θ, grads)
7792
scatter(
7893
x_train, y_train; lab="data", title="i = $(i), Loss = $(round(loss(θ), digits = 4))"
@@ -82,3 +97,54 @@ for i in 1:30
8297
frame(anim)
8398
end
8499
gif(anim)
100+
101+
# ### ParameterHandling.jl
102+
# Alternatively, we can use the [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl) package.
103+
104+
raw_initial_θ = (
105+
k1 = positive(1.1),
106+
k2 = positive(0.1),
107+
k3 = positive(0.01),
108+
noise_var=positive(0.001),
109+
)
110+
111+
flat_θ, unflatten = ParameterHandling.value_flatten(raw_initial_θ);
112+
113+
function kernelcall(θ)
114+
return.k1 * SqExponentialKernel() + θ.k2 * Matern32Kernel())
115+
ScaleTransform.k3)
116+
end
117+
118+
function f(x, x_train, y_train, θ)
119+
k = kernelcall(θ)
120+
return kernelmatrix(k, x, x_train) *
121+
((kernelmatrix(k, x_train) + θ.noise_var * I) \ y_train)
122+
end
123+
124+
function loss(θ)
125+
= f(x_train, x_train, y_train, θ)
126+
return sum(abs2, y_train - ŷ) + θ.noise_var * norm(ŷ)
127+
end
128+
129+
initial_θ = ParameterHandling.value(raw_initial_θ)
130+
131+
# The loss with our starting point :
132+
133+
loss(initial_θ)
134+
135+
# ## Training the model
136+
137+
anim = Animation()
138+
opt = Optimise.ADAGrad(0.5)
139+
for i in 1:30
140+
println(i)
141+
grads = only(Zygote.gradient(loss unflatten, flat_θ)) # We compute the gradients given the kernel parameters and regularization
142+
Optimise.update!(opt, flat_θ, grads)
143+
scatter(
144+
x_train, y_train; lab="data", title="i = $(i), Loss = $(round((loss unflatten)(flat_θ), digits = 4))"
145+
)
146+
plot!(x_test, sinc; lab="true function")
147+
plot!(x_test, f(x_test, x_train, y_train, unflatten(flat_θ)); lab="Prediction", lw=3.0)
148+
frame(anim)
149+
end
150+
gif(anim)

0 commit comments

Comments
 (0)