Skip to content

Commit a6d6e5b

Browse files
committed
Extend example
1 parent 30171c9 commit a6d6e5b

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

examples/train-kernel-parameters/script.jl

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ x_test = range(xmin - 0.1, xmax + 0.1; length=300)
3131
scatter(x_train, y_train; lab="data")
3232
plot!(x_test, sinc; lab="true function")
3333

34-
# ## Kernel training, Method 1
34+
# ## Method 1
3535
# The first method is to rebuild the parametrized kernel from a vector of parameters
3636
# in each evaluation of the cost fuction. This is similar to the approach taken in
3737
# [Stheno.jl](https://github.com/JuliaGaussianProcesses/Stheno.jl).
3838

3939

40-
# ### Simplest Approach
40+
# ### Base Approach
4141
# A simple way to ensure that the kernel parameters are positive
4242
# is to optimize over the logarithm of the parameters.
4343

@@ -86,7 +86,6 @@ loss(log.(ones(4)))
8686
anim = Animation()
8787
opt = Optimise.ADAGrad(0.5)
8888
for i in 1:30
89-
println(i)
9089
grads = only(Zygote.gradient(loss, θ)) # We compute the gradients given the kernel parameters and regularization
9190
Optimise.update!(opt, θ, grads)
9291
scatter(
@@ -99,7 +98,8 @@ end
9998
gif(anim)
10099

101100
# ### ParameterHandling.jl
102-
# Alternatively, we can use the [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl) package.
101+
# Alternatively, we can use the [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl) package
102+
# to handle the requirement that all kernel parameters should be positive.
103103

104104
raw_initial_θ = (
105105
k1 = positive(1.1),
@@ -137,7 +137,6 @@ loss(initial_θ)
137137
anim = Animation()
138138
opt = Optimise.ADAGrad(0.5)
139139
for i in 1:30
140-
println(i)
141140
grads = only(Zygote.gradient(loss unflatten, flat_θ)) # We compute the gradients given the kernel parameters and regularization
142141
Optimise.update!(opt, flat_θ, grads)
143142
scatter(
@@ -147,4 +146,24 @@ for i in 1:30
147146
plot!(x_test, f(x_test, x_train, y_train, unflatten(flat_θ)); lab="Prediction", lw=3.0)
148147
frame(anim)
149148
end
150-
gif(anim)
149+
gif(anim)
150+
151+
152+
# ## Method 2: Functor
153+
# An alternative method is to use tools from Flux.jl, which is a fairly heavy package.
154+
155+
# raw_initial_θ = (
156+
# k1 = positive(1.1),
157+
# k2 = positive(0.1),
158+
# k3 = positive(0.01),
159+
# noise_var=positive(0.001),
160+
# )
161+
k1 = 1.1
162+
k2 = 0.1
163+
k3 = 0.01
164+
noise_var = 0.001
165+
166+
kernel = (k1 * SqExponentialKernel() + k2 * Matern32Kernel())
167+
ScaleTransform(k3)
168+
169+
Θ = Flux.params(k1, k2, k3)

0 commit comments

Comments
 (0)