Skip to content

Commit 490ece8

Browse files
authored
Fix build_laplace_objective behaviour (#115)
Fixes #109 Breaking change (required for bugfix): when using `objective = build_laplace_objective(...)` and wanting to access the final `f`, you now need to access `objective.cache.f` instead of `objective.f`.
1 parent c751c05 commit 490ece8

File tree

4 files changed

+37
-13
lines changed

4 files changed

+37
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ApproximateGPs"
22
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
33
authors = ["JuliaGaussianProcesses Team"]
4-
version = "0.3.3"
4+
version = "0.3.4"
55

66
[deps]
77
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"

examples/c-comparisons/script.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,10 @@ lf2.f.kernel
127127
# Finally, we need to construct again the (approximate) posterior given the
128128
# observations for the latent GP with optimised hyperparameters:
129129

130-
f_post2 = posterior(LaplaceApproximation(; f_init=objective.f), lf2(X), Y)
130+
f_post2 = posterior(LaplaceApproximation(; f_init=objective.cache.f), lf2(X), Y)
131131

132-
# By passing `f_init=objective.f` we let the Laplace approximation "warm-start"
133-
# at the last point of the inner-loop Newton optimisation; `objective.f` is a
132+
# By passing `f_init=objective.cache.f` we let the Laplace approximation "warm-start"
133+
# at the last point of the inner-loop Newton optimisation; `objective.cache` is a
134134
# field on the `objective` closure.
135135

136136
# Let's plot samples from the approximate posterior for the optimised hyperparameters:

src/LaplaceApproximationModule.jl

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,25 @@ closure passes its arguments to `build_latent_gp`, which must return the
7575
- `newton_maxiter=100`: maximum number of Newton steps.
7676
"""
7777
function build_laplace_objective(build_latent_gp, xs, ys; kwargs...)
78-
# TODO assumes type of `xs` will be same as `mean(lfx.fx)`
79-
f = similar(xs, length(xs)) # will be mutated in-place to "warm-start" the Newton steps
80-
return build_laplace_objective!(f, build_latent_gp, xs, ys; kwargs...)
78+
cache = LaplaceObjectiveCache(nothing)
79+
# cache.f will be mutated in-place to "warm-start" the Newton steps
80+
# f should be similar(mean(lfx.fx)), but to construct lfx we would need the arguments
81+
# so we set it to `nothing` initially, and set it to mean(lfx.fx) within the objective
82+
return build_laplace_objective!(cache, build_latent_gp, xs, ys; kwargs...)
83+
end
84+
85+
function build_laplace_objective!(f_init::Vector, build_latent_gp, xs, ys; kwargs...)
86+
return build_laplace_objective!(
87+
LaplaceObjectiveCache(f_init), build_latent_gp, xs, ys; kwargs...
88+
)
89+
end
90+
91+
mutable struct LaplaceObjectiveCache
92+
f::Union{Nothing,Vector}
8193
end
8294

8395
function build_laplace_objective!(
84-
f,
96+
cache::LaplaceObjectiveCache,
8597
build_latent_gp,
8698
xs,
8799
ys;
@@ -98,16 +110,18 @@ function build_laplace_objective!(
98110
# Zygote does not like the try/catch within @info etc.
99111
@debug "Objective arguments: $args"
100112
# Zygote does not like in-place assignments either
101-
if initialize_f
102-
f .= mean(lfx.fx)
113+
if cache.f === nothing
114+
cache.f = mean(lfx.fx)
115+
elseif initialize_f
116+
cache.f .= mean(lfx.fx)
103117
end
104118
end
105119
f_opt, lml = laplace_f_and_lml(
106-
lfx, ys; f_init=f, maxiter=newton_maxiter, callback=newton_callback
120+
lfx, ys; f_init=cache.f, maxiter=newton_maxiter, callback=newton_callback
107121
)
108122
ignore_derivatives() do
109123
if newton_warmstart
110-
f .= f_opt
124+
cache.f .= f_opt
111125
initialize_f = false
112126
end
113127
end

test/LaplaceApproximationModule.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424

2525
lf = build_latent_gp(training_results.minimizer)
26-
f_post = posterior(LaplaceApproximation(; f_init=objective.f), lf(xs), ys)
26+
f_post = posterior(LaplaceApproximation(; f_init=objective.cache.f), lf(xs), ys)
2727
return f_post, training_results
2828
end
2929

@@ -208,4 +208,14 @@
208208
res = res_array[end]
209209
@test res.q isa MvNormal
210210
end
211+
212+
@testset "GitHub issue #109" begin
213+
build_latent_gp() = LatentGP(GP(SEKernel()), BernoulliLikelihood(), 1e-8)
214+
215+
x = ColVecs(randn(2, 5))
216+
_, y = rand(build_latent_gp()(x))
217+
218+
objective = build_laplace_objective(build_latent_gp, x, y)
219+
_ = objective() # check that it works
220+
end
211221
end

0 commit comments

Comments
 (0)