Skip to content

Commit ef49c73

Browse files
committed
Remove ProposalStyle and forward function proposals
1 parent 02a6cb9 commit ef49c73

File tree

4 files changed

+50
-86
lines changed

4 files changed

+50
-86
lines changed

src/AdvancedMH.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,14 @@ using Requires
88
using Random
99

1010
# Exports
11-
export MetropolisHastings, DensityModel, RWMH, StaticMH, Proposal, Static, RandomWalk
11+
export MetropolisHastings, DensityModel, RWMH, StaticMH, StaticProposal, RandomWalkProposal
1212

1313
# Reexports
1414
using AbstractMCMC: sample, psample
1515
export sample, psample
1616

1717
# Abstract type for MH-style samplers.
1818
abstract type Metropolis <: AbstractMCMC.AbstractSampler end
19-
abstract type ProposalStyle end
20-
21-
struct RandomWalk <: ProposalStyle end
22-
struct Static <: ProposalStyle end
2319

2420
# Define a model type. Stores the log density function and the data to
2521
# evaluate the log density on.

src/mh-core.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@ x = (a = 1.0, b=3.8)
1212
The proposal would be
1313
1414
```julia
15-
proposal = (a=Proposal(Static(), Normal(0,1)), b=Proposal(Static(), Normal(0,1)))
15+
proposal = (a=StaticProposal(Normal(0,1)), b=StaticProposal(Normal(0,1)))
1616
````
1717
18-
Other allowed proposal styles are
18+
Other allowed proposals are
1919
2020
```
21-
p1 = Proposal(Static(), Normal(0,1))
22-
p2 = Proposal(Static(), [Normal(0,1), InverseGamma(2,3)])
23-
p3 = Proposal(Static(), (a=Normal(0,1), b=InverseGamma(2,3)))
24-
p4 = Proposal(Static(), (x=1.0) -> Normal(x, 1))
21+
p1 = StaticProposal(Normal(0,1))
22+
p2 = StaticProposal([Normal(0,1), InverseGamma(2,3)])
23+
p3 = StaticProposal(a=Normal(0,1), b=InverseGamma(2,3))
24+
p4 = StaticProposal((x=1.0) -> Normal(x, 1))
2525
```
2626
2727
The sampler is constructed using
@@ -45,8 +45,8 @@ mutable struct MetropolisHastings{D} <: Metropolis
4545
proposal :: D
4646
end
4747

48-
StaticMH(d) = MetropolisHastings(Proposal(Static(), d))
49-
RWMH(d) = MetropolisHastings(Proposal(RandomWalk(), d))
48+
StaticMH(d) = MetropolisHastings(StaticProposal(d))
49+
RWMH(d) = MetropolisHastings(RandomWalkProposal(d))
5050

5151
# Propose from a vector of proposals
5252
function propose(

src/proposal.jl

Lines changed: 37 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,131 +1,99 @@
1-
struct Proposal{T<:ProposalStyle, P}
2-
type :: T
3-
proposal :: P
1+
abstract type Proposal{P} end
2+
3+
struct StaticProposal{P} <: Proposal{P}
4+
proposal::P
45
end
56

6-
# Random draws
7-
Base.rand(p::Proposal{<:ProposalStyle, <:Distribution}) = rand(p.proposal)
8-
function Base.rand(p::Proposal{<:ProposalStyle, <:AbstractArray})
9-
return map(rand, p.proposal)
7+
struct RandomWalkProposal{P} <: Proposal{P}
8+
proposal::P
109
end
1110

11+
# Random draws
12+
Base.rand(p::Proposal{<:Distribution}) = rand(p.proposal)
13+
Base.rand(p::Proposal{<:AbstractArray}) = map(rand, p.proposal)
14+
1215
# Densities
13-
function Distributions.logpdf(p::Proposal{<:ProposalStyle, <:UnivariateDistribution}, v)
14-
return sum(logpdf(p.proposal, v))
15-
end
16-
function Distributions.logpdf(p::Proposal{<:ProposalStyle, <:MultivariateDistribution}, v)
17-
return sum(logpdf(p.proposal, v))
18-
end
19-
function Distributions.logpdf(p::Proposal{<:ProposalStyle, <:MatrixDistribution}, v)
20-
return sum(logpdf(p.proposal, v))
21-
end
22-
function Distributions.logpdf(p::Proposal{<:ProposalStyle, <:AbstractArray}, v)
23-
return sum(map(x -> logpdf(x[1], x[2]), zip(p.proposal, v)))
24-
end
25-
function Distributions.logpdf(p::Proposal{<:ProposalStyle, <:Function}, v)
26-
return logpdf(p.proposal(v), v)
16+
Distributions.logpdf(p::Proposal{<:Distribution}, v) = logpdf(p.proposal, v)
17+
function Distributions.logpdf(p::Proposal{<:AbstractArray}, v)
18+
# `mapreduce` with multiple iterators requires Julia 1.2 or later
19+
return mapreduce(((pi, vi),) -> logpdf(pi, vi), +, zip(p.proposal, v))
2720
end
2821

2922
###############
3023
# Random Walk #
3124
###############
3225

33-
function propose(
34-
proposal::Proposal{RandomWalk, <:Distribution},
35-
model::DensityModel,
36-
t
37-
)
38-
return t + rand(proposal)
26+
function propose(p::RandomWalkProposal, m::DensityModel)
27+
return propose(StaticProposal(p.proposal), m)
3928
end
4029

4130
function propose(
42-
proposal::Proposal{RandomWalk, <:AbstractArray},
31+
proposal::RandomWalkProposal{<:Union{Distribution,AbstractArray}},
4332
model::DensityModel,
4433
t
4534
)
4635
return t + rand(proposal)
4736
end
4837

4938
function q(
50-
proposal::Proposal{RandomWalk, <:Distribution},
39+
proposal::RandomWalkProposal{<:Union{Distribution,AbstractArray}},
5140
t,
5241
t_cond
5342
)
54-
return sum(logpdf(proposal, t - t_cond))
55-
end
56-
57-
function q(
58-
proposal::Proposal{RandomWalk, <:AbstractArray},
59-
t,
60-
t_cond
61-
)
62-
return sum(logpdf(proposal, t - t_cond))
43+
return logpdf(proposal, t - t_cond)
6344
end
6445

6546
##########
6647
# Static #
6748
##########
6849

69-
propose(p::Proposal{RandomWalk}, m::DensityModel) = propose(Proposal(Static(), p.proposal), m)
7050
function propose(
71-
proposal::Proposal{Static, <:Distribution},
72-
model::DensityModel,
51+
proposal::StaticProposal{<:Union{Distribution,AbstractArray}},
52+
model::DensityModel,
7353
t=nothing
7454
)
7555
return rand(proposal)
7656
end
7757

78-
function propose(
79-
p::Proposal{Static, <:AbstractArray},
80-
model::DensityModel,
81-
t=nothing
82-
)
83-
props = map(x -> rand(x), p.proposal)
84-
return props
85-
end
86-
8758
function q(
88-
proposal::Proposal{Static, <:Distribution},
59+
proposal::StaticProposal{<:Union{Distribution,AbstractArray}},
8960
t,
9061
t_cond
9162
)
92-
return sum(logpdf(proposal, t))
93-
end
94-
95-
function q(
96-
proposal::Proposal{Static, <:AbstractArray},
97-
t,
98-
t_cond
99-
)
100-
return sum(logpdf(proposal, t))
63+
return logpdf(proposal, t)
10164
end
10265

10366
############
10467
# Function #
10568
############
10669

70+
# function definition with abstract types requires Julia 1.3 or later
71+
for T in (StaticProposal, RandomWalkProposal)
72+
@eval begin
73+
(p::$T{<:Function})() = $T(p.proposal())
74+
(p::$T{<:Function})(t) = $T(p.proposal(t))
75+
end
76+
end
77+
10778
function propose(
108-
proposal::Proposal{<:ProposalStyle, <:Function},
79+
proposal::Proposal{<:Function},
10980
model::DensityModel
11081
)
111-
p = proposal.proposal
112-
return rand(proposal.proposal())
82+
return propose(proposal(), model)
11383
end
11484

11585
function propose(
116-
proposal::Proposal{<:ProposalStyle, <:Function},
86+
proposal::Proposal{<:Function},
11787
model::DensityModel,
11888
t
11989
)
120-
p = proposal.proposal
121-
return rand(proposal.proposal(t))
90+
return propose(proposal(t), model)
12291
end
12392

12493
function q(
125-
proposal::Proposal{<:ProposalStyle, <:Function},
94+
proposal::Proposal{<:Function},
12695
t,
12796
t_cond
12897
)
129-
p = proposal.proposal
130-
return sum(logpdf.(p(t_cond), t))
98+
return q(proposal(t_cond), t, t_cond)
13199
end

test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,10 @@ using Test
6868
m3 = DensityModel(x -> logpdf(Normal(x.a, x.b), 1.0))
6969
m4 = DensityModel(x -> logpdf(Normal(x,1), 1.0))
7070

71-
p1 = Proposal(Static(), Normal(0,1))
72-
p2 = Proposal(Static(), [Normal(0,1), InverseGamma(2,3)])
73-
p3 = (a=Proposal(Static(), Normal(0,1)), b=Proposal(Static(), InverseGamma(2,3)))
74-
p4 = Proposal(Static(), (x=1.0) -> Normal(x, 1))
71+
p1 = StaticProposal(Normal(0,1))
72+
p2 = StaticProposal([Normal(0,1), InverseGamma(2,3)])
73+
p3 = (a=StaticProposal(Normal(0,1)), b=StaticProposal(InverseGamma(2,3)))
74+
p4 = StaticProposal((x=1.0) -> Normal(x, 1))
7575

7676
c1 = sample(m1, MetropolisHastings(p1), 100; chain_type=Vector{NamedTuple})
7777
c2 = sample(m2, MetropolisHastings(p2), 100; chain_type=Vector{NamedTuple})

0 commit comments

Comments
 (0)