@@ -53,13 +53,10 @@ neuralnet = Chain(Dense(1 => 20), Dense(20 => 30), Dense(30 => 5))
5353rng = Random. default_rng ()
5454ps, st = Lux. setup (rng, neuralnet)
5555
56- # Create a wrapper function for the neural network that will be updated during training
57- function neural_transform (x, θ)
58- return first (neuralnet (x, θ, st))
59- end
56+ smodel = StatefulLuxLayer (neuralnet, ps, st)
6057
6158# We use the Squared Exponential Kernel:
62- k = SqExponentialKernel () ∘ FunctionTransform (x -> neural_transform (x, ps) )
59+ k = SqExponentialKernel () ∘ FunctionTransform (smodel )
6360
6461# We now define our model:
6562gpprior = GP (k) # GP Prior
@@ -81,36 +78,41 @@ plot!(vcat(x_test...), mean.(pred_init); ribbon=std.(pred_init), label="Predicti
8178
8279# ## Training
8380nmax = 200
84- opt_state = Optimisers. setup (Optimisers. Adam (0.1 ), ps)
8581
8682# Create a wrapper function that updates the kernel with current parameters
87- function update_kernel_and_loss (θ_current )
88- k_updated =
89- SqExponentialKernel () ∘ FunctionTransform (x -> neural_transform (x, θ_current) )
83+ function update_kernel_and_loss (model, ps, st, data )
84+ smodel = StatefulLuxLayer (model, ps, st)
85+ k_updated = SqExponentialKernel () ∘ FunctionTransform (smodel )
9086 fx_updated = AbstractGPs. FiniteGP (GP (k_updated), x_train, noise_std^ 2 )
91- return - logpdf (fx_updated, y_train)
87+ return - logpdf (fx_updated, y_train), smodel . st, (;)
9288end
9389
9490anim = Animation ()
95- for i in 1 : nmax
96- loss_val, grads = Zygote. withgradient (update_kernel_and_loss, ps)
97- opt_state, ps = Optimisers. update (opt_state, ps, grads[1 ])
98- k = SqExponentialKernel () ∘ FunctionTransform (x -> neural_transform (x, ps))
99- fx = AbstractGPs. FiniteGP (GP (k), x_train, noise_std^ 2 )
100-
101- if i % 10 == 0
102- L = loss_val
103- @info " iteration $i /$nmax : loss = $L "
104-
105- p = plot (; title= " Loss[$i /$nmax ] = $(round (L; sigdigits= 6 )) " )
106- plot! (vcat (x_test... ), target_f; label= " true f" )
107- scatter! (vcat (x_train... ), y_train; label= " data" )
108- pred = marginals (posterior (fx, y_train)(x_test))
109- plot! (vcat (x_test... ), mean .(pred); ribbon= std .(pred), label= " Prediction" )
110- frame (anim)
111- display (p)
91+ let tstate = Training. TrainState (neuralnet, ps, st, Optimisers. Adam (0.005 ))
92+ for i in 1 : nmax
93+ _, loss_val, _, tstate = Training. single_train_step! (
94+ AutoZygote (), update_kernel_and_loss, (), tstate
95+ )
96+
97+ if i % 10 == 0
98+ k = SqExponentialKernel () ∘ FunctionTransform (
99+ StatefulLuxLayer (neuralnet, tstate. parameters, tstate. states)
100+ )
101+ fx = AbstractGPs. FiniteGP (GP (k), x_train, noise_std^ 2 )
102+
103+ @info " iteration $i /$nmax : loss = $loss_val "
104+
105+ p = plot (; title= " Loss[$i /$nmax ] = $(round (loss_val; sigdigits= 6 )) " )
106+ plot! (vcat (x_test... ), target_f; label= " true f" )
107+ scatter! (vcat (x_train... ), y_train; label= " data" )
108+ pred = marginals (posterior (fx, y_train)(x_test))
109+ plot! (vcat (x_test... ), mean .(pred); ribbon= std .(pred), label= " Prediction" )
110+ frame (anim)
111+ display (p)
112+ end
112113 end
113114end
115+
114116gif (anim, " train-dkl.gif" ; fps= 3 )
115117nothing # hide
116118
0 commit comments