diff --git a/examples/2-deep-kernel-learning/Project.toml b/examples/2-deep-kernel-learning/Project.toml index 94f146e4..1c205098 100644 --- a/examples/2-deep-kernel-learning/Project.toml +++ b/examples/2-deep-kernel-learning/Project.toml @@ -1,21 +1,24 @@ [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractGPs = "0.3,0.4,0.5" Distributions = "0.25" -Flux = "0.12, 0.13, 0.14" KernelFunctions = "0.10" Literate = "2" +Lux = "1" MLDataUtils = "0.5" +Optimisers = "0.4" Plots = "1" -Zygote = "0.6, 0.7" -julia = "1.3" +Zygote = "0.7" +julia = "1.10" diff --git a/examples/2-deep-kernel-learning/script.jl b/examples/2-deep-kernel-learning/script.jl index 9ce11164..67f3b09e 100644 --- a/examples/2-deep-kernel-learning/script.jl +++ b/examples/2-deep-kernel-learning/script.jl @@ -1,10 +1,10 @@ -# # Deep Kernel Learning with Flux +# # Deep Kernel Learning with Lux ## Background # This example trains a GP whose inputs are passed through a neural network. # This kind of model has been considered previously [^Calandra] [^Wilson], although it has been shown that some care is needed to avoid substantial overfitting [^Ober]. -# In this example we make use of the `FunctionTransform` from [KernelFunctions.jl](github.com/JuliaGaussianProcesses/KernelFunctions.jl/) to put a simple Multi-Layer Perceptron built using Flux.jl inside a standard kernel. +# In this example we make use of the `FunctionTransform` from [KernelFunctions.jl](github.com/JuliaGaussianProcesses/KernelFunctions.jl/) to put a simple Multi-Layer Perceptron built using Lux.jl inside a standard kernel. # [^Calandra]: Calandra, R., Peters, J., Rasmussen, C. E., & Deisenroth, M. P. (2016, July). [Manifold Gaussian processes for regression.](https://ieeexplore.ieee.org/abstract/document/7727626) In 2016 International Joint Conference on Neural Networks (IJCNN) (pp. 3338-3345). IEEE. @@ -17,12 +17,17 @@ # the different hyper-parameters using AbstractGPs using Distributions -using Flux using KernelFunctions using LinearAlgebra +using Lux +using Optimisers using Plots +using Random +using Zygote default(; legendfontsize=15.0, linewidth=3.0); +Random.seed!(42) # for reproducibility + # ## Data creation # We create a simple 1D Problem with very different variations @@ -30,11 +35,11 @@ xmin, xmax = (-3, 3) # Limits N = 150 noise_std = 0.01 x_train_vec = rand(Uniform(xmin, xmax), N) # Training dataset -x_train = collect(eachrow(x_train_vec)) # vector-of-vectors for Flux compatibility +x_train = collect(eachrow(x_train_vec)) # vector-of-vectors for neural network compatibility target_f(x) = sinc(abs(x)^abs(x)) # We use sinc with a highly varying value y_train = target_f.(x_train_vec) + randn(N) * noise_std x_test_vec = range(xmin, xmax; length=200) # Testing dataset -x_test = collect(eachrow(x_test_vec)) # vector-of-vectors for Flux compatibility +x_test = collect(eachrow(x_test_vec)) # vector-of-vectors for neural network compatibility plot(xmin:0.01:xmax, target_f; label="ground truth") scatter!(x_train_vec, y_train; label="training data") @@ -42,10 +47,16 @@ scatter!(x_train_vec, y_train; label="training data") # ## Model definition # We create a neural net with 2 layers and 10 units each. # The data is passed through the NN before being used in the kernel. -neuralnet = Chain(Dense(1, 20), Dense(20, 30), Dense(30, 5)) +neuralnet = Chain(Dense(1 => 20), Dense(20 => 30), Dense(30 => 5)) + +# Initialize the neural network parameters +rng = Random.default_rng() +ps, st = Lux.setup(rng, neuralnet) + +smodel = StatefulLuxLayer(neuralnet, ps, st) # We use the Squared Exponential Kernel: -k = SqExponentialKernel() ∘ FunctionTransform(neuralnet) +k = SqExponentialKernel() ∘ FunctionTransform(smodel) # We now define our model: gpprior = GP(k) # GP Prior @@ -58,9 +69,6 @@ loss(y) = -logpdf(fx, y) @info "Initial loss = $(loss(y_train))" -# Flux will automatically extract all the parameters of the kernel -ps = Flux.params(k) - # We show the initial prediction with the untrained model p_init = plot(; title="Loss = $(round(loss(y_train); sigdigits=6))") plot!(vcat(x_test...), target_f; label="true f") @@ -70,28 +78,42 @@ plot!(vcat(x_test...), mean.(pred_init); ribbon=std.(pred_init), label="Predicti # ## Training nmax = 200 -opt = Flux.Adam(0.1) + +# Create a wrapper function that updates the kernel with current parameters +function update_kernel_and_loss(model, ps, st, data) + smodel = StatefulLuxLayer(model, ps, st) + k_updated = SqExponentialKernel() ∘ FunctionTransform(smodel) + fx_updated = AbstractGPs.FiniteGP(GP(k_updated), x_train, noise_std^2) + return -logpdf(fx_updated, y_train), smodel.st, (;) +end anim = Animation() -for i in 1:nmax - grads = gradient(ps) do - loss(y_train) - end - Flux.Optimise.update!(opt, ps, grads) - - if i % 10 == 0 - L = loss(y_train) - @info "iteration $i/$nmax: loss = $L" - - p = plot(; title="Loss[$i/$nmax] = $(round(L; sigdigits=6))") - plot!(vcat(x_test...), target_f; label="true f") - scatter!(vcat(x_train...), y_train; label="data") - pred = marginals(posterior(fx, y_train)(x_test)) - plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), label="Prediction") - frame(anim) - display(p) +let tstate = Training.TrainState(neuralnet, ps, st, Optimisers.Adam(0.005)) + for i in 1:nmax + _, loss_val, _, tstate = Training.single_train_step!( + AutoZygote(), update_kernel_and_loss, (), tstate + ) + + if i % 10 == 0 + k = + SqExponentialKernel() ∘ FunctionTransform( + StatefulLuxLayer(neuralnet, tstate.parameters, tstate.states) + ) + fx = AbstractGPs.FiniteGP(GP(k), x_train, noise_std^2) + + @info "iteration $i/$nmax: loss = $loss_val" + + p = plot(; title="Loss[$i/$nmax] = $(round(loss_val; sigdigits=6))") + plot!(vcat(x_test...), target_f; label="true f") + scatter!(vcat(x_train...), y_train; label="data") + pred = marginals(posterior(fx, y_train)(x_test)) + plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), label="Prediction") + frame(anim) + display(p) + end end end + gif(anim, "train-dkl.gif"; fps=3) nothing #hide