Skip to content

Commit 7d0ba7c

Browse files
committed
add hierarchical normal problem
1 parent f758a4c commit 7d0ba7c

File tree

2 files changed

+115
-2
lines changed

2 files changed

+115
-2
lines changed

gibbs_example/gibbs.jl

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,41 @@ end
9292

9393
## tests
9494

95+
# generate data
96+
N = 100 # Number of data points
97+
mu_true = 0.5 # True mean
98+
tau2_true = 2.0 # True variance
99+
100+
# Generate data based on true parameters
101+
x_data = rand(Normal(mu_true, sqrt(tau2_true)), N)
102+
103+
# Store the generated data in the HierNormal structure
104+
hn = HierNormal((x=x_data,))
105+
106+
##
107+
108+
samples = sample(
109+
hn,
110+
Gibbs(
111+
OrderedDict(
112+
(:mu,) => RWMH(1),
113+
(:tau2,) => PriorMH(product_distribution([InverseGamma(1, 1)])),
114+
),
115+
),
116+
100000;
117+
initial_params=(mu=[0.0], tau2=[1.0]),
118+
)
119+
120+
mu_samples = [sample.values.mu for sample in samples][20001:end]
121+
tau2_samples = [sample.values.tau2 for sample in samples][20001:end]
122+
123+
mean(mu_samples)
124+
mean(tau2_samples)
125+
126+
##
127+
128+
# this is too difficult of a problem
129+
95130
gmm = GMM((; x=x))
96131

97132
samples = sample(
@@ -104,12 +139,17 @@ samples = sample(
104139
),
105140
),
106141
100000;
107-
initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[0.0, 1.0], w=[0.3, 0.7]),
142+
initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[-3.5, 0.5], w=[0.3, 0.7]),
108143
);
109144

110145
z_samples = [sample.values.z for sample in samples][20001:end]
111146
μ_samples = [sample.values.μ for sample in samples][20001:end]
112-
w_samples = [sample.values.w for sample in samples][20001:end]
147+
w_samples = [sample.values.w for sample in samples][20001:end];
148+
149+
# thin these samples
150+
z_samples = z_samples[1:100:end]
151+
μ_samples = μ_samples[1:100:end]
152+
w_samples = w_samples[1:100:end];
113153

114154
mean(μ_samples)
115155
mean(w_samples)

gibbs_example/hier_normal.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using LogDensityProblems
2+
3+
abstract type AbstractHierNormal end
4+
5+
struct HierNormal <: AbstractHierNormal
6+
data::NamedTuple
7+
end
8+
9+
struct ConditionedHierNormal{conditioned_vars} <: AbstractHierNormal
10+
data::NamedTuple
11+
conditioned_values::NamedTuple{conditioned_vars}
12+
end
13+
14+
function log_joint(; mu, tau2, x)
15+
# mu is the mean
16+
# tau2 is the variance
17+
# x is data
18+
19+
# μ ~ Normal(0, 1)
20+
# τ² ~ InverseGamma(1, 1)
21+
# xᵢ ~ Normal(μ, √τ²)
22+
23+
logp = 0.0
24+
mu = only(mu)
25+
tau2 = only(tau2)
26+
27+
mu_prior = Normal(0, 1)
28+
logp += logpdf(mu_prior, mu)
29+
30+
tau2_prior = InverseGamma(1, 1)
31+
logp += logpdf(tau2_prior, tau2)
32+
33+
obs_prior = Normal(mu, sqrt(tau2))
34+
logp += sum(logpdf(obs_prior, xi) for xi in x)
35+
36+
return logp
37+
end
38+
39+
function condition(hn::HierNormal, conditioned_values::NamedTuple)
40+
return ConditionedHierNormal(hn.data, conditioned_values)
41+
end
42+
43+
function LogDensityProblems.logdensity(
44+
hn::ConditionedHierNormal{names}, params::AbstractVector
45+
) where {names}
46+
if Set(names) == Set([:mu]) # conditioned on mu, so params are tau2
47+
return log_joint(; mu=hn.conditioned_values.mu, tau2=params, x=hn.data.x)
48+
elseif Set(names) == Set([:tau2]) # conditioned on tau2, so params are mu
49+
return log_joint(; mu=params, tau2=hn.conditioned_values.tau2, x=hn.data.x)
50+
else
51+
error("Unsupported conditioning configuration.")
52+
end
53+
end
54+
55+
function LogDensityProblems.capabilities(::HierNormal)
56+
return LogDensityProblems.LogDensityOrder{0}()
57+
end
58+
59+
function LogDensityProblems.capabilities(::ConditionedHierNormal)
60+
return LogDensityProblems.LogDensityOrder{0}()
61+
end
62+
63+
function flatten(nt::NamedTuple)
64+
return only(values(nt))
65+
end
66+
67+
function unflatten(vec::AbstractVector, group::Tuple)
68+
return NamedTuple((only(group) => vec,))
69+
end
70+
71+
function recompute_logprob!!(hn::ConditionedHierNormal, vals, state)
72+
return setlogp!!(state, LogDensityProblems.logdensity(hn, vals))
73+
end

0 commit comments

Comments
 (0)