Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions examples/2-deep-kernel-learning/Project.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractGPs = "0.3,0.4,0.5"
Distributions = "0.25"
Flux = "0.12, 0.13, 0.14"
KernelFunctions = "0.10"
Literate = "2"
Lux = "1"
MLDataUtils = "0.5"
Optimisers = "0.4"
Plots = "1"
Zygote = "0.6, 0.7"
julia = "1.3"
Zygote = "0.7"
julia = "1.10"
78 changes: 50 additions & 28 deletions examples/2-deep-kernel-learning/script.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# # Deep Kernel Learning with Flux
# # Deep Kernel Learning with Lux

## Background

# This example trains a GP whose inputs are passed through a neural network.
# This kind of model has been considered previously [^Calandra] [^Wilson], although it has been shown that some care is needed to avoid substantial overfitting [^Ober].
# In this example we make use of the `FunctionTransform` from [KernelFunctions.jl](github.com/JuliaGaussianProcesses/KernelFunctions.jl/) to put a simple Multi-Layer Perceptron built using Flux.jl inside a standard kernel.
# In this example we make use of the `FunctionTransform` from [KernelFunctions.jl](github.com/JuliaGaussianProcesses/KernelFunctions.jl/) to put a simple Multi-Layer Perceptron built using Lux.jl inside a standard kernel.

# [^Calandra]: Calandra, R., Peters, J., Rasmussen, C. E., & Deisenroth, M. P. (2016, July). [Manifold Gaussian processes for regression.](https://ieeexplore.ieee.org/abstract/document/7727626) In 2016 International Joint Conference on Neural Networks (IJCNN) (pp. 3338-3345). IEEE.

Expand All @@ -17,35 +17,46 @@
# the different hyper-parameters
using AbstractGPs
using Distributions
using Flux
using KernelFunctions
using LinearAlgebra
using Lux
using Optimisers
using Plots
using Random
using Zygote
default(; legendfontsize=15.0, linewidth=3.0);

Random.seed!(42) # for reproducibility

# ## Data creation
# We create a simple 1D Problem with very different variations

xmin, xmax = (-3, 3) # Limits
N = 150
noise_std = 0.01
x_train_vec = rand(Uniform(xmin, xmax), N) # Training dataset
x_train = collect(eachrow(x_train_vec)) # vector-of-vectors for Flux compatibility
x_train = collect(eachrow(x_train_vec)) # vector-of-vectors for neural network compatibility
target_f(x) = sinc(abs(x)^abs(x)) # We use sinc with a highly varying value
y_train = target_f.(x_train_vec) + randn(N) * noise_std
x_test_vec = range(xmin, xmax; length=200) # Testing dataset
x_test = collect(eachrow(x_test_vec)) # vector-of-vectors for Flux compatibility
x_test = collect(eachrow(x_test_vec)) # vector-of-vectors for neural network compatibility

plot(xmin:0.01:xmax, target_f; label="ground truth")
scatter!(x_train_vec, y_train; label="training data")

# ## Model definition
# We create a neural net with 2 layers and 10 units each.
# The data is passed through the NN before being used in the kernel.
neuralnet = Chain(Dense(1, 20), Dense(20, 30), Dense(30, 5))
neuralnet = Chain(Dense(1 => 20), Dense(20 => 30), Dense(30 => 5))

# Initialize the neural network parameters
rng = Random.default_rng()
ps, st = Lux.setup(rng, neuralnet)

smodel = StatefulLuxLayer(neuralnet, ps, st)

# We use the Squared Exponential Kernel:
k = SqExponentialKernel() ∘ FunctionTransform(neuralnet)
k = SqExponentialKernel() ∘ FunctionTransform(smodel)

# We now define our model:
gpprior = GP(k) # GP Prior
Expand All @@ -58,9 +69,6 @@ loss(y) = -logpdf(fx, y)

@info "Initial loss = $(loss(y_train))"

# Flux will automatically extract all the parameters of the kernel
ps = Flux.params(k)

# We show the initial prediction with the untrained model
p_init = plot(; title="Loss = $(round(loss(y_train); sigdigits=6))")
plot!(vcat(x_test...), target_f; label="true f")
Expand All @@ -70,28 +78,42 @@ plot!(vcat(x_test...), mean.(pred_init); ribbon=std.(pred_init), label="Predicti

# ## Training
nmax = 200
opt = Flux.Adam(0.1)

# Create a wrapper function that updates the kernel with current parameters
function update_kernel_and_loss(model, ps, st, data)
smodel = StatefulLuxLayer(model, ps, st)
k_updated = SqExponentialKernel() ∘ FunctionTransform(smodel)
fx_updated = AbstractGPs.FiniteGP(GP(k_updated), x_train, noise_std^2)
return -logpdf(fx_updated, y_train), smodel.st, (;)
end

anim = Animation()
for i in 1:nmax
grads = gradient(ps) do
loss(y_train)
end
Flux.Optimise.update!(opt, ps, grads)

if i % 10 == 0
L = loss(y_train)
@info "iteration $i/$nmax: loss = $L"

p = plot(; title="Loss[$i/$nmax] = $(round(L; sigdigits=6))")
plot!(vcat(x_test...), target_f; label="true f")
scatter!(vcat(x_train...), y_train; label="data")
pred = marginals(posterior(fx, y_train)(x_test))
plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), label="Prediction")
frame(anim)
display(p)
let tstate = Training.TrainState(neuralnet, ps, st, Optimisers.Adam(0.005))
for i in 1:nmax
_, loss_val, _, tstate = Training.single_train_step!(
AutoZygote(), update_kernel_and_loss, (), tstate
)

if i % 10 == 0
k =
SqExponentialKernel() ∘ FunctionTransform(
StatefulLuxLayer(neuralnet, tstate.parameters, tstate.states)
)
fx = AbstractGPs.FiniteGP(GP(k), x_train, noise_std^2)

@info "iteration $i/$nmax: loss = $loss_val"

p = plot(; title="Loss[$i/$nmax] = $(round(loss_val; sigdigits=6))")
plot!(vcat(x_test...), target_f; label="true f")
scatter!(vcat(x_train...), y_train; label="data")
pred = marginals(posterior(fx, y_train)(x_test))
plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), label="Prediction")
frame(anim)
display(p)
end
end
end

gif(anim, "train-dkl.gif"; fps=3)
nothing #hide

Expand Down
Loading