Skip to content

Commit 775140a

Browse files
sharanrydevmotion
andauthored
Add Bernoulli likelihood (#15)
* Add Bernoulli likelihood * Use StatsFuns * Update src/likelihoods/bernoulli.jl Co-authored-by: David Widmann <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent 3a540d9 commit 775140a

File tree

5 files changed

+35
-1
lines changed

5 files changed

+35
-1
lines changed

src/GPLikelihoods.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ using StatsFuns: logistic, softmax
88

99
import Distributions
1010

11-
export CategoricalLikelihood,
11+
export BernoulliLikelihood,
12+
CategoricalLikelihood,
1213
GaussianLikelihood,
1314
HeteroscedasticGaussianLikelihood,
1415
PoissonLikelihood
1516

1617
# Likelihoods
18+
include("likelihoods/bernoulli.jl")
1719
include("likelihoods/categorical.jl")
1820
include("likelihoods/gaussian.jl")
1921
include("likelihoods/poisson.jl")

src/likelihoods/bernoulli.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""
2+
BernoulliLikelihood
3+
4+
Bernoulli likelihood is to be used if we assume that the
5+
uncertainity associated with the data follows a Bernoulli distribution.
6+
7+
```math
8+
p(y|f) = Bernoulli(y | f)
9+
```
10+
On calling, this would return a Bernoulli distribution with `f` probability of `true`.
11+
"""
12+
struct BernoulliLikelihood end
13+
14+
(l::BernoulliLikelihood)(f::Real) = Bernoulli(logistic(f))
15+
16+
(l::BernoulliLikelihood)(fs::AbstractVector{<:Real}) = Product(Bernoulli.(logistic.(fs)))

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
33
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
44
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
55
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6+
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
67
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
78

89
[compat]
910
AbstractGPs = "0.2"
1011
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23"
1112
Functors = "0.1"
13+
StatsFuns = "0.9"
1214
julia = "1.3"

test/likelihoods/bernoulli.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
@testset "BernoulliLikelihood" begin
2+
rng = MersenneTwister(123)
3+
gp = GP(SqExponentialKernel())
4+
x = rand(rng, 10)
5+
y = rand(rng, 10)
6+
lik = BernoulliLikelihood()
7+
lgp = LatentGP(gp, lik, 1e-5)
8+
lfgp = lgp(x)
9+
10+
@test typeof(lik(rand(rng, lfgp.fx))) <: Distribution
11+
@test length(rand(rng, lik(rand(rng, lfgp.fx)))) == 10
12+
@test Functors.functor(lik)[1] == ()
13+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Distributions
88
@testset "GPLikelihoods.jl" begin
99

1010
@testset "likelihoods" begin
11+
include("likelihoods/bernoulli.jl")
1112
include("likelihoods/categorical.jl")
1213
include("likelihoods/gaussian.jl")
1314
include("likelihoods/poisson.jl")

0 commit comments

Comments
 (0)