Skip to content

Commit 0fb3d7f

Browse files
committed
Add Flux.destructure example
1 parent f512e47 commit 0fb3d7f

File tree

1 file changed

+89
-98
lines changed

1 file changed

+89
-98
lines changed

examples/train-kernel-parameters/script.jl

Lines changed: 89 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# In this example we show the two main methods to perform regression on a kernel from KernelFunctions.jl.
44

5-
# ## We load KernelFunctions and some other packages
5+
# We load KernelFunctions and some other packages
66

77
using KernelFunctions
88
using LinearAlgebra
@@ -14,8 +14,7 @@ using Flux
1414
using Flux: Optimise
1515
using Zygote
1616
using Random: seed!
17-
seed!(42)
18-
using ParameterHandling
17+
seed!(42);
1918

2019
# ## Data Generation
2120
# We generated data in 1 dimension
@@ -26,30 +25,29 @@ N = 50 # Number of samples
2625
x_train = rand(Uniform(xmin, xmax), N) # We sample 100 random samples
2726
σ = 0.1
2827
y_train = sinc.(x_train) + randn(N) * σ # We create a function and add some noise
29-
x_test = range(xmin - 0.1, xmax + 0.1; length=300)
30-
28+
x_test = range(xmin - 0.1, xmax + 0.1; length=300);
3129
# Plot the data
3230

33-
scatter(x_train, y_train; lab="data")
34-
plot!(x_test, sinc; lab="true function")
31+
## scatter(x_train, y_train; lab="data")
32+
## plot!(x_test, sinc; lab="true function")
3533

36-
# ## Method 1
37-
# The first method is to rebuild the parametrized kernel from a vector of parameters
38-
# in each evaluation of the cost fuction. This is similar to the approach taken in
39-
# [Stheno.jl](https://github.com/JuliaGaussianProcesses/Stheno.jl).
4034

4135

42-
# ### Base Approach
43-
# A simple way to ensure that the kernel parameters are positive
44-
# is to optimize over the logarithm of the parameters.
36+
37+
# ## Base Approach
38+
# The first option is to rebuild the parametrized kernel from a vector of parameters
39+
# in each evaluation of the cost fuction. This is similar to the approach taken in
40+
# [Stheno.jl](https://github.com/JuliaGaussianProcesses/Stheno.jl).
4541

4642
# To train the kernel parameters via ForwardDiff.jl
4743
# we need to create a function creating a kernel from an array.
44+
# A simple way to ensure that the kernel parameters are positive
45+
# is to optimize over the logarithm of the parameters.
4846

4947
function kernelcall(θ)
5048
return (exp(θ[1]) * SqExponentialKernel() + exp(θ[2]) * Matern32Kernel())
5149
ScaleTransform(exp(θ[3]))
52-
end
50+
end;
5351

5452
# From theory we know the prediction for a test set x given
5553
# the kernel parameters and normalization constant
@@ -58,63 +56,66 @@ function f(x, x_train, y_train, θ)
5856
k = kernelcall(θ[1:3])
5957
return kernelmatrix(k, x, x_train) *
6058
((kernelmatrix(k, x_train) + exp(θ[4]) * I) \ y_train)
61-
end
59+
end;
6260

6361
# We look how the prediction looks like
6462
# with starting parameters [1.0, 1.0, 1.0, 1.0] we get :
6563

66-
= f(x_test, x_train, y_train, log.(ones(4)))
67-
scatter(x_train, y_train; lab="data")
68-
plot!(x_test, sinc; lab="true function")
69-
plot!(x_test, ŷ; lab="prediction")
70-
64+
= f(x_test, x_train, y_train, log.(ones(4)));
65+
## scatter(x_train, y_train; lab="data")
66+
## plot!(x_test, sinc; lab="true function")
67+
## plot!(x_test, ŷ; lab="prediction")
7168

7269
# We define the loss based on the L2 norm both
7370
# for the loss and the regularization
7471

7572
function loss(θ)
7673
= f(x_train, x_train, y_train, θ)
7774
return sum(abs2, y_train - ŷ) + exp(θ[4]) * norm(ŷ)
78-
end
79-
80-
# ## Training the model
75+
end;
8176

77+
# ### Training
78+
# Setting an initial value and initializing the optimizer:
8279
θ = log.([1.1, 0.1, 0.01, 0.001]) # Initial vector
83-
opt = Optimise.ADAGrad(0.5)
80+
opt = Optimise.ADAGrad(0.5);
8481

85-
# The loss with our starting point :
82+
# The loss with our starting point:
8683

8784
loss(θ)
8885

89-
# Cost for one step
86+
# Computational cost for one step
9087

9188
@benchmark let θt = θ[:], optt = Optimise.ADAGrad(0.5)
92-
grads = only((Zygote.gradient(loss, θt))) # We compute the gradients given the kernel parameters and regularization
89+
grads = only((Zygote.gradient(loss, θt)))
9390
Optimise.update!(optt, θt, grads)
9491
end
9592

96-
# The optimization
93+
# Optimizing
9794

98-
anim = Animation()
95+
## anim = Animation()
9996
for i in 1:25
100-
grads = only((Zygote.gradient(loss, θ))) # We compute the gradients given the kernel parameters and regularization
97+
grads = only((Zygote.gradient(loss, θ)))
10198
Optimise.update!(opt, θ, grads)
102-
scatter(
103-
x_train, y_train; lab="data", title="i = $(i), Loss = $(round(loss(θ), digits = 4))"
104-
)
105-
plot!(x_test, sinc; lab="true function")
106-
plot!(x_test, f(x_test, x_train, y_train, θ); lab="Prediction", lw=3.0)
107-
frame(anim)
108-
end
109-
gif(anim)
99+
end;
100+
## scatter(
101+
## x_train, y_train; lab="data", title="i = $(i), Loss = $(round(loss(θ), digits = 4))"
102+
## )
103+
## plot!(x_test, sinc; lab="true function")
104+
## plot!(x_test, f(x_test, x_train, y_train, θ); lab="Prediction", lw=3.0)
105+
## frame(anim)
106+
## end
107+
## gif(anim)
110108

111109
# Final loss
112110
loss(θ)
113111

114-
# ### ParameterHandling.jl
112+
113+
# ## Using ParameterHandling.jl
115114
# Alternatively, we can use the [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl) package
116115
# to handle the requirement that all kernel parameters should be positive.
117116

117+
using ParameterHandling
118+
118119
raw_initial_θ = (
119120
k1 = positive(1.1),
120121
k2 = positive(0.1),
@@ -127,20 +128,20 @@ flat_θ, unflatten = ParameterHandling.value_flatten(raw_initial_θ);
127128
function kernelcall(θ)
128129
return.k1 * SqExponentialKernel() + θ.k2 * Matern32Kernel())
129130
ScaleTransform.k3)
130-
end
131+
end;
131132

132133
function f(x, x_train, y_train, θ)
133134
k = kernelcall(θ)
134135
return kernelmatrix(k, x, x_train) *
135136
((kernelmatrix(k, x_train) + θ.noise_var * I) \ y_train)
136-
end
137+
end;
137138

138139
function loss(θ)
139140
= f(x_train, x_train, y_train, θ)
140141
return sum(abs2, y_train - ŷ) + θ.noise_var * norm(ŷ)
141-
end
142+
end;
142143

143-
initial_θ = ParameterHandling.value(raw_initial_θ)
144+
initial_θ = ParameterHandling.value(raw_initial_θ);
144145

145146
# The loss with our starting point :
146147

@@ -151,86 +152,76 @@ initial_θ = ParameterHandling.value(raw_initial_θ)
151152
# ### Cost per step
152153

153154
@benchmark let θt = flat_θ[:], optt = Optimise.ADAGrad(0.5)
154-
grads = (Zygote.gradient(loss unflatten, θt))[1] # We compute the gradients given the kernel parameters and regularization
155+
grads = (Zygote.gradient(loss unflatten, θt))[1]
155156
Optimise.update!(optt, θt, grads)
156157
end
157158

159+
# ### Complete optimization
160+
158161
opt = Optimise.ADAGrad(0.5)
159162
for i in 1:25
160-
grads = (Zygote.gradient(loss unflatten, flat_θ))[1] # We compute the gradients given the kernel parameters and regularization
163+
grads = (Zygote.gradient(loss unflatten, flat_θ))[1]
161164
Optimise.update!(opt, flat_θ, grads)
162-
end
165+
end;
163166

164167
# Final loss
165168

166169
(loss unflatten)(flat_θ)
167170

168171

169-
# ## Method 2: Functor
170-
# An alternative method is to use tools from Flux.jl.
172+
# ## Flux.destructure
173+
# If don't want to write an explicit function to construct the kernel, we can alternatively use the `Flux.destructure` function.
174+
# Again, we need to ensure that the parameters are positive. Note that the `exp` function now has to be in a different position.
171175

172-
# raw_initial_θ = (
173-
# k1 = positive(1.1),
174-
# k2 = positive(0.1),
175-
# k3 = positive(0.01),
176-
# noise_var=positive(0.001),
177-
# )
178-
k1 = [1.1]
179-
k2 = [0.1]
180-
k3 = [0.01]
181-
noise_var = log.([0.001])
182176

183-
kernel = (ScaledKernel(SqExponentialKernel(), relu.(k1)) + ScaledKernel(Matern32Kernel(), k2))
184-
ScaleTransform(map(exp,k3))
177+
θ = [1.1, 0.1, 0.01, 0.001]
185178

186-
θ = Flux.params(k1, k2, k3, noise_var)
179+
kernel = (θ[1] * SqExponentialKernel() + θ[2] * Matern32Kernel())
180+
ScaleTransform(θ[3])
187181

188-
# kernel = (ScaledKernel(SqExponentialKernel(), softplus(θ[1])) + ScaledKernel(Matern32Kernel(), θ[2])) ∘
189-
# ScaleTransform(θ[3])
182+
p, kernelc = Flux.destructure(kernel);
190183

191-
# This next
184+
# From theory we know the prediction for a test set x given
185+
# the kernel parameters and normalization constant
192186

193-
# function loss2()
194-
# ŷ = kernelmatrix(kernel, x_train, x_train) * ((kernelmatrix(kernel, x_train) + θ[4][1] * I) \ y_train)
195-
# return sum(abs2, y_train - ŷ) + θ[4][1] * norm(ŷ)
196-
# end
187+
function f(x, x_train, y_train, θ)
188+
k = kernelc(θ[1:3])
189+
return kernelmatrix(k, x, x_train) *
190+
((kernelmatrix(k, x_train) + (θ[4]) * I) \ y_train)
191+
end;
197192

198-
function loss()
199-
= kernelmatrix(kernel, x_train, x_train) * ((kernelmatrix(kernel, x_train)) \ y_train)
200-
return sum(abs2, y_train - ŷ) + only(exp.(noise_var) .* norm(ŷ))
201-
end
202193

203-
function f(x, x_train, y_train)
204-
return kernelmatrix(kernel, x, x_train) *
205-
((kernelmatrix(kernel, x_train) + only(exp.(noise_var)) * I) \ y_train)
206-
end
194+
# We define the loss based on the L2 norm both
195+
# for the loss and the regularization
207196

197+
function loss(θ)
198+
= f(x_train, x_train, y_train, exp.(θ))
199+
return sum(abs2, y_train - ŷ) + exp(θ[4]) * norm(ŷ)
200+
end;
208201

209-
grads = Flux.gradient(loss, θ)
210-
for p in θ
211-
println(grads[p])
212-
end
202+
# ## Training the model
203+
204+
# The loss with our starting point :
205+
θ = log.([1.1, 0.1, 0.01, 0.001]) # Initial vector
206+
loss(θ)
213207

208+
# Initialize optimizer
214209

215-
grads = Flux.gradient(loss, θ)
210+
opt = Optimise.ADAGrad(0.5)
216211

217-
η = 0.1 # Learning Rate
218-
opt = Optimise.ADAGrad(η)
219-
# for p in θ
220-
# update!(p, η * grads[p])
221-
# end
212+
# Cost for one step
213+
214+
@benchmark let θt = θ[:], optt = Optimise.ADAGrad(0.5)
215+
grads = only((Zygote.gradient(loss, θt))) # We compute the gradients given the kernel parameters and regularization
216+
Optimise.update!(optt, θt, grads)
217+
end
218+
219+
# The optimization
222220

223-
anim = Animation()
224221
for i in 1:25
222+
grads = only((Zygote.gradient(loss, θ))) # We compute the gradients given the kernel parameters and regularization
225223
Optimise.update!(opt, θ, grads)
226-
println(θ)
227-
228-
scatter(
229-
x_train, y_train; lab="data", title="i = $(i), Loss = $(round(loss(), digits = 4))"
230-
)
231-
plot!(x_test, sinc; lab="true function")
232-
plot!(x_test, f(x_test, x_train, y_train); lab="Prediction", lw=3.0)
233-
frame(anim)
234-
end
224+
end;
235225

236-
gif(anim)
226+
# Final loss
227+
loss(θ)

0 commit comments

Comments
 (0)