Skip to content

Commit 156afff

Browse files
Merge branch 'master' of https://github.com/FredericWantiez/AdvancedPS.jl into feature/traced_rng
2 parents f125369 + 176640e commit 156afff

File tree

14 files changed

+229
-183
lines changed

14 files changed

+229
-183
lines changed

.JuliaFormatter.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
style="blue"

.github/workflows/Format.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
name: Format
2+
3+
on:
4+
push:
5+
branches:
6+
- master
7+
pull_request:
8+
9+
jobs:
10+
format:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v2
14+
- uses: julia-actions/setup-julia@latest
15+
with:
16+
version: 1
17+
- name: Format code
18+
run: |
19+
using Pkg
20+
Pkg.add(; name="JuliaFormatter", uuid="98e50ef6-434e-11e9-1051-2b60c6c9e899")
21+
using JuliaFormatter
22+
format("."; verbose=true)
23+
shell: julia --color=yes {0}
24+
- uses: reviewdog/action-suggester@v1
25+
if: github.event_name == 'pull_request'
26+
with:
27+
tool_name: JuliaFormatter
28+
fail_on_error: true

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AdvancedPS"
22
uuid = "576499cb-2369-40b2-a588-c64705576edc"
33
authors = ["TuringLang"]
4-
version = "0.2.0"
4+
version = "0.2.2"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -11,8 +11,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1212

1313
[compat]
14-
AbstractMCMC = "2"
15-
Distributions = "0.23, 0.24"
14+
AbstractMCMC = "2, 3"
15+
Distributions = "0.23, 0.24, 0.25"
1616
Libtask = "0.5"
1717
StatsFuns = "0.9"
1818
julia = "1.3"

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
[![Build Status](https://github.com/TuringLang/AdvancedPS.jl/workflows/CI/badge.svg?branch=master)](https://github.com/TuringLang/AdvancedPS.jl/actions?query=workflow%3ACI%20branch%3Amaster)
44
[![Coverage](https://codecov.io/gh/TuringLang/AdvancedPS.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/TuringLang/AdvancedPS.jl)
55
[![Coverage](https://coveralls.io/repos/github/TuringLang/AdvancedPS.jl/badge.svg?branch=master)](https://coveralls.io/github/TuringLang/AdvancedPS.jl?branch=master)
6+
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
67

78
### Reference
89

src/AdvancedPS.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
module AdvancedPS
22

3-
import AbstractMCMC
4-
import Distributions
5-
import Libtask
6-
import Random
7-
import StatsFuns
3+
using AbstractMCMC: AbstractMCMC
4+
using Distributions: Distributions
5+
using Libtask: Libtask
6+
using Random: Random
7+
using StatsFuns: StatsFuns
88

99
include("resampling.jl")
1010
include("rng.jl")

src/container.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ reset_model(f) = deepcopy(f)
4343
delete_retained!(f) = nothing
4444

4545
# Task copying version of fork for Trace.
46-
function fork(trace::Trace, isref::Bool = false)
46+
function fork(trace::Trace, isref::Bool=false)
4747
newtrace = copy(trace)
4848
isref && delete_retained!(newtrace.f)
4949

@@ -58,7 +58,7 @@ end
5858
function forkr(trace::Trace)
5959
newf = reset_model(trace.f)
6060

61-
ctask = let f=trace.ctask.task.code
61+
ctask = let f = trace.ctask.task.code
6262
Libtask.CTask() do
6363
res = f()(trace.rng)
6464
Libtask.produce(nothing)
@@ -115,7 +115,7 @@ end
115115
function Base.push!(pc::ParticleContainer, p::Particle)
116116
push!(pc.vals, p)
117117
push!(pc.logWs, 0.0)
118-
pc
118+
return pc
119119
end
120120

121121
# clones a theta-particle
@@ -126,7 +126,7 @@ function Base.copy(pc::ParticleContainer)
126126
# copy weights
127127
logWs = copy(pc.logWs)
128128

129-
ParticleContainer(vals, logWs)
129+
return ParticleContainer(vals, logWs)
130130
end
131131

132132
"""
@@ -193,9 +193,9 @@ of the particle `weights`. For Particle Gibbs sampling, one can provide a refere
193193
function resample_propagate!(
194194
rng::Random.AbstractRNG,
195195
pc::ParticleContainer,
196-
randcat = resample_systematic,
197-
ref::Union{Particle, Nothing} = nothing;
198-
weights = getweights(pc)
196+
randcat=resample_systematic,
197+
ref::Union{Particle,Nothing}=nothing;
198+
weights=getweights(pc),
199199
)
200200
# check that weights are not NaN
201201
@assert !any(isnan, weights)
@@ -242,24 +242,24 @@ function resample_propagate!(
242242
pc.vals = children
243243
reset_logweights!(pc)
244244

245-
pc
245+
return pc
246246
end
247247

248248
function resample_propagate!(
249249
rng::Random.AbstractRNG,
250250
pc::ParticleContainer,
251251
resampler::ResampleWithESSThreshold,
252-
ref::Union{Particle,Nothing} = nothing;
253-
weights = getweights(pc)
252+
ref::Union{Particle,Nothing}=nothing;
253+
weights=getweights(pc),
254254
)
255255
# Compute the effective sample size ``1 / ∑ wᵢ²`` with normalized weights ``wᵢ``
256256
ess = inv(sum(abs2, weights))
257257

258258
if ess resampler.threshold * length(pc)
259-
resample_propagate!(rng, pc, resampler.resampler, ref; weights = weights)
259+
resample_propagate!(rng, pc, resampler.resampler, ref; weights=weights)
260260
end
261261

262-
pc
262+
return pc
263263
end
264264

265265
"""
@@ -300,9 +300,13 @@ function reweight!(pc::ParticleContainer)
300300

301301
# The posterior for models with random number of observations is not well-defined.
302302
if numdone != 0
303-
error("mis-aligned execution traces: # particles = ", n,
304-
" # completed trajectories = ", numdone,
305-
". Please make sure the number of observations is NOT random.")
303+
error(
304+
"mis-aligned execution traces: # particles = ",
305+
n,
306+
" # completed trajectories = ",
307+
numdone,
308+
". Please make sure the number of observations is NOT random.",
309+
)
306310
end
307311

308312
return false

0 commit comments

Comments
 (0)