|
| 1 | +using LogDensityProblems |
| 2 | + |
| 3 | +abstract type AbstractHierNormal end |
| 4 | + |
| 5 | +struct HierNormal <: AbstractHierNormal |
| 6 | + data::NamedTuple |
| 7 | +end |
| 8 | + |
| 9 | +struct ConditionedHierNormal{conditioned_vars} <: AbstractHierNormal |
| 10 | + data::NamedTuple |
| 11 | + conditioned_values::NamedTuple{conditioned_vars} |
| 12 | +end |
| 13 | + |
| 14 | +function log_joint(; mu, tau2, x) |
| 15 | + # mu is the mean |
| 16 | + # tau2 is the variance |
| 17 | + # x is data |
| 18 | + |
| 19 | + # μ ~ Normal(0, 1) |
| 20 | + # τ² ~ InverseGamma(1, 1) |
| 21 | + # xᵢ ~ Normal(μ, √τ²) |
| 22 | + |
| 23 | + logp = 0.0 |
| 24 | + mu = only(mu) |
| 25 | + tau2 = only(tau2) |
| 26 | + |
| 27 | + mu_prior = Normal(0, 1) |
| 28 | + logp += logpdf(mu_prior, mu) |
| 29 | + |
| 30 | + tau2_prior = InverseGamma(1, 1) |
| 31 | + logp += logpdf(tau2_prior, tau2) |
| 32 | + |
| 33 | + obs_prior = Normal(mu, sqrt(tau2)) |
| 34 | + logp += sum(logpdf(obs_prior, xi) for xi in x) |
| 35 | + |
| 36 | + return logp |
| 37 | +end |
| 38 | + |
| 39 | +function condition(hn::HierNormal, conditioned_values::NamedTuple) |
| 40 | + return ConditionedHierNormal(hn.data, conditioned_values) |
| 41 | +end |
| 42 | + |
| 43 | +function LogDensityProblems.logdensity( |
| 44 | + hn::ConditionedHierNormal{names}, params::AbstractVector |
| 45 | +) where {names} |
| 46 | + if Set(names) == Set([:mu]) # conditioned on mu, so params are tau2 |
| 47 | + return log_joint(; mu=hn.conditioned_values.mu, tau2=params, x=hn.data.x) |
| 48 | + elseif Set(names) == Set([:tau2]) # conditioned on tau2, so params are mu |
| 49 | + return log_joint(; mu=params, tau2=hn.conditioned_values.tau2, x=hn.data.x) |
| 50 | + else |
| 51 | + error("Unsupported conditioning configuration.") |
| 52 | + end |
| 53 | +end |
| 54 | + |
| 55 | +function LogDensityProblems.capabilities(::HierNormal) |
| 56 | + return LogDensityProblems.LogDensityOrder{0}() |
| 57 | +end |
| 58 | + |
| 59 | +function LogDensityProblems.capabilities(::ConditionedHierNormal) |
| 60 | + return LogDensityProblems.LogDensityOrder{0}() |
| 61 | +end |
| 62 | + |
| 63 | +function flatten(nt::NamedTuple) |
| 64 | + return only(values(nt)) |
| 65 | +end |
| 66 | + |
| 67 | +function unflatten(vec::AbstractVector, group::Tuple) |
| 68 | + return NamedTuple((only(group) => vec,)) |
| 69 | +end |
| 70 | + |
| 71 | +function recompute_logprob!!(hn::ConditionedHierNormal, vals, state) |
| 72 | + return setlogp!!(state, LogDensityProblems.logdensity(hn, vals)) |
| 73 | +end |
0 commit comments