Skip to content

Commit 7b80078

Browse files
committed
Remove unicode and boilerplate code
1 parent a16fa77 commit 7b80078

File tree

3 files changed

+82
-49
lines changed

3 files changed

+82
-49
lines changed

src/mh-core.jl

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
44
Fields:
55
6-
- `init_θ` is the vector form of the parameters needed for the likelihood function.
6+
- `init_params` is the vector form of the parameters needed for the likelihood function.
77
- `proposal` is a function that dynamically constructs a conditional distribution.
88
99
Example:
@@ -12,31 +12,12 @@ Example:
1212
MetropolisHastings([0.0, 0.0], x -> MvNormal(x, 1.0))
1313
````
1414
"""
15-
struct MetropolisHastings{P<:ProposalStyle, T, D} <: Metropolis
15+
mutable struct MetropolisHastings{P<:ProposalStyle, D, T} <: Metropolis
1616
proposal_type :: P
17-
init_θ :: T
1817
proposal :: D
18+
init_params :: T
1919
end
2020

21-
# Default constructors.
22-
MetropolisHastings(init_θ::Real) = MetropolisHastings(init_θ, Normal(0,1))
23-
MetropolisHastings(init_θ::Vector{<:Real}) = MetropolisHastings(init_θ, MvNormal(length(init_θ),1))
24-
25-
"""
26-
propose(spl::MetropolisHastings, model::DensityModel, t::Transition)
27-
28-
Generates a new parameter proposal conditional on the model, the sampler, and the previous
29-
sample.
30-
"""
31-
@inline propose(spl::MetropolisHastings, model::DensityModel, t::Transition) = propose(spl, model, t.θ)
32-
33-
"""
34-
q(spl::MetropolisHastings, t1::Transition, t2::Transition)
35-
36-
Calculates the probability `q(θ | θcond)`, using the proposal distribution `spl.proposal`.
37-
"""
38-
@inline q(spl::MetropolisHastings, t1::Transition, t2::Transition) = q(spl, t1.θ, t2.θ)
39-
4021
# Define the first step! function, which is called at the
4122
# beginning of sampling. Return the initial parameter used
4223
# to define the sampler.
@@ -47,7 +28,7 @@ function step!(
4728
N::Integer;
4829
kwargs...
4930
)
50-
return Transition(model, spl.init_θ)
31+
return Transition(model, spl.init_params)
5132
end
5233

5334
# Define the other step functions. Returns a Transition containing
@@ -58,19 +39,20 @@ function step!(
5839
model::DensityModel,
5940
spl::MetropolisHastings,
6041
::Integer,
61-
θ_prev::Transition;
42+
params_prev::Transition;
6243
kwargs...
6344
)
6445
# Generate a new proposal.
65-
θ = propose(spl, model, θ_prev)
46+
params = propose(spl, model, params_prev)
6647

6748
# Calculate the log acceptance probability.
68-
α = ℓπ(model, θ) - ℓπ(model, θ_prev) + q(spl, θ_prev, θ) - q(spl, θ, θ_prev)
49+
α = logdensity(model, params) - logdensity(model, params_prev) +
50+
q(spl, params_prev, params) - q(spl, params, params_prev)
6951

70-
# Decide whether to return the previous θ or the new one.
52+
# Decide whether to return the previous params or the new one.
7153
if log(rand(rng)) < min(α, 0.0)
72-
return θ
54+
return params
7355
else
74-
return θ_prev
56+
return params_prev
7557
end
7658
end

src/rwmh.jl

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Random walk Metropolis-Hastings.
88
99
Fields:
1010
11-
- `init_θ` is the vector form of the parameters needed for the likelihood function.
11+
- `init_params` is the vector form of the parameters needed for the likelihood function.
1212
- `proposal` is a function that dynamically constructs a conditional distribution.
1313
1414
Example:
@@ -17,20 +17,33 @@ Example:
1717
RWMH([0.0, 0.0], x -> MvNormal(x, 1.0))
1818
````
1919
"""
20-
RWMH(init_theta::Real, proposal = Normal(init_theta, 1)) = MetropolisHastings(RandomWalk(), init_theta, proposal)
21-
RWMH(init_theta::Vector{<:Real}, proposal = MvNormal(init_theta, 1)) = MetropolisHastings(RandomWalk(), init_theta, proposal)
20+
RWMH(init_theta, proposal) = MetropolisHastings(RandomWalk(), proposal, init_theta)
21+
function RWMH(init_theta::Vector, proposal)
22+
if proposal isa Vector
23+
# Verify that there are proposal distributions for each parameter.
24+
length(proposal) == length(init_theta) ||
25+
throw("The number of proposal distributions must match the number of parameters.")
26+
end
27+
28+
return MetropolisHastings(RandomWalk(), proposal, init_theta)
29+
end
2230

2331
# Define a function that makes a basic proposal depending on a univariate
2432
# parameterization or a multivariate parameterization.
25-
propose(spl::MetropolisHastings{RandomWalk}, model::DensityModel, θ::Real) = Transition(model, θ + rand(spl.proposal))
26-
propose(spl::MetropolisHastings{RandomWalk}, model::DensityModel, θ::Vector{<:Real}) = Transition(model, θ + rand(spl.proposal))
33+
propose(spl::MetropolisHastings{RandomWalk, <:Distribution}, model::DensityModel, t::Transition) = Transition(model, t.params + rand(spl.proposal))
34+
function propose(spl::MetropolisHastings{RandomWalk, <:AbstractArray}, model::DensityModel, t::Transition)
35+
props = map(x -> x[2] + rand(x[1]), zip(spl.proposal, t.params))
36+
return Transition(model, props)
37+
end
2738

2839
"""
29-
q(θ::Real, dist::Sampleable)
30-
q(θ::Vector{<:Real}, dist::Sampleable)
40+
q(params::Real, dist::Sampleable)
41+
q(params::Vector{<:Real}, dist::Sampleable)
3142
q(t1::Transition, dist::Sampleable)
3243
33-
Calculates the probability `q(θ | θcond)`, using the proposal distribution `spl.proposal`.
44+
Calculates the probability `q(params | paramscond)`, using the proposal distribution `spl.proposal`.
3445
"""
35-
@inline q(spl::MetropolisHastings{RandomWalk}, θ::Real, θcond::Real) = logpdf(spl.proposal, θ - θcond)
36-
@inline q(spl::MetropolisHastings{RandomWalk}, θ::Vector{<:Real}, θcond::Vector{<:Real}) = logpdf(spl.proposal, θ - θcond)
46+
q(spl::MetropolisHastings{RandomWalk, <:Distribution}, t::Transition, t_cond::Transition) = logpdf(spl.proposal, t.params - t_cond.params)
47+
function q(spl::MetropolisHastings{RandomWalk, <:AbstractArray}, t::Transition, t_cond::Transition)
48+
return sum(map(x -> logpdf(x[1], x[2] - x[3]), zip(spl.proposal, t.params, t_cond.params)))
49+
end

src/staticmh.jl

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Static Metropolis-Hastings. Proposes only from the prior distribution.
88
99
Fields:
1010
11-
- `init_θ` is the vector form of the parameters needed for the likelihood function.
11+
- `init_params` is the vector form of the parameters needed for the likelihood function.
1212
- `proposal` is a distribution.
1313
1414
Example:
@@ -17,21 +17,59 @@ Example:
1717
RWMH([0.0, 0.0], MvNormal(x, 1.0))
1818
````
1919
"""
20-
StaticMH(init_theta::Real, proposal = Normal(init_theta, 1)) = MetropolisHastings(Static(), init_theta, proposal)
21-
StaticMH(init_theta::Vector{<:Real}, proposal = MvNormal(init_theta, 1)) = MetropolisHastings(Static(), init_theta, proposal)
20+
StaticMH(init_theta, proposal) = MetropolisHastings(Static(), proposal, init_theta)
21+
function StaticMH(init_theta::Vector, proposal)
22+
if proposal isa Vector
23+
# Verify that there are proposal distributions for each parameter.
24+
length(proposal) == length(init_theta) ||
25+
throw("The number of proposal distributions must match the number of parameters.")
26+
end
27+
28+
return MetropolisHastings(Static(), proposal, init_theta)
29+
end
2230

2331
# Define a function that makes a basic proposal depending on a univariate
2432
# parameterization or a multivariate parameterization.
25-
propose(spl::MetropolisHastings{Static}, model::DensityModel, θ::Real) = Transition(model, rand(spl.proposal))
26-
propose(spl::MetropolisHastings{Static}, model::DensityModel, θ::Vector{<:Real}) = Transition(model, rand(spl.proposal))
33+
propose(spl::MetropolisHastings{Static, <:Distribution}, model::DensityModel, params::Transition) = Transition(model, rand(spl.proposal))
34+
function propose(spl::MetropolisHastings{Static, <:AbstractArray}, model::DensityModel, params::Transition)
35+
props = map(rand, spl.proposal)
36+
return Transition(model, props)
37+
end
38+
function propose(spl::MetropolisHastings{Static, <:NamedTuple}, model::DensityModel, params::Transition)
39+
return Transition(model, _propose(spl.proposal))
40+
end
41+
@generated function _propose(proposals::NamedTuple{names}) where {names}
42+
expr = Expr(:tuple)
43+
map(names) do f
44+
push!(expr.args, Expr(:(=), f, :(rand(proposals.$f)) ))
45+
end
46+
return expr
47+
end
2748

2849
"""
29-
q(θ::Real, dist::Sampleable)
30-
q(θ::Vector{<:Real}, dist::Sampleable)
50+
q(params::Real, dist::Sampleable)
51+
q(params::Vector{<:Real}, dist::Sampleable)
3152
q(t1::Transition, dist::Sampleable)
3253
33-
Calculates the probability `q(θ | θcond)`, using the proposal distribution `spl.proposal`.
54+
Calculates the probability `q(params | paramscond)`, using the proposal distribution `spl.proposal`.
3455
"""
35-
q(spl::MetropolisHastings{Static}, θ::Real, θcond::Real) = logpdf(spl.proposal, θ)
36-
q(spl::MetropolisHastings{Static}, θ::Vector{<:Real}, θcond::Vector{<:Real}) = logpdf(spl.proposal, θ)
56+
function q(spl::MetropolisHastings{Static, <:Distribution}, t::Transition, t_cond::Transition)
57+
return logpdf(spl.proposal, t.params)
58+
end
59+
60+
function q(spl::MetropolisHastings{Static, <:AbstractArray}, t::Transition, t_cond::Transition)
61+
return sum(map(x -> logpdf(x[1], x[2]), zip(spl.proposal, t.params)))
62+
end
63+
64+
function q(spl::MetropolisHastings{Static, <:NamedTuple}, t::Transition, t_cond::Transition)
65+
total = 0.0
66+
for p in keys(t.params)
67+
if length(t.params[p]) == 1
68+
total += logpdf(spl.proposal[p], t.params[p][1])
69+
else
70+
total += logpdf(spl.proposal[p], t.params[p])
71+
end
72+
end
73+
return total
74+
end
3775

0 commit comments

Comments
 (0)