diff --git a/examples/2-deep-kernel-learning/Project.toml b/examples/2-deep-kernel-learning/Project.toml index 2db67df5..1c205098 100644 --- a/examples/2-deep-kernel-learning/Project.toml +++ b/examples/2-deep-kernel-learning/Project.toml @@ -16,9 +16,9 @@ AbstractGPs = "0.3,0.4,0.5" Distributions = "0.25" KernelFunctions = "0.10" Literate = "2" -Lux = "0.5" +Lux = "1" MLDataUtils = "0.5" -Optimisers = "0.3" +Optimisers = "0.4" Plots = "1" -Zygote = "0.6, 0.7" -julia = "1.3" +Zygote = "0.7" +julia = "1.10" diff --git a/examples/2-deep-kernel-learning/script.jl b/examples/2-deep-kernel-learning/script.jl index e9180c35..a5ddcdc5 100644 --- a/examples/2-deep-kernel-learning/script.jl +++ b/examples/2-deep-kernel-learning/script.jl @@ -53,13 +53,10 @@ neuralnet = Chain(Dense(1 => 20), Dense(20 => 30), Dense(30 => 5)) rng = Random.default_rng() ps, st = Lux.setup(rng, neuralnet) -# Create a wrapper function for the neural network that will be updated during training -function neural_transform(x, θ) - return first(neuralnet(x, θ, st)) -end +smodel = StatefulLuxLayer(neuralnet, ps, st) # We use the Squared Exponential Kernel: -k = SqExponentialKernel() ∘ FunctionTransform(x -> neural_transform(x, ps)) +k = SqExponentialKernel() ∘ FunctionTransform(smodel) # We now define our model: gpprior = GP(k) # GP Prior @@ -81,36 +78,41 @@ plot!(vcat(x_test...), mean.(pred_init); ribbon=std.(pred_init), label="Predicti # ## Training nmax = 200 -opt_state = Optimisers.setup(Optimisers.Adam(0.1), ps) # Create a wrapper function that updates the kernel with current parameters -function update_kernel_and_loss(θ_current) - k_updated = - SqExponentialKernel() ∘ FunctionTransform(x -> neural_transform(x, θ_current)) +function update_kernel_and_loss(model, ps, st, data) + smodel = StatefulLuxLayer(model, ps, st) + k_updated = SqExponentialKernel() ∘ FunctionTransform(smodel) fx_updated = AbstractGPs.FiniteGP(GP(k_updated), x_train, noise_std^2) - return -logpdf(fx_updated, y_train) + return -logpdf(fx_updated, y_train), smodel.st, (;) end anim = Animation() -for i in 1:nmax - loss_val, grads = Zygote.withgradient(update_kernel_and_loss, ps) - opt_state, ps = Optimisers.update(opt_state, ps, grads[1]) - k = SqExponentialKernel() ∘ FunctionTransform(x -> neural_transform(x, ps)) - fx = AbstractGPs.FiniteGP(GP(k), x_train, noise_std^2) - - if i % 10 == 0 - L = loss_val - @info "iteration $i/$nmax: loss = $L" - - p = plot(; title="Loss[$i/$nmax] = $(round(L; sigdigits=6))") - plot!(vcat(x_test...), target_f; label="true f") - scatter!(vcat(x_train...), y_train; label="data") - pred = marginals(posterior(fx, y_train)(x_test)) - plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), label="Prediction") - frame(anim) - display(p) +let tstate = Training.TrainState(neuralnet, ps, st, Optimisers.Adam(0.005)) + for i in 1:nmax + _, loss_val, _, tstate = Training.single_train_step!( + AutoZygote(), update_kernel_and_loss, (), tstate + ) + + if i % 10 == 0 + k = SqExponentialKernel() ∘ FunctionTransform( + StatefulLuxLayer(neuralnet, tstate.parameters, tstate.states) + ) + fx = AbstractGPs.FiniteGP(GP(k), x_train, noise_std^2) + + @info "iteration $i/$nmax: loss = $loss_val" + + p = plot(; title="Loss[$i/$nmax] = $(round(loss_val; sigdigits=6))") + plot!(vcat(x_test...), target_f; label="true f") + scatter!(vcat(x_train...), y_train; label="data") + pred = marginals(posterior(fx, y_train)(x_test)) + plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), label="Prediction") + frame(anim) + display(p) + end end end + gif(anim, "train-dkl.gif"; fps=3) nothing #hide