Skip to content

Commit c0ea6e0

Browse files
committed
Specify default chain type in Turing
1 parent aa3cfcf commit c0ea6e0

File tree

5 files changed

+6
-6
lines changed

5 files changed

+6
-6
lines changed

src/mcmc/Inference.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ export InferenceAlgorithm,
8080
# Abstract interface for inference algorithms #
8181
###############################################
8282

83+
const TURING_CHAIN_TYPE = MCMCChains.Chains
84+
8385
include("algorithm.jl")
8486

8587
####################

src/mcmc/algorithm.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ this wrapping occurs automatically.
1111
"""
1212
abstract type InferenceAlgorithm end
1313

14-
DynamicPPL.default_chain_type(sampler::Sampler{<:InferenceAlgorithm}) = MCMCChains.Chains
15-
1614
function DynamicPPL.init_strategy(sampler::Sampler{<:InferenceAlgorithm})
1715
return DynamicPPL.InitFromPrior()
1816
end

src/mcmc/hmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ function AbstractMCMC.sample(
8888
model::DynamicPPL.Model,
8989
sampler::Sampler{<:AdaptiveHamiltonian},
9090
N::Integer;
91-
chain_type=DynamicPPL.default_chain_type(sampler),
91+
chain_type=TURING_CHAIN_TYPE,
9292
resume_from=nothing,
9393
initial_params=DynamicPPL.init_strategy(sampler),
9494
initial_state=DynamicPPL.loadstate(resume_from),

src/mcmc/particle_mcmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ function AbstractMCMC.sample(
142142
model::DynamicPPL.Model,
143143
sampler::Sampler{<:SMC},
144144
N::Integer;
145-
chain_type=DynamicPPL.default_chain_type(sampler),
145+
chain_type=TURING_CHAIN_TYPE,
146146
resume_from=nothing,
147147
initial_params=DynamicPPL.init_strategy(sampler),
148148
initial_state=DynamicPPL.loadstate(resume_from),

src/mcmc/repeat_sampler.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ function AbstractMCMC.sample(
9595
sampler::RepeatSampler{<:Sampler},
9696
N::Integer;
9797
initial_params=DynamicPPL.init_strategy(sampler),
98-
chain_type=MCMCChains.Chains,
98+
chain_type=TURING_CHAIN_TYPE,
9999
progress=PROGRESS[],
100100
kwargs...,
101101
)
@@ -119,7 +119,7 @@ function AbstractMCMC.sample(
119119
N::Integer,
120120
n_chains::Integer;
121121
initial_params=fill(DynamicPPL.init_strategy(sampler), n_chains),
122-
chain_type=MCMCChains.Chains,
122+
chain_type=TURING_CHAIN_TYPE,
123123
progress=PROGRESS[],
124124
kwargs...,
125125
)

0 commit comments

Comments
 (0)