|
| 1 | +module LaplaceApproximationModule |
| 2 | + |
| 3 | +using ..API |
| 4 | + |
| 5 | +export LaplaceApproximation |
| 6 | +export build_laplace_objective, build_laplace_objective! |
| 7 | + |
| 8 | +using ForwardDiff: ForwardDiff |
| 9 | +using Distributions |
| 10 | +using LinearAlgebra |
| 11 | +using Statistics |
| 12 | +using StatsBase |
| 13 | + |
| 14 | +using ChainRulesCore: ignore_derivatives, NoTangent, @thunk |
| 15 | +using ChainRulesCore: ChainRulesCore |
| 16 | + |
| 17 | +using AbstractGPs: AbstractGPs |
| 18 | +using AbstractGPs: LatentFiniteGP, ApproxPosteriorGP |
| 19 | + |
1 | 20 | # Implementation follows Rasmussen & Williams, Gaussian Processes for Machine |
2 | 21 | # Learning, the MIT Press, 2006. In the following referred to as 'RW'. |
3 | 22 | # Online text: |
@@ -36,7 +55,7 @@ Compute an approximation to the log of the marginal likelihood (also known as |
36 | 55 |
|
37 | 56 | This should become part of the AbstractGPs API (see JuliaGaussianProcesses/AbstractGPs.jl#221). |
38 | 57 | """ |
39 | | -function approx_lml(la::LaplaceApproximation, lfx::LatentFiniteGP, ys) |
| 58 | +function API.approx_lml(la::LaplaceApproximation, lfx::LatentFiniteGP, ys) |
40 | 59 | return laplace_lml(lfx, ys; la.newton_kwargs...) |
41 | 60 | end |
42 | 61 |
|
@@ -309,11 +328,13 @@ function ChainRulesCore.rrule(::typeof(newton_inner_loop), dist_y_given_f, ys, K |
309 | 328 | function newton_pullback(Δf_opt) |
310 | 329 | ∂self = NoTangent() |
311 | 330 |
|
312 | | - ∂dist_y_given_f = @not_implemented( |
| 331 | + ∂dist_y_given_f = ChainRulesCore.@not_implemented( |
313 | 332 | "gradient of Newton's method w.r.t. likelihood parameters" |
314 | 333 | ) |
315 | 334 |
|
316 | | - ∂ys = @not_implemented("gradient of Newton's method w.r.t. observations") |
| 335 | + ∂ys = ChainRulesCore.@not_implemented( |
| 336 | + "gradient of Newton's method w.r.t. observations" |
| 337 | + ) |
317 | 338 |
|
318 | 339 | # ∂K = df/dK Δf |
319 | 340 | ∂K = @thunk(cache.Wsqrt * (cache.B_ch \ (cache.Wsqrt \ Δf_opt)) * cache.d_loglik') |
@@ -417,3 +438,5 @@ function Statistics.cov(f::LaplacePosteriorGP, x::AbstractVector, y::AbstractVec |
417 | 438 | vy = L \ (f.data.Wsqrt * cov(f.prior.f, f.prior.x, y)) |
418 | 439 | return cov(f.prior.f, x, y) - vx' * vy |
419 | 440 | end |
| 441 | + |
| 442 | +end |
0 commit comments