Skip to content

Commit 7a77bb2

Browse files
authored
Merge pull request #13 from TuringLang/csp/init_params
Allow parameter initialization
2 parents 728c7b5 + 5189c38 commit 7a77bb2

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.3.0"
3+
version = "0.3.1"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/mh-core.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@ The sampler is constructed using
2929
```julia
3030
spl = MetropolisHastings(proposal)
3131
```
32+
33+
When using `MetropolisHastings` with the function `sample`, the following keyword
34+
arguments are allowed:
35+
36+
- `init_params` defines the initial parameterization for your model. If
37+
none is given, the initial parameters will be drawn from the sampler's proposals.
38+
- `param_names` is a vector of strings to be assigned to parameters. This is only
39+
used if `chain_type=Chains`.
40+
- `chain_type` is the type of chain you would like returned to you. Supported
41+
types are `chain_type=Chains` if `MCMCChains` is imported, or
42+
`chain_type=StructArray` if `StructArrays` is imported.
3243
"""
3344
mutable struct MetropolisHastings{D} <: Metropolis
3445
proposal :: D
@@ -155,9 +166,14 @@ function step!(
155166
model::DensityModel,
156167
spl::MetropolisHastings,
157168
N::Integer;
169+
init_params=nothing,
158170
kwargs...
159171
)
160-
return propose(spl, model)
172+
if init_params === nothing
173+
return propose(spl, model)
174+
else
175+
return Transition(model, init_params)
176+
end
161177
end
162178

163179
# Define the other step functions. Returns a Transition containing

test/runtests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,17 @@ using MCMCChains
7878
c3 = sample(m3, MetropolisHastings(p3), 100; chain_type=Vector{NamedTuple})
7979
c4 = sample(m4, MetropolisHastings(p4), 100; chain_type=Vector{NamedTuple})
8080
end
81+
82+
@testset "Initial parameters" begin
83+
# Set up our sampler with initial parameters.
84+
spl1 = StaticMH([Normal(0,1), Normal(0, 1)])
85+
86+
val = [0.4, 1.2]
87+
88+
# Sample from the posterior.
89+
chain1 = sample(model, spl1, 10, init_params = val)
90+
91+
@test chain1[1].params == val
92+
end
8193
end
8294

0 commit comments

Comments
 (0)