Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/2-deep-kernel-learning/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ AbstractGPs = "0.3,0.4,0.5"
Distributions = "0.25"
KernelFunctions = "0.10"
Literate = "2"
Lux = "0.5"
Lux = "1"
MLDataUtils = "0.5"
Optimisers = "0.3"
Optimisers = "0.4"
Plots = "1"
Zygote = "0.6, 0.7"
julia = "1.3"
Zygote = "0.7"
julia = "1.10"
56 changes: 29 additions & 27 deletions examples/2-deep-kernel-learning/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,10 @@ neuralnet = Chain(Dense(1 => 20), Dense(20 => 30), Dense(30 => 5))
rng = Random.default_rng()
ps, st = Lux.setup(rng, neuralnet)

# Create a wrapper function for the neural network that will be updated during training
function neural_transform(x, θ)
return first(neuralnet(x, θ, st))
end
smodel = StatefulLuxLayer(neuralnet, ps, st)

# We use the Squared Exponential Kernel:
k = SqExponentialKernel() ∘ FunctionTransform(x -> neural_transform(x, ps))
k = SqExponentialKernel() ∘ FunctionTransform(smodel)

# We now define our model:
gpprior = GP(k) # GP Prior
Expand All @@ -81,36 +78,41 @@ plot!(vcat(x_test...), mean.(pred_init); ribbon=std.(pred_init), label="Predicti

# ## Training
nmax = 200
opt_state = Optimisers.setup(Optimisers.Adam(0.1), ps)

# Create a wrapper function that updates the kernel with current parameters
function update_kernel_and_loss(θ_current)
k_updated =
SqExponentialKernel() ∘ FunctionTransform(x -> neural_transform(x, θ_current))
function update_kernel_and_loss(model, ps, st, data)
smodel = StatefulLuxLayer(model, ps, st)
k_updated = SqExponentialKernel() ∘ FunctionTransform(smodel)
fx_updated = AbstractGPs.FiniteGP(GP(k_updated), x_train, noise_std^2)
return -logpdf(fx_updated, y_train)
return -logpdf(fx_updated, y_train), smodel.st, (;)
end

anim = Animation()
for i in 1:nmax
loss_val, grads = Zygote.withgradient(update_kernel_and_loss, ps)
opt_state, ps = Optimisers.update(opt_state, ps, grads[1])
k = SqExponentialKernel() ∘ FunctionTransform(x -> neural_transform(x, ps))
fx = AbstractGPs.FiniteGP(GP(k), x_train, noise_std^2)

if i % 10 == 0
L = loss_val
@info "iteration $i/$nmax: loss = $L"

p = plot(; title="Loss[$i/$nmax] = $(round(L; sigdigits=6))")
plot!(vcat(x_test...), target_f; label="true f")
scatter!(vcat(x_train...), y_train; label="data")
pred = marginals(posterior(fx, y_train)(x_test))
plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), label="Prediction")
frame(anim)
display(p)
let tstate = Training.TrainState(neuralnet, ps, st, Optimisers.Adam(0.005))
for i in 1:nmax
_, loss_val, _, tstate = Training.single_train_step!(
AutoZygote(), update_kernel_and_loss, (), tstate
)

if i % 10 == 0
k = SqExponentialKernel() ∘ FunctionTransform(
StatefulLuxLayer(neuralnet, tstate.parameters, tstate.states)
)
fx = AbstractGPs.FiniteGP(GP(k), x_train, noise_std^2)

@info "iteration $i/$nmax: loss = $loss_val"

p = plot(; title="Loss[$i/$nmax] = $(round(loss_val; sigdigits=6))")
plot!(vcat(x_test...), target_f; label="true f")
scatter!(vcat(x_train...), y_train; label="data")
pred = marginals(posterior(fx, y_train)(x_test))
plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), label="Prediction")
frame(anim)
display(p)
end
end
end

gif(anim, "train-dkl.gif"; fps=3)
nothing #hide

Expand Down
Loading