Skip to content

Commit 13e1228

Browse files
committed
Add initial parameter keyword
1 parent 728c7b5 commit 13e1228

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.3.0"
3+
version = "0.3.1"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/mh-core.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,14 @@ function step!(
155155
model::DensityModel,
156156
spl::MetropolisHastings,
157157
N::Integer;
158+
init_params=nothing,
158159
kwargs...
159160
)
160-
return propose(spl, model)
161+
if init_params === nothing
162+
return propose(spl, model)
163+
else
164+
return Transition(model, init_params)
165+
end
161166
end
162167

163168
# Define the other step functions. Returns a Transition containing

test/runtests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,17 @@ using MCMCChains
7878
c3 = sample(m3, MetropolisHastings(p3), 100; chain_type=Vector{NamedTuple})
7979
c4 = sample(m4, MetropolisHastings(p4), 100; chain_type=Vector{NamedTuple})
8080
end
81+
82+
@testset "Initial parameters" begin
83+
# Set up our sampler with initial parameters.
84+
spl1 = StaticMH([Normal(0,1), Normal(0, 1)])
85+
86+
val = [0.4, 1.2]
87+
88+
# Sample from the posterior.
89+
chain1 = sample(model, spl1, 10, init_params = val)
90+
91+
@test chain1[1].params == val
92+
end
8193
end
8294

0 commit comments

Comments
 (0)