Skip to content

Commit 237f213

Browse files
committed
docs: use more of Lux official API for training and inference
1 parent cd5e927 commit 237f213

File tree

2 files changed

+33
-31
lines changed

2 files changed

+33
-31
lines changed

examples/2-deep-kernel-learning/Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ AbstractGPs = "0.3,0.4,0.5"
1616
Distributions = "0.25"
1717
KernelFunctions = "0.10"
1818
Literate = "2"
19-
Lux = "0.5"
19+
Lux = "1"
2020
MLDataUtils = "0.5"
21-
Optimisers = "0.3"
21+
Optimisers = "0.4"
2222
Plots = "1"
23-
Zygote = "0.6, 0.7"
24-
julia = "1.3"
23+
Zygote = "0.7"
24+
julia = "1.10"

examples/2-deep-kernel-learning/script.jl

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,10 @@ neuralnet = Chain(Dense(1 => 20), Dense(20 => 30), Dense(30 => 5))
5353
rng = Random.default_rng()
5454
ps, st = Lux.setup(rng, neuralnet)
5555

56-
# Create a wrapper function for the neural network that will be updated during training
57-
function neural_transform(x, θ)
58-
return first(neuralnet(x, θ, st))
59-
end
56+
smodel = StatefulLuxLayer(neuralnet, ps, st)
6057

6158
# We use the Squared Exponential Kernel:
62-
k = SqExponentialKernel() FunctionTransform(x -> neural_transform(x, ps))
59+
k = SqExponentialKernel() FunctionTransform(smodel)
6360

6461
# We now define our model:
6562
gpprior = GP(k) # GP Prior
@@ -81,36 +78,41 @@ plot!(vcat(x_test...), mean.(pred_init); ribbon=std.(pred_init), label="Predicti
8178

8279
# ## Training
8380
nmax = 200
84-
opt_state = Optimisers.setup(Optimisers.Adam(0.1), ps)
8581

8682
# Create a wrapper function that updates the kernel with current parameters
87-
function update_kernel_and_loss(θ_current)
88-
k_updated =
89-
SqExponentialKernel() FunctionTransform(x -> neural_transform(x, θ_current))
83+
function update_kernel_and_loss(model, ps, st, data)
84+
smodel = StatefulLuxLayer(model, ps, st)
85+
k_updated = SqExponentialKernel() FunctionTransform(smodel)
9086
fx_updated = AbstractGPs.FiniteGP(GP(k_updated), x_train, noise_std^2)
91-
return -logpdf(fx_updated, y_train)
87+
return -logpdf(fx_updated, y_train), smodel.st, (;)
9288
end
9389

9490
anim = Animation()
95-
for i in 1:nmax
96-
loss_val, grads = Zygote.withgradient(update_kernel_and_loss, ps)
97-
opt_state, ps = Optimisers.update(opt_state, ps, grads[1])
98-
k = SqExponentialKernel() FunctionTransform(x -> neural_transform(x, ps))
99-
fx = AbstractGPs.FiniteGP(GP(k), x_train, noise_std^2)
100-
101-
if i % 10 == 0
102-
L = loss_val
103-
@info "iteration $i/$nmax: loss = $L"
104-
105-
p = plot(; title="Loss[$i/$nmax] = $(round(L; sigdigits=6))")
106-
plot!(vcat(x_test...), target_f; label="true f")
107-
scatter!(vcat(x_train...), y_train; label="data")
108-
pred = marginals(posterior(fx, y_train)(x_test))
109-
plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), label="Prediction")
110-
frame(anim)
111-
display(p)
91+
let tstate = Training.TrainState(neuralnet, ps, st, Optimisers.Adam(0.005))
92+
for i in 1:nmax
93+
_, loss_val, _, tstate = Training.single_train_step!(
94+
AutoZygote(), update_kernel_and_loss, (), tstate
95+
)
96+
97+
if i % 10 == 0
98+
k = SqExponentialKernel() FunctionTransform(
99+
StatefulLuxLayer(neuralnet, tstate.parameters, tstate.states)
100+
)
101+
fx = AbstractGPs.FiniteGP(GP(k), x_train, noise_std^2)
102+
103+
@info "iteration $i/$nmax: loss = $loss_val"
104+
105+
p = plot(; title="Loss[$i/$nmax] = $(round(loss_val; sigdigits=6))")
106+
plot!(vcat(x_test...), target_f; label="true f")
107+
scatter!(vcat(x_train...), y_train; label="data")
108+
pred = marginals(posterior(fx, y_train)(x_test))
109+
plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), label="Prediction")
110+
frame(anim)
111+
display(p)
112+
end
112113
end
113114
end
115+
114116
gif(anim, "train-dkl.gif"; fps=3)
115117
nothing #hide
116118

0 commit comments

Comments
 (0)