Skip to content

Commit 6fbebf8

Browse files
FredericWantiezFrederic WantiezyebaistoropoliFrederic Wantiez
authored
Match AdvancedPS 0.4 API (#1858)
* Conflict * Conflict * Weird typo * Fix `LogPoisson` typo in docstring (#1888) Simple fix that I saw. * Conflict * test * Update state, still missing one * Sneaky RNG * Fix logevidence * Small test * AdvancedPS * Updating the RNG properly * Update Inference.jl * Update TuringCI.yml * Update DynamicHMC.yml * Update Numerical.yml * Conflict * Conflict * Weird typo * Conflict * test * Update state, still missing one * Sneaky RNG * Fix logevidence * Small test * AdvancedPS * Updating the RNG properly * Update Inference.jl * Update TuringCI.yml * Update DynamicHMC.yml * Update Numerical.yml * Scaffolding * Clean up * Forward kwargs * Optional args, TArray * Fix invalid reference error * Remove mentions of TArray to avoid confusion * Remove section on Libtask * Proper copy of TracedModel * Updating the VarInfo * Clean up comments, update reset_model * Increase sample numbers Co-authored-by: Frederic Wantiez <[email protected]> Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Jose Storopoli <[email protected]> Co-authored-by: Frederic Wantiez <[email protected]>
1 parent ed3f094 commit 6fbebf8

File tree

16 files changed

+110
-131
lines changed

16 files changed

+110
-131
lines changed

.github/workflows/DynamicHMC.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
strategy:
1313
matrix:
1414
version:
15-
- '1.6'
15+
- '1.7'
1616
- '1'
1717
os:
1818
- ubuntu-latest

.github/workflows/Numerical.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
strategy:
1313
matrix:
1414
version:
15-
- '1.6'
15+
- '1.7'
1616
- '1'
1717
os:
1818
- ubuntu-latest

.github/workflows/TuringCI.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
strategy:
1414
matrix:
1515
version:
16-
- '1.6'
16+
- '1.7'
1717
- '1'
1818
os:
1919
- ubuntu-latest
@@ -23,15 +23,15 @@ jobs:
2323
- 1
2424
- 2
2525
include:
26-
- version: '1.6'
26+
- version: '1.7'
2727
os: ubuntu-latest
2828
arch: x86
2929
num_threads: 2
30-
- version: '1.6'
30+
- version: '1.7'
3131
os: windows-latest
3232
arch: x64
3333
num_threads: 2
34-
- version: '1.6'
34+
- version: '1.7'
3535
os: macOS-latest
3636
arch: x64
3737
num_threads: 2
@@ -58,13 +58,13 @@ jobs:
5858
env:
5959
JULIA_NUM_THREADS: ${{ matrix.num_threads }}
6060
- uses: julia-actions/julia-processcoverage@v1
61-
if: matrix.version == '1.6' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1
61+
if: matrix.version == '1.7' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1
6262
- uses: codecov/codecov-action@v1
63-
if: matrix.version == '1.6' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1
63+
if: matrix.version == '1.7' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1
6464
with:
6565
file: lcov.info
6666
- uses: coverallsapp/github-action@master
67-
if: matrix.version == '1.6' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1
67+
if: matrix.version == '1.7' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1
6868
with:
6969
github-token: ${{ secrets.GITHUB_TOKEN }}
7070
path-to-lcov: lcov.info

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
3838
AbstractMCMC = "4"
3939
AdvancedHMC = "0.3.0"
4040
AdvancedMH = "0.6.8"
41-
AdvancedPS = "0.3.4"
41+
AdvancedPS = "0.4"
4242
AdvancedVI = "0.1"
4343
BangBang = "0.3"
4444
Bijectors = "0.8, 0.9, 0.10"
@@ -49,8 +49,8 @@ DocStringExtensions = "0.8, 0.9"
4949
DynamicPPL = "0.21"
5050
EllipticalSliceSampling = "0.5, 1"
5151
ForwardDiff = "0.10.3"
52-
Libtask = "0.6.7, 0.7"
5352
LogDensityProblems = "0.12, 1"
53+
Libtask = "0.7, 0.8"
5454
MCMCChains = "5"
5555
NamedArrays = "0.9"
5656
Reexport = "0.2, 1"

docs/src/library/api.md

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,3 @@ BinomialLogit
4242
VecBinomialLogit
4343
OrderedLogistic
4444
```
45-
46-
47-
## Data Structures
48-
```@docs
49-
TArray
50-
```
51-
52-
## Utilities
53-
```@docs
54-
tzeros
55-
```
56-

docs/src/using-turing/guide.md

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ title: Guide
1010

1111
### Introduction
1212

13-
A probabilistic program is Julia code wrapped in a `@model` macro. It can use arbitrary Julia code, but to ensure correctness of inference it should not have external effects or modify global state. Stack-allocated variables are safe, but mutable heap-allocated objects may lead to subtle bugs when using task copying. To help avoid those we provide a Turing-safe datatype `TArray` that can be used to create mutable arrays in Turing programs.
13+
A probabilistic program is Julia code wrapped in a `@model` macro. It can use arbitrary Julia code, but to ensure correctness of inference it should not have external effects or modify global state. Stack-allocated variables are safe, but mutable heap-allocated objects may lead to subtle bugs when using task copying. By default Libtask deepcopies `Array` and `Dict` objects when copying task to avoid bugs with data stored in mutable structure in Turing models.
1414

1515

1616
To specify distributions of random variables, Turing programs should use the `~` notation:
@@ -369,7 +369,7 @@ The element type of a vector (or matrix) of random variables should match the `e
369369
1. `Real` to enable auto-differentiation through the model which uses special number types that are sub-types of `Real`, or
370370
2. Some type parameter `T` defined in the model header using the type parameter syntax, e.g. `function gdemo(x, ::Type{T} = Float64) where {T}`.
371371
Similarly, when using a particle sampler, the Julia variable used should either be:
372-
1. A `TArray`, or
372+
1. An `Array`, or
373373
2. An instance of some type parameter `T` defined in the model header using the type parameter syntax, e.g. `function gdemo(x, ::Type{T} = Vector{Float64}) where {T}`.
374374

375375

@@ -596,48 +596,6 @@ plot(chn) # Plots statistics of the samples.
596596
There are numerous functions in addition to `describe` and `plot` in the `MCMCChains` package, such as those used in convergence diagnostics. For more information on the package, please see the [GitHub repository](https://github.com/TuringLang/MCMCChains.jl).
597597

598598

599-
### Working with Libtask.jl
600-
601-
602-
The [Libtask.jl](https://github.com/TuringLang/Libtask.jl) library provides write-on-copy data structures that are safe for use in Turing's particle-based samplers. One data structure in particular is often required for use – the [`TArray`](http://turing.ml/docs/library/#Libtask.TArray). The following sampler types require the use of a `TArray` to store distributions:
603-
604-
605-
* `IPMCMC`
606-
* `IS`
607-
* `PG`
608-
* `PMMH`
609-
* `SMC`
610-
611-
612-
If you do not use a `TArray` to store arrays of distributions when using a particle-based sampler, you may experience errors.
613-
614-
615-
Here is an example of how the `TArray` (using a `TArray` constructor function called `tzeros`) can be applied in this way:
616-
617-
618-
```julia
619-
# Turing model definition.
620-
@model function BayesHmm(y)
621-
# Declare a TArray with a length of N.
622-
s = tzeros(Int, N)
623-
m = Vector{Real}(undef, K)
624-
T = Vector{Vector{Real}}(undef, K)
625-
for i = 1:K
626-
T[i] ~ Dirichlet(ones(K)/K)
627-
m[i] ~ Normal(i, 0.01)
628-
end
629-
630-
# Draw from a distribution for each element in s.
631-
s[1] ~ Categorical(K)
632-
for i = 2:N
633-
s[i] ~ Categorical(vec(T[s[i-1]]))
634-
y[i] ~ Normal(m[s[i]], 0.1)
635-
end
636-
return (s, m)
637-
end;
638-
```
639-
640-
641599
### Changing Default Settings
642600

643601

src/essential/container.jl

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,13 @@ function TracedModel(
99
model::Model,
1010
sampler::AbstractSampler,
1111
varinfo::AbstractVarInfo,
12+
rng::Random.AbstractRNG
1213
)
13-
# evaluate!!(m.model, varinfo, SamplingContext(Random.AbstractRNG, m.sampler, DefaultContext()))
14-
context = SamplingContext(DynamicPPL.Random.GLOBAL_RNG, sampler, DefaultContext())
14+
context = SamplingContext(rng, sampler, DefaultContext())
1515
evaluator = _get_evaluator(model, varinfo, context)
1616
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(model, sampler, varinfo, evaluator)
1717
end
1818

19-
# Smiliar to `evaluate!!` except that we return the evaluator signature without execution.
2019
# TODO: maybe move to DynamicPPL
2120
@generated function _get_evaluator(
2221
model::Model{_F,argnames}, varinfo, context
@@ -39,42 +38,49 @@ end
3938
end
4039
end
4140

42-
function Base.copy(trace::AdvancedPS.Trace{<:TracedModel})
43-
f = trace.f
44-
newf = TracedModel(f.model, f.sampler, deepcopy(f.varinfo))
45-
return AdvancedPS.Trace(newf, copy(trace.task))
41+
42+
function Base.copy(model::AdvancedPS.GenericModel{<:TracedModel})
43+
newtask = copy(model.ctask)
44+
newmodel = TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(deepcopy(model.f.model), deepcopy(model.f.sampler), deepcopy(model.f.varinfo), deepcopy(model.f.evaluator))
45+
gen_model = AdvancedPS.GenericModel(newmodel, newtask)
46+
return gen_model
4647
end
4748

48-
function AdvancedPS.advance!(trace::AdvancedPS.Trace{<:TracedModel})
49-
DynamicPPL.increment_num_produce!(trace.f.varinfo)
50-
score = consume(trace.task)
49+
function AdvancedPS.advance!(trace::AdvancedPS.Trace{<:AdvancedPS.GenericModel{<:TracedModel}}, isref::Bool=false)
50+
# Make sure we load/reset the rng in the new replaying mechanism
51+
DynamicPPL.increment_num_produce!(trace.model.f.varinfo)
52+
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
53+
score = consume(trace.model.ctask)
5154
if score === nothing
5255
return
5356
else
54-
return score + DynamicPPL.getlogp(trace.f.varinfo)
57+
return score + DynamicPPL.getlogp(trace.model.f.varinfo)
5558
end
5659
end
5760

58-
function AdvancedPS.delete_retained!(f::TracedModel)
59-
DynamicPPL.set_retained_vns_del_by_spl!(f.varinfo, f.sampler)
60-
return
61+
function AdvancedPS.delete_retained!(trace::TracedModel)
62+
DynamicPPL.set_retained_vns_del_by_spl!(trace.varinfo, trace.sampler)
63+
return trace
6164
end
6265

63-
function AdvancedPS.reset_model(f::TracedModel)
64-
newvarinfo = deepcopy(f.varinfo)
65-
DynamicPPL.reset_num_produce!(newvarinfo)
66-
return TracedModel(f.model, f.sampler, newvarinfo)
66+
function AdvancedPS.reset_model(trace::TracedModel)
67+
DynamicPPL.reset_num_produce!(trace.varinfo)
68+
return trace
6769
end
6870

69-
function AdvancedPS.reset_logprob!(f::TracedModel)
70-
DynamicPPL.resetlogp!!(f.varinfo)
71-
return
71+
function AdvancedPS.reset_logprob!(trace::TracedModel)
72+
DynamicPPL.resetlogp!!(trace.model.varinfo)
73+
return trace
7274
end
7375

74-
function Libtask.TapedTask(model::TracedModel)
75-
return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...)
76+
function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R}
77+
args = trace.model.ctask.args
78+
_, _, container, = args
79+
rng = container.rng
80+
trace.rng = rng
81+
return trace
7682
end
7783

78-
function Libtask.TapedTask(model::TracedModel, ::Random.AbstractRNG)
79-
return Libtask.TapedTask(model)
84+
function Libtask.TapedTask(model::TracedModel, rng::Random.AbstractRNG; kwargs...)
85+
return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...)
8086
end

0 commit comments

Comments
 (0)