Skip to content

Commit bcf8fe6

Browse files
committed
Merge branch 'csp/abstractmcmc' into cpfiffer-patch-1
2 parents 26b5be7 + e75086c commit bcf8fe6

File tree

4 files changed

+38
-8
lines changed

4 files changed

+38
-8
lines changed

.github/workflows/CompatHelper.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
name: CompatHelper
2+
3+
on:
4+
schedule:
5+
- cron: '00 00 * * *'
6+
7+
jobs:
8+
CompatHelper:
9+
runs-on: ubuntu-latest
10+
steps:
11+
- uses: julia-actions/setup-julia@latest
12+
with:
13+
version: 1.3
14+
- name: Pkg.add("CompatHelper")
15+
run: julia -e 'using Pkg; Pkg.add("CompatHelper")'
16+
- name: CompatHelper.main()
17+
env:
18+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
19+
run: julia -e 'using CompatHelper; CompatHelper.main()'

.travis.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Documentation: http://docs.travis-ci.com/user/languages/julia/
22
language: julia
3+
branches:
4+
only:
5+
- master
36
os:
47
- linux
58
- osx
@@ -16,4 +19,6 @@ matrix:
1619
notifications:
1720
email: false
1821
after_success:
19-
- julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Codecov.submit(process_folder())'
22+
- if [[ $TRAVIS_JULIA_VERSION = 1.3 ]] && [[ $TRAVIS_OS_NAME = linux ]]; then
23+
julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Codecov.submit(process_folder())';
24+
fi

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1111

1212
[compat]
13-
AbstractMCMC = "0.1, 0.2, 0.3"
13+
AbstractMCMC = "0.3"
1414
AdvancedHMC = "0.2.20"
1515
Bijectors = "0.5.2"
1616
Distributions = "0.22"

test/Turing/inference/Inference.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,20 +144,22 @@ function AbstractMCMC.sample(
144144
model::AbstractModel,
145145
alg::InferenceAlgorithm,
146146
N::Integer;
147+
chain_type=Chains,
147148
kwargs...
148149
)
149-
return sample(rng, model, Sampler(alg, model), N; progress=PROGRESS[], kwargs...)
150+
return sample(rng, model, Sampler(alg, model), N; progress=PROGRESS[], chain_type=chain_type, kwargs...)
150151
end
151152

152153
function AbstractMCMC.sample(
153154
model::AbstractModel,
154155
alg::InferenceAlgorithm,
155156
N::Integer;
156157
resume_from=nothing,
158+
chain_type=Chains,
157159
kwargs...
158160
)
159161
if resume_from === nothing
160-
return sample(model, Sampler(alg, model), N; progress=PROGRESS[], kwargs...)
162+
return sample(model, Sampler(alg, model), N; progress=PROGRESS[], chain_type=chain_type, kwargs...)
161163
else
162164
return resume(resume_from, N)
163165
end
@@ -169,9 +171,10 @@ function AbstractMCMC.psample(
169171
alg::InferenceAlgorithm,
170172
N::Integer,
171173
n_chains::Integer;
174+
chain_type=Chains,
172175
kwargs...
173176
)
174-
return psample(GLOBAL_RNG, model, alg, N, n_chains; progress=false, kwargs...)
177+
return psample(GLOBAL_RNG, model, alg, N, n_chains; progress=false, chain_type=chain_type, kwargs...)
175178
end
176179

177180
function AbstractMCMC.psample(
@@ -180,9 +183,10 @@ function AbstractMCMC.psample(
180183
alg::InferenceAlgorithm,
181184
N::Integer,
182185
n_chains::Integer;
186+
chain_type=Chains,
183187
kwargs...
184188
)
185-
return psample(rng, model, Sampler(alg, model), N, n_chains; progress=false, kwargs...)
189+
return psample(rng, model, Sampler(alg, model), N, n_chains; progress=false, chain_type=chain_type, kwargs...)
186190
end
187191

188192
function AbstractMCMC.sample_init!(
@@ -318,7 +322,8 @@ function AbstractMCMC.bundle_samples(
318322
model::AbstractModel,
319323
spl::Sampler,
320324
N::Integer,
321-
ts::Vector{<:AbstractTransition};
325+
ts::Vector{<:AbstractTransition},
326+
ct::Type{Chains};
322327
discard_adapt::Bool=true,
323328
save_state=true,
324329
kwargs...
@@ -375,7 +380,7 @@ function save(c::Chains, spl::AbstractSampler, model, vi, samples)
375380
return setinfo(c, merge(nt, c.info))
376381
end
377382

378-
function resume(c::Chains, n_iter::Int; kwargs...)
383+
function resume(c::Chains, n_iter::Int; chain_type=Chains, kwargs...)
379384
@assert !isempty(c.info) "[Turing] cannot resume from a chain without state info"
380385

381386
# Sample a new chain.
@@ -386,6 +391,7 @@ function resume(c::Chains, n_iter::Int; kwargs...)
386391
n_iter;
387392
resume_from=c,
388393
reuse_spl_n=n_iter,
394+
chain_type=chain_type,
389395
kwargs...
390396
)
391397

0 commit comments

Comments
 (0)