Skip to content

Commit f512e47

Browse files
committed
Failed attempts with default Flux opt
1 parent a6d6e5b commit f512e47

File tree

1 file changed

+94
-27
lines changed

1 file changed

+94
-27
lines changed

examples/train-kernel-parameters/script.jl

Lines changed: 94 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ using LinearAlgebra
99
using Distributions
1010
using Plots;
1111
default(; lw=2.0, legendfontsize=15.0);
12+
using BenchmarkTools
13+
using Flux
1214
using Flux: Optimise
13-
using ForwardDiff
15+
using Zygote
1416
using Random: seed!
1517
seed!(42)
1618
using ParameterHandling
@@ -75,18 +77,27 @@ function loss(θ)
7577
return sum(abs2, y_train - ŷ) + exp(θ[4]) * norm(ŷ)
7678
end
7779

80+
# ## Training the model
81+
82+
θ = log.([1.1, 0.1, 0.01, 0.001]) # Initial vector
83+
opt = Optimise.ADAGrad(0.5)
84+
7885
# The loss with our starting point :
7986

80-
loss(log.(ones(4)))
87+
loss(θ)
8188

82-
# ## Training the model
89+
# Cost for one step
90+
91+
@benchmark let θt = θ[:], optt = Optimise.ADAGrad(0.5)
92+
grads = only((Zygote.gradient(loss, θt))) # We compute the gradients given the kernel parameters and regularization
93+
Optimise.update!(optt, θt, grads)
94+
end
95+
96+
# The optimization
8397

84-
θ = log.([1.1, 0.1, 0.01, 0.001]) # Initial vector
85-
# θ = positive.([1.1, 0.1, 0.01, 0.001])
8698
anim = Animation()
87-
opt = Optimise.ADAGrad(0.5)
88-
for i in 1:30
89-
grads = only(Zygote.gradient(loss, θ)) # We compute the gradients given the kernel parameters and regularization
99+
for i in 1:25
100+
grads = only((Zygote.gradient(loss, θ))) # We compute the gradients given the kernel parameters and regularization
90101
Optimise.update!(opt, θ, grads)
91102
scatter(
92103
x_train, y_train; lab="data", title="i = $(i), Loss = $(round(loss(θ), digits = 4))"
@@ -97,6 +108,9 @@ for i in 1:30
97108
end
98109
gif(anim)
99110

111+
# Final loss
112+
loss(θ)
113+
100114
# ### ParameterHandling.jl
101115
# Alternatively, we can use the [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl) package
102116
# to handle the requirement that all kernel parameters should be positive.
@@ -130,40 +144,93 @@ initial_θ = ParameterHandling.value(raw_initial_θ)
130144

131145
# The loss with our starting point :
132146

133-
loss(initial_θ)
147+
(loss unflatten)(flat_θ)
134148

135149
# ## Training the model
136150

137-
anim = Animation()
151+
# ### Cost per step
152+
153+
@benchmark let θt = flat_θ[:], optt = Optimise.ADAGrad(0.5)
154+
grads = (Zygote.gradient(loss unflatten, θt))[1] # We compute the gradients given the kernel parameters and regularization
155+
Optimise.update!(optt, θt, grads)
156+
end
157+
138158
opt = Optimise.ADAGrad(0.5)
139-
for i in 1:30
140-
grads = only(Zygote.gradient(loss unflatten, flat_θ)) # We compute the gradients given the kernel parameters and regularization
159+
for i in 1:25
160+
grads = (Zygote.gradient(loss unflatten, flat_θ))[1] # We compute the gradients given the kernel parameters and regularization
141161
Optimise.update!(opt, flat_θ, grads)
142-
scatter(
143-
x_train, y_train; lab="data", title="i = $(i), Loss = $(round((loss unflatten)(flat_θ), digits = 4))"
144-
)
145-
plot!(x_test, sinc; lab="true function")
146-
plot!(x_test, f(x_test, x_train, y_train, unflatten(flat_θ)); lab="Prediction", lw=3.0)
147-
frame(anim)
148162
end
149-
gif(anim)
163+
164+
# Final loss
165+
166+
(loss unflatten)(flat_θ)
150167

151168

152169
# ## Method 2: Functor
153-
# An alternative method is to use tools from Flux.jl, which is a fairly heavy package.
170+
# An alternative method is to use tools from Flux.jl.
154171

155172
# raw_initial_θ = (
156173
# k1 = positive(1.1),
157174
# k2 = positive(0.1),
158175
# k3 = positive(0.01),
159176
# noise_var=positive(0.001),
160177
# )
161-
k1 = 1.1
162-
k2 = 0.1
163-
k3 = 0.01
164-
noise_var = 0.001
178+
k1 = [1.1]
179+
k2 = [0.1]
180+
k3 = [0.01]
181+
noise_var = log.([0.001])
165182

166-
kernel = (k1 * SqExponentialKernel() + k2 * Matern32Kernel())
167-
ScaleTransform(k3)
183+
kernel = (ScaledKernel(SqExponentialKernel(), relu.(k1)) + ScaledKernel(Matern32Kernel(), k2))
184+
ScaleTransform(map(exp,k3))
185+
186+
θ = Flux.params(k1, k2, k3, noise_var)
187+
188+
# kernel = (ScaledKernel(SqExponentialKernel(), softplus(θ[1])) + ScaledKernel(Matern32Kernel(), θ[2])) ∘
189+
# ScaleTransform(θ[3])
190+
191+
# This next
192+
193+
# function loss2()
194+
# ŷ = kernelmatrix(kernel, x_train, x_train) * ((kernelmatrix(kernel, x_train) + θ[4][1] * I) \ y_train)
195+
# return sum(abs2, y_train - ŷ) + θ[4][1] * norm(ŷ)
196+
# end
197+
198+
function loss()
199+
= kernelmatrix(kernel, x_train, x_train) * ((kernelmatrix(kernel, x_train)) \ y_train)
200+
return sum(abs2, y_train - ŷ) + only(exp.(noise_var) .* norm(ŷ))
201+
end
202+
203+
function f(x, x_train, y_train)
204+
return kernelmatrix(kernel, x, x_train) *
205+
((kernelmatrix(kernel, x_train) + only(exp.(noise_var)) * I) \ y_train)
206+
end
207+
208+
209+
grads = Flux.gradient(loss, θ)
210+
for p in θ
211+
println(grads[p])
212+
end
213+
214+
215+
grads = Flux.gradient(loss, θ)
216+
217+
η = 0.1 # Learning Rate
218+
opt = Optimise.ADAGrad(η)
219+
# for p in θ
220+
# update!(p, η * grads[p])
221+
# end
222+
223+
anim = Animation()
224+
for i in 1:25
225+
Optimise.update!(opt, θ, grads)
226+
println(θ)
227+
228+
scatter(
229+
x_train, y_train; lab="data", title="i = $(i), Loss = $(round(loss(), digits = 4))"
230+
)
231+
plot!(x_test, sinc; lab="true function")
232+
plot!(x_test, f(x_test, x_train, y_train); lab="Prediction", lw=3.0)
233+
frame(anim)
234+
end
168235

169-
Θ = Flux.params(k1, k2, k3)
236+
gif(anim)

0 commit comments

Comments
 (0)