Skip to content

Commit 3a540d9

Browse files
authored
Add Categorical Likelihood (#16)
* Add Categorical Likelihood * Fix docstring * Address code review * Usa isa * Update categorical likelihood * Avoid mutation
1 parent afe33b0 commit 3a540d9

File tree

7 files changed

+47
-3
lines changed

7 files changed

+47
-3
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
88
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
99
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1010
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
11+
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1112

1213
[compat]
1314
AbstractGPs = "0.2"
1415
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23"
1516
Functors = "0.1"
17+
StatsFuns = "0.9"
1618
julia = "1.3"

src/GPLikelihoods.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,17 @@ using Distributions
44
using AbstractGPs
55
using Random
66
using Functors
7+
using StatsFuns: logistic, softmax
78

89
import Distributions
910

10-
export GaussianLikelihood,
11+
export CategoricalLikelihood,
12+
GaussianLikelihood,
1113
HeteroscedasticGaussianLikelihood,
1214
PoissonLikelihood
1315

1416
# Likelihoods
17+
include("likelihoods/categorical.jl")
1518
include("likelihoods/gaussian.jl")
1619
include("likelihoods/poisson.jl")
1720

src/likelihoods/categorical.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""
2+
CategoricalLikelihood
3+
4+
Categorical likelihood is to be used if we assume that the
5+
uncertainity associated with the data follows a Categorical distribution.
6+
```math
7+
p(y|f_1, f_2, \\dots, f_{n-1}) = Categorical(y | softmax(f_1, f_2, \\dots, f_{n-1}, 0))
8+
```
9+
On calling, this would return a Categorical distribution with `f_i`
10+
probability of `i` category.
11+
"""
12+
struct CategoricalLikelihood end
13+
14+
(l::CategoricalLikelihood)(f::AbstractVector{<:Real}) = Categorical(softmax(vcat(f, 0)))
15+
16+
(l::CategoricalLikelihood)(fs::AbstractVector) = Product(Categorical.(softmax.(vcat.(fs, 0))))

test/likelihoods/categorical.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
@testset "CategoricalLikelihood" begin
2+
rng = MersenneTwister(123)
3+
gp = GP(IndependentMOKernel(SqExponentialKernel()))
4+
IN_DIM = 3
5+
OUT_DIM = 4
6+
N = 10
7+
x = [rand(rng, IN_DIM) for _=1:N]
8+
X = MOInput(x, OUT_DIM)
9+
lik = CategoricalLikelihood()
10+
lgp = LatentGP(gp, lik, 1e-5)
11+
lfgp = lgp(X)
12+
13+
Y = rand(rng, lfgp.fx)
14+
15+
y = [Y[[i + j*N for j in 0:(OUT_DIM - 1)]] for i in 1:N]
16+
# Replace with mo_inverse_transform once it is merged
17+
18+
@test length(lik(rand(3)).p) == 4
19+
@test lik(y) isa Distribution
20+
@test length(rand(rng, lik(y))) == 10
21+
@test Functors.functor(lik)[1] == ()
22+
end

test/likelihoods/gaussian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
lgp = LatentGP(gp, lik, 1e-5)
88
lfgp = lgp(x)
99

10-
@test typeof(lik(rand(rng, lfgp.fx))) <: Distribution
10+
@test lik(rand(rng, lfgp.fx)) isa Distribution
1111
@test length(rand(rng, lik(rand(rng, lfgp.fx)))) == 10
1212
@test keys(Functors.functor(lik)[1]) == (:σ²,)
1313
end

test/likelihoods/poisson.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
lgp = LatentGP(gp, lik, 1e-5)
88
lfgp = lgp(x)
99

10-
@test typeof(lik(rand(rng, lfgp.fx))) <: Distribution
10+
@test lik(rand(rng, lfgp.fx)) isa Distribution
1111
@test length(rand(rng, lik(rand(rng, lfgp.fx)))) == 10
1212
@test Functors.functor(lik)[1] == ()
1313
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Distributions
88
@testset "GPLikelihoods.jl" begin
99

1010
@testset "likelihoods" begin
11+
include("likelihoods/categorical.jl")
1112
include("likelihoods/gaussian.jl")
1213
include("likelihoods/poisson.jl")
1314
end

0 commit comments

Comments
 (0)