Skip to content

Commit 9d9c34b

Browse files
authored
Tweak Examples to use DI (#139)
* Tweak examples * Capitalisation * Update examples CI
1 parent fb1d758 commit 9d9c34b

File tree

5 files changed

+11
-21
lines changed

5 files changed

+11
-21
lines changed

.github/workflows/examples.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
with:
3333
version: ${{ matrix.version }}
3434
arch: ${{ matrix.arch }}
35-
- uses: actions/cache@v1
35+
- uses: actions/cache@v3
3636
env:
3737
cache-name: cache-artifacts
3838
with:

examples/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
23
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
34
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
45
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"

examples/approx_space_time_learning.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ using TemporalGPs
1111
using TemporalGPs: Separable, approx_posterior_marginals, RegularInTime
1212

1313
# Load standard packages from the Julia ecosystem
14+
using ADTypes # Way to specify algorithmic differentiation backend.
1415
using Optim # Standard optimisation algorithms.
1516
using ParameterHandling # Helper functionality for dealing with model parameters.
16-
using Mooncake # Algorithmic Differentiation
17+
import Mooncake # Algorithmic differentiation.
1718

1819
using ParameterHandling: flatten
1920

@@ -60,24 +61,19 @@ function objective(flat_params)
6061
return -elbo(f(x, params.var_noise), y, z_r)
6162
end
6263

63-
# Optimise using Optim.
64-
function objective_grad(rule, flat_params)
65-
return Mooncake.value_and_gradient!!(rule, objective, flat_params)[2][2]
66-
end
67-
6864
@info "running objective"
6965
@show objective(flat_initial_params)
7066

7167
training_results = Optim.optimize(
7268
objective,
73-
Base.Fix1(objective_grad, Mooncake.build_rrule(objective, flat_initial_params)),
7469
flat_initial_params + randn(4), # Add some noise to make learning non-trivial
7570
BFGS(
7671
alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
7772
linesearch = Optim.LineSearches.BackTracking(),
7873
),
7974
Optim.Options(show_trace = true);
8075
inplace=false,
76+
autodiff=AutoMooncake(; config=nothing),
8177
);
8278

8379

examples/exact_space_time_learning.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ using TemporalGPs
1111
using TemporalGPs: Separable, RectilinearGrid
1212

1313
# Load standard packages from the Julia ecosystem
14+
using ADTypes # Way to specify algorithmic differentiation backend.
1415
using Optim # Standard optimisation algorithms.
1516
using ParameterHandling # Helper functionality for dealing with model parameters.
16-
using Mooncake # Algorithmic Differentiation
17+
import Mooncake # Algorithmic differentiation.
1718

1819
# Declare model parameters using `ParameterHandling.jl` types.
1920
flat_initial_params, unflatten = ParameterHandling.flatten((
@@ -53,21 +54,17 @@ function objective(flat_params)
5354
return -logpdf(f(x, params.var_noise), y)
5455
end
5556

56-
function objective_grad(rule, flat_params)
57-
return Mooncake.value_and_gradient!!(rule, objective, flat_params)[2][2]
58-
end
59-
6057
# Optimise using Optim.
6158
training_results = Optim.optimize(
6259
objective,
63-
Base.Fix1(objective_grad, Mooncake.build_rrule(objective, flat_initial_params)),
6460
flat_initial_params + randn(4), # Add some noise to make learning non-trivial
6561
BFGS(
6662
alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
6763
linesearch = Optim.LineSearches.BackTracking(),
6864
),
6965
Optim.Options(show_trace = true);
7066
inplace=false,
67+
autodiff=AutoMooncake(; config=nothing),
7168
);
7269

7370
# Extracting the final values of the parameters.

examples/exact_time_learning.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ using AbstractGPs
77
using TemporalGPs
88

99
# Load standard packages from the Julia ecosystem
10+
using ADTypes # Way to specify algorithmic differentiation backend.
1011
using Optim # Standard optimisation algorithms.
1112
using ParameterHandling # Helper functionality for dealing with model parameters.
12-
using Mooncake # Algorithmic Differentiation
13+
import Mooncake # Algorithmic differentiation.
1314

1415
# Declare model parameters using `ParameterHandling.jl` types.
1516
# var_kernel is the variance of the kernel, λ the inverse length scale, and var_noise the
@@ -48,22 +49,17 @@ function objective(flat_params)
4849
return -logpdf(f(x, params.var_noise), y)
4950
end
5051

51-
# A helper function to get the gradient.
52-
function objective_grad(rule, flat_params)
53-
return Mooncake.value_and_gradient!!(rule, objective, flat_params)[2][2]
54-
end
55-
5652
# Optimise using Optim. Mooncake takes a little while to compile.
5753
training_results = Optim.optimize(
5854
objective,
59-
Base.Fix1(objective_grad, Mooncake.build_rrule(objective, flat_initial_params)),
6055
flat_initial_params .+ randn.(), # Perturb the parameters to make learning non-trivial
6156
BFGS(
6257
alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
6358
linesearch = Optim.LineSearches.BackTracking(),
6459
),
6560
Optim.Options(show_trace = true);
6661
inplace=false,
62+
autodiff=AutoMooncake(; config=nothing),
6763
);
6864

6965
# Extracting the final values of the parameters. Should be moderately close to truth.

0 commit comments

Comments
 (0)