Skip to content

Commit 63028d3

Browse files
committed
setup as a test
1 parent 923c116 commit 63028d3

File tree

6 files changed

+193
-152
lines changed

6 files changed

+193
-152
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
1010
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1111
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
1212
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
13+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1314
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1415
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1516
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1617
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
18+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1719
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1820
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1921
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
@@ -34,10 +36,12 @@ Transducers = "0.4.30"
3436
julia = "1.6"
3537

3638
[extras]
39+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
3740
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
3841
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
42+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
3943
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4044
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4145

4246
[targets]
43-
test = ["FillArrays", "IJulia", "Statistics", "Test"]
47+
test = ["FillArrays", "Distributions", "IJulia", "OrderedCollections", "Statistics", "Test"]

docs/src/gibbs.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,49 @@ Some points worth noting:
345345
- update the `vi` with the new values from the sampler state
346346

347347
Again, the `state` interface in AbstractMCMC allows the Gibbs sampler to be agnostic of the details of the sampler state, and acquire the values of the parameters from individual sampler states.
348+
349+
Now we can use the Gibbs sampler to sample from the hierarchical normal model.
350+
351+
First we generate some data,
352+
353+
```julia
354+
N = 100 # Number of data points
355+
mu_true = 0.5 # True mean
356+
tau2_true = 2.0 # True variance
357+
358+
x_data = rand(Normal(mu_true, sqrt(tau2_true)), N)
359+
```
360+
361+
```
362+
363+
Then we can create a `HierNormal` model with the data.
364+
365+
```julia
366+
hn = HierNormal((x=x_data,))
367+
```
368+
369+
sampling is easy: we use random walk MH for `mu` and prior MH for `tau2`, because `tau2` has support on positive real numbers.
370+
371+
```julia
372+
samples = sample(
373+
hn,
374+
Gibbs(
375+
OrderedDict(
376+
(:mu,) => RWMH(1),
377+
(:tau2,) => PriorMH(product_distribution([InverseGamma(1, 1)])),
378+
),
379+
),
380+
100000;
381+
initial_params=(mu=[0.0], tau2=[1.0]),
382+
)
383+
```
384+
385+
Then we can extract the samples and compute the mean of the samples.
386+
387+
```julia
388+
mu_samples = [sample.values.mu for sample in samples][20001:end]
389+
tau2_samples = [sample.values.tau2 for sample in samples][20001:end]
390+
391+
mean(mu_samples)
392+
mean(tau2_samples)
393+
```

test/gibbs_example/Project.toml

Lines changed: 0 additions & 11 deletions
This file was deleted.

test/gibbs_example/gibbs.jl

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@ using Distributions
33
using LogDensityProblems
44
using OrderedCollections
55
using Random
6+
using Test
67

7-
##
8+
include("hier_normal.jl")
9+
# include("gmm.jl")
10+
include("mh.jl")
811

912
struct Gibbs <: AbstractMCMC.AbstractSampler
1013
sampler_map::OrderedDict
@@ -64,7 +67,7 @@ function AbstractMCMC.step(
6467
vi = state.vi
6568
for group in keys(spl.sampler_map)
6669
for (group, sub_state) in state.states
67-
vi = merge(vi, unflatten(get_params(sub_state), group))
70+
vi = merge(vi, unflatten(AbstractMCMC.get_params(sub_state), group))
6871
end
6972
sub_spl = spl.sampler_map[group]
7073
sub_state = state.states[group]
@@ -73,7 +76,7 @@ function AbstractMCMC.step(
7376
Tuple([vi[g] for g in group_complement])
7477
)
7578
cond_logdensity = condition(logdensity_model.logdensity, cond_val)
76-
sub_state = recompute_logprob!!(cond_logdensity, get_params(sub_state), sub_state)
79+
sub_state = recompute_logprob!!(cond_logdensity, AbstractMCMC.get_params(sub_state), sub_state)
7780
sub_state = last(
7881
AbstractMCMC.step(
7982
rng,
@@ -87,15 +90,15 @@ function AbstractMCMC.step(
8790
state.states[group] = sub_state
8891
end
8992
for (group, sub_state) in state.states
90-
vi = merge(vi, unflatten(get_params(sub_state), group))
93+
vi = merge(vi, unflatten(AbstractMCMC.get_params(sub_state), group))
9194
end
9295
return GibbsTransition(vi), GibbsState(vi, state.states)
9396
end
9497

95-
## tests
98+
## tests with hierarchical normal model
9699

97100
# generate data
98-
N = 100 # Number of data points
101+
N = 1000 # Number of data points
99102
mu_true = 0.5 # True mean
100103
tau2_true = 2.0 # True variance
101104

@@ -105,8 +108,6 @@ x_data = rand(Normal(mu_true, sqrt(tau2_true)), N)
105108
# Store the generated data in the HierNormal structure
106109
hn = HierNormal((x=x_data,))
107110

108-
##
109-
110111
samples = sample(
111112
hn,
112113
Gibbs(
@@ -115,43 +116,46 @@ samples = sample(
115116
(:tau2,) => PriorMH(product_distribution([InverseGamma(1, 1)])),
116117
),
117118
),
118-
100000;
119+
200000;
119120
initial_params=(mu=[0.0], tau2=[1.0]),
120121
)
121122

122-
mu_samples = [sample.values.mu for sample in samples][20001:end]
123-
tau2_samples = [sample.values.tau2 for sample in samples][20001:end]
124-
125-
mean(mu_samples)
126-
mean(tau2_samples)
127-
128-
##
129-
130-
# this is too difficult of a problem
131-
132-
gmm = GMM((; x=x))
133-
134-
samples = sample(
135-
gmm,
136-
Gibbs(
137-
OrderedDict(
138-
(:z,) => PriorMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])),
139-
(:w,) => PriorMH(Dirichlet(2, 1.0)),
140-
(,) => RWMH(1),
141-
),
142-
),
143-
100000;
144-
initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[-3.5, 0.5], w=[0.3, 0.7]),
145-
);
123+
mu_samples = [sample.values.mu for sample in samples][40001:end]
124+
tau2_samples = [sample.values.tau2 for sample in samples][40001:end]
146125

147-
z_samples = [sample.values.z for sample in samples][20001:end]
148-
μ_samples = [sample.values.μ for sample in samples][20001:end]
149-
w_samples = [sample.values.w for sample in samples][20001:end];
126+
mu_mean = mean(mu_samples)[1]
127+
tau2_mean = mean(tau2_samples)[1]
150128

151-
# thin these samples
152-
z_samples = z_samples[1:100:end]
153-
μ_samples = μ_samples[1:100:end]
154-
w_samples = w_samples[1:100:end];
129+
@testset "hierarchical normal with gibbs" begin
130+
@test mu_mean mu_true atol = 0.1
131+
@test tau2_mean tau2_true atol = 0.3
132+
end
155133

156-
mean(μ_samples)
157-
mean(w_samples)
134+
## test with gmm -- too hard, doesn't converge
135+
136+
# gmm = GMM((; x=x))
137+
138+
# samples = sample(
139+
# gmm,
140+
# Gibbs(
141+
# OrderedDict(
142+
# (:z,) => PriorMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])),
143+
# (:w,) => PriorMH(Dirichlet(2, 1.0)),
144+
# (:μ,) => RWMH(1),
145+
# ),
146+
# ),
147+
# 100000;
148+
# initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[-3.5, 0.5], w=[0.3, 0.7]),
149+
# );
150+
151+
# z_samples = [sample.values.z for sample in samples][20001:end]
152+
# μ_samples = [sample.values.μ for sample in samples][20001:end]
153+
# w_samples = [sample.values.w for sample in samples][20001:end];
154+
155+
# # thin these samples
156+
# z_samples = z_samples[1:100:end]
157+
# μ_samples = μ_samples[1:100:end]
158+
# w_samples = w_samples[1:100:end];
159+
160+
# mean(μ_samples)
161+
# mean(w_samples)

test/gibbs_example/gmm.jl

Lines changed: 43 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using LogDensityProblems
2-
31
abstract type AbstractGMM end
42

53
struct GMM <: AbstractGMM
@@ -81,65 +79,65 @@ function unflatten(vec::AbstractVector, group::Tuple)
8179
end
8280

8381
function recompute_logprob!!(gmm::ConditionedGMM, vals, state)
84-
return setlogp!!(state, LogDensityProblems.logdensity(gmm, vals))
82+
return set_logp!!(state, LogDensityProblems.logdensity(gmm, vals))
8583
end
8684

8785
## test using Turing
8886

89-
# data generation
87+
# # data generation
9088

91-
using FillArrays
89+
# using FillArrays
9290

93-
w = [0.5, 0.5]
94-
μ = [-3.5, 0.5]
95-
mixturemodel = Distributions.MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ], w)
91+
# w = [0.5, 0.5]
92+
# μ = [-3.5, 0.5]
93+
# mixturemodel = Distributions.MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ], w)
9694

97-
N = 60
98-
x = rand(mixturemodel, N);
95+
# N = 60
96+
# x = rand(mixturemodel, N);
9997

100-
# Turing model from https://turinglang.org/docs/tutorials/01-gaussian-mixture-model/
101-
using Turing
98+
# # Turing model from https://turinglang.org/docs/tutorials/01-gaussian-mixture-model/
99+
# using Turing
102100

103-
@model function gaussian_mixture_model(x)
104-
# Draw the parameters for each of the K=2 clusters from a standard normal distribution.
105-
K = 2
106-
μ ~ MvNormal(Zeros(K), I)
101+
# @model function gaussian_mixture_model(x)
102+
# # Draw the parameters for each of the K=2 clusters from a standard normal distribution.
103+
# K = 2
104+
# μ ~ MvNormal(Zeros(K), I)
107105

108-
# Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1.
109-
w ~ Dirichlet(K, 1.0)
110-
# Alternatively, one could use a fixed set of weights.
111-
# w = fill(1/K, K)
106+
# # Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1.
107+
# w ~ Dirichlet(K, 1.0)
108+
# # Alternatively, one could use a fixed set of weights.
109+
# # w = fill(1/K, K)
112110

113-
# Construct categorical distribution of assignments.
114-
distribution_assignments = Categorical(w)
111+
# # Construct categorical distribution of assignments.
112+
# distribution_assignments = Categorical(w)
115113

116-
# Construct multivariate normal distributions of each cluster.
117-
D, N = size(x)
118-
distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
114+
# # Construct multivariate normal distributions of each cluster.
115+
# D, N = size(x)
116+
# distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
119117

120-
# Draw assignments for each datum and generate it from the multivariate normal distribution.
121-
k = Vector{Int}(undef, N)
122-
for i in 1:N
123-
k[i] ~ distribution_assignments
124-
x[:, i] ~ distribution_clusters[k[i]]
125-
end
118+
# # Draw assignments for each datum and generate it from the multivariate normal distribution.
119+
# k = Vector{Int}(undef, N)
120+
# for i in 1:N
121+
# k[i] ~ distribution_assignments
122+
# x[:, i] ~ distribution_clusters[k[i]]
123+
# end
126124

127-
return μ, w, k, __varinfo__
128-
end
125+
# return μ, w, k, __varinfo__
126+
# end
129127

130-
model = gaussian_mixture_model(x);
128+
# model = gaussian_mixture_model(x);
131129

132-
using Test
133-
# full model
134-
μ, w, k, vi = model()
135-
@test log_joint(; μ=μ, w=w, z=k, x=x) DynamicPPL.getlogp(vi)
130+
# using Test
131+
# # full model
132+
# μ, w, k, vi = model()
133+
# @test log_joint(; μ=μ, w=w, z=k, x=x) ≈ DynamicPPL.getlogp(vi)
136134

137-
gmm = GMM((; x=x))
135+
# gmm = GMM((; x=x))
138136

139-
# cond model on μ, w
140-
μ, w, k, vi = (DynamicPPL.condition(model, (μ=μ, w=w)))()
141-
@test _logdensity(condition(gmm, (; μ=μ, w=w)), (; z=k)) DynamicPPL.getlogp(vi)
137+
# # cond model on μ, w
138+
# μ, w, k, vi = (DynamicPPL.condition(model, (μ=μ, w=w)))()
139+
# @test _logdensity(condition(gmm, (; μ=μ, w=w)), (; z=k)) ≈ DynamicPPL.getlogp(vi)
142140

143-
# cond model on z
144-
μ, w, k, vi = (DynamicPPL.condition(model, (z = k)))()
145-
@test _logdensity(condition(gmm, (; z=k)), (; μ=μ, w=w)) DynamicPPL.getlogp(vi)
141+
# # cond model on z
142+
# μ, w, k, vi = (DynamicPPL.condition(model, (z = k)))()
143+
# @test _logdensity(condition(gmm, (; z=k)), (; μ=μ, w=w)) ≈ DynamicPPL.getlogp(vi)

0 commit comments

Comments
 (0)