Skip to content

Commit afe33b0

Browse files
sharanrydevmotion
andauthored
Add Heteroscedastic Gaussian Likelihood (#17)
* Add Heteroscedastic Gaussian Likelihood * Address code review * Usa isa * Update src/likelihoods/gaussian.jl Co-authored-by: David Widmann <[email protected]> * Edit docs, patch bump, style fix Co-authored-by: David Widmann <[email protected]>
1 parent 2b89c26 commit afe33b0

File tree

4 files changed

+44
-2
lines changed

4 files changed

+44
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GPLikelihoods"
22
uuid = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
33
authors = ["willtebbutt <[email protected]>"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"

src/GPLikelihoods.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ using Functors
77

88
import Distributions
99

10-
export GaussianLikelihood, PoissonLikelihood
10+
export GaussianLikelihood,
11+
HeteroscedasticGaussianLikelihood,
12+
PoissonLikelihood
1113

1214
# Likelihoods
1315
include("likelihoods/gaussian.jl")

src/likelihoods/gaussian.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,21 @@ GaussianLikelihood() = GaussianLikelihood(1e-6)
2020
(l::GaussianLikelihood)(f::Real) = Normal(f, sqrt(l.σ²))
2121

2222
(l::GaussianLikelihood)(fs::AbstractVector{<:Real}) = MvNormal(fs, sqrt(l.σ²))
23+
24+
"""
25+
HeteroscedasticGaussianLikelihood(σ²)
26+
27+
Heteroscedastic Gaussian likelihood.
28+
This is a Gaussian likelihood whose mean and the log of whose variance are functions of the
29+
latent process.
30+
31+
```math
32+
p(y|[f, g]) = Normal(y | f, exp(g))
33+
```
34+
On calling, this would return a normal distribution with mean `f` and variance `exp(g)`.
35+
"""
36+
struct HeteroscedasticGaussianLikelihood end
37+
38+
(::HeteroscedasticGaussianLikelihood)(f::AbstractVector{<:Real}) = Normal(f[1], exp(f[2]))
39+
40+
(::HeteroscedasticGaussianLikelihood)(fs::AbstractVector) = MvNormal(first.(fs), exp.(last.(fs)))

test/likelihoods/gaussian.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,25 @@
1111
@test length(rand(rng, lik(rand(rng, lfgp.fx)))) == 10
1212
@test keys(Functors.functor(lik)[1]) == (:σ²,)
1313
end
14+
15+
@testset "HeteroscedasticGaussianLikelihood" begin
16+
rng = MersenneTwister(123)
17+
gp = GP(IndependentMOKernel(SqExponentialKernel()))
18+
IN_DIM = 3
19+
OUT_DIM = 2 # one for the mean the other for the log-standard deviation
20+
N = 10
21+
x = [rand(rng, IN_DIM) for _ in 1:N]
22+
X = MOInput(x, OUT_DIM)
23+
lik = HeteroscedasticGaussianLikelihood()
24+
lgp = LatentGP(gp, lik, 1e-5)
25+
lfgp = lgp(X)
26+
27+
Y = rand(rng, lfgp.fx)
28+
29+
y = [Y[[i + j*N for j in 0:(OUT_DIM - 1)]] for i in 1:N]
30+
# Replace with mo_inverse_transform once it is merged
31+
32+
@test lik(y) isa Distribution
33+
@test length(rand(rng, lik(y))) == 10
34+
@test Functors.functor(lik)[1] == ()
35+
end

0 commit comments

Comments
 (0)