Skip to content

Commit 9937cb3

Browse files
torfjeldeyebai
andcommitted
Perform invlinking in assume rather than implicitly in getindex (#360)
Currently, in `assume`, etc., `invlink` is called implicitly in `getindex` using the distribution extracted from `vi`. This has a couple of drawbacks: 1. We can only use the distribution for a particular `vn` stored in `vi` obtained during the initial run. This means that we can't even run models where the distributions has dynamic domains, i.e. the domain of a particular random variable is dependent on the realizations of other random variables. 2. We have to store the distribution for each `vn` in `vi`. This was fine when we only had `VarInfo` because we also need it for other functionality, but this is not the case in `SimpleVarInfo` (nor will it be). So. In this PR we introduce a `getindex_raw` which is `getindex` but without `invlink` if it's already linked, and uses this within `assume`, etc. where we now use the distributions that are passed to `assume` rather than those stored in `vi`. E.g. the following now works: ``` julia julia> @model demo() = x ~ InverseGamma(2, 3) demo (generic function with 2 methods) julia> vi = SimpleVarInfo((x = 10.0, ), true) SimpleVarInfo((x = 10.0,), 0.0, true) julia> _, vi = DynamicPPL.evaluate!!(model, vi, DefaultContext()) (22026.465794806718, SimpleVarInfo{NamedTuple{(:x,), Tuple{Float64}}, Float64}((x = 10.0,), -17.80291162245307, true)) ``` Co-authored-by: Hong Ge <[email protected]>
1 parent 9ecf3dc commit 9937cb3

20 files changed

+1068
-329
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
strategy:
1717
matrix:
1818
version:
19-
- '1.3' # minimum supported version
19+
- '1.6' # minimum supported version
2020
- '1' # current stable version
2121
os:
2222
- ubuntu-latest

.github/workflows/IntegrationTest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
# force it to use this PR's version of the package
4141
Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps
4242
Pkg.update()
43-
Pkg.test() # resolver may fail with test time deps
43+
Pkg.test(julia_args=["--depwarn=no"]) # resolver may fail with test time deps
4444
catch err
4545
err isa Pkg.Resolve.ResolverError || rethrow()
4646
# If we can't resolve that means this is incompatible by SemVer and this is fine

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.19.3"
3+
version = "0.20.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
77
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
88
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
99
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
1010
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
11+
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1112
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
13+
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1214
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1315
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1416
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -22,8 +24,10 @@ AbstractPPL = "0.5.1"
2224
BangBang = "0.3"
2325
Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9, 0.10"
2426
ChainRulesCore = "0.9.7, 0.10, 1"
27+
ConstructionBase = "1"
2528
Distributions = "0.23.8, 0.24, 0.25"
29+
DocStringExtensions = "0.8"
2630
MacroTools = "0.5.6"
2731
Setfield = "0.7.1, 0.8"
2832
ZygoteRules = "0.2"
29-
julia = "1.3"
33+
julia = "1.6"

docs/src/api.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,14 @@ NamedDist
103103
DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule.
104104

105105
```@docs
106-
DynamicPPL.TestUtils.test_sampler_demo_models
106+
DynamicPPL.TestUtils.test_sampler
107+
DynamicPPL.TestUtils.test_sampler_on_demo_models
107108
DynamicPPL.TestUtils.test_sampler_continuous
109+
DynamicPPL.TestUtils.marginal_mean_of_samples
110+
```
111+
112+
```@docs
113+
DynamicPPL.TestUtils.DEMO_MODELS
108114
```
109115

110116
For every demo model, one can define the true log prior, log likelihood, and log joint probabilities.
@@ -115,6 +121,20 @@ DynamicPPL.TestUtils.loglikelihood_true
115121
DynamicPPL.TestUtils.logjoint_true
116122
```
117123

124+
And in the case where the model includes constrained variables, it can also be useful to define
125+
126+
```@docs
127+
DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian
128+
DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian
129+
```
130+
131+
Finally, the following methods can also be of use:
132+
133+
```@docs
134+
DynamicPPL.TestUtils.varnames
135+
DynamicPPL.TestUtils.posterior_mean
136+
```
137+
118138
## Advanced
119139

120140
### Variable names

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ using MacroTools: MacroTools
1212
using Setfield: Setfield
1313
using ZygoteRules: ZygoteRules
1414

15+
using DocStringExtensions
16+
1517
using Random: Random
1618

1719
import Base:

src/context_implementations.jl

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ end
5555
function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi)
5656
if haskey(context.vars, getsym(vn))
5757
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
58-
settrans!(vi, false, vn)
58+
settrans!!(vi, false, vn)
5959
end
6060
return tilde_assume(PriorContext(), right, vn, vi)
6161
end
@@ -64,15 +64,15 @@ function tilde_assume(
6464
)
6565
if haskey(context.vars, getsym(vn))
6666
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
67-
settrans!(vi, false, vn)
67+
settrans!!(vi, false, vn)
6868
end
6969
return tilde_assume(rng, PriorContext(), sampler, right, vn, vi)
7070
end
7171

7272
function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi)
7373
if haskey(context.vars, getsym(vn))
7474
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
75-
settrans!(vi, false, vn)
75+
settrans!!(vi, false, vn)
7676
end
7777
return tilde_assume(LikelihoodContext(), right, vn, vi)
7878
end
@@ -86,7 +86,7 @@ function tilde_assume(
8686
)
8787
if haskey(context.vars, getsym(vn))
8888
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
89-
settrans!(vi, false, vn)
89+
settrans!!(vi, false, vn)
9090
end
9191
return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi)
9292
end
@@ -194,7 +194,7 @@ end
194194

195195
# fallback without sampler
196196
function assume(dist::Distribution, vn::VarName, vi)
197-
r = vi[vn]
197+
r = vi[vn, dist]
198198
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
199199
end
200200

@@ -211,16 +211,21 @@ function assume(
211211
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
212212
unset_flag!(vi, vn, "del")
213213
r = init(rng, dist, sampler)
214-
vi[vn] = vectorize(dist, r)
215-
settrans!(vi, false, vn)
214+
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r))
216215
setorder!(vi, vn, get_num_produce(vi))
217216
else
218-
r = vi[vn]
217+
# Otherwise we just extract it.
218+
r = vi[vn, dist]
219219
end
220220
else
221221
r = init(rng, dist, sampler)
222-
push!!(vi, vn, r, dist, sampler)
223-
settrans!(vi, false, vn)
222+
if istrans(vi)
223+
push!!(vi, vn, link(dist, r), dist, sampler)
224+
# By default `push!!` sets the transformed flag to `false`.
225+
settrans!!(vi, true, vn)
226+
else
227+
push!!(vi, vn, r, dist, sampler)
228+
end
224229
end
225230

226231
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
@@ -286,7 +291,7 @@ function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left,
286291
var = get(context.vars, vn)
287292
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
288293
set_val!(vi, _vns, _right, _left)
289-
settrans!.(Ref(vi), false, _vns)
294+
settrans!!.((vi,), false, _vns)
290295
dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi)
291296
else
292297
dot_tilde_assume(LikelihoodContext(), right, left, vn, vi)
@@ -305,19 +310,20 @@ function dot_tilde_assume(
305310
var = get(context.vars, vn)
306311
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
307312
set_val!(vi, _vns, _right, _left)
308-
settrans!.(Ref(vi), false, _vns)
313+
settrans!!.((vi,), false, _vns)
309314
dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi)
310315
else
311316
dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi)
312317
end
313318
end
319+
314320
function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi)
315-
return dot_assume(NoDist.(right), left, vn, vi)
321+
return dot_assume(nodist(right), left, vn, vi)
316322
end
317323
function dot_tilde_assume(
318324
rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi
319325
)
320-
return dot_assume(rng, sampler, NoDist.(right), vn, left, vi)
326+
return dot_assume(rng, sampler, nodist(right), vn, left, vi)
321327
end
322328

323329
# `PriorContext`
@@ -326,7 +332,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn,
326332
var = get(context.vars, vn)
327333
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
328334
set_val!(vi, _vns, _right, _left)
329-
settrans!.(Ref(vi), false, _vns)
335+
settrans!!.((vi,), false, _vns)
330336
dot_tilde_assume(PriorContext(), _right, _left, _vns, vi)
331337
else
332338
dot_tilde_assume(PriorContext(), right, left, vn, vi)
@@ -345,7 +351,7 @@ function dot_tilde_assume(
345351
var = get(context.vars, vn)
346352
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
347353
set_val!(vi, _vns, _right, _left)
348-
settrans!.(Ref(vi), false, _vns)
354+
settrans!!.((vi,), false, _vns)
349355
dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi)
350356
else
351357
dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi)
@@ -383,14 +389,14 @@ function dot_assume(
383389
vns::AbstractVector{<:VarName},
384390
vi::AbstractVarInfo,
385391
)
386-
@assert length(dist) == size(var, 1)
392+
@assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))"
387393
# NOTE: We cannot work with `var` here because we might have a model of the form
388394
#
389395
# m = Vector{Float64}(undef, n)
390396
# m .~ Normal()
391397
#
392398
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
393-
r = vi[vns]
399+
r = vi[vns, dist]
394400
lp = sum(zip(vns, eachcol(r))) do (vn, ri)
395401
return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn))
396402
end
@@ -412,19 +418,21 @@ function dot_assume(
412418
end
413419

414420
function dot_assume(
415-
dists::Union{Distribution,AbstractArray{<:Distribution}},
421+
dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi
422+
)
423+
r = getindex.((vi,), vns, (dist,))
424+
lp = sum(Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns)))
425+
return r, lp, vi
426+
end
427+
428+
function dot_assume(
429+
dists::AbstractArray{<:Distribution},
416430
var::AbstractArray,
417431
vns::AbstractArray{<:VarName},
418432
vi,
419433
)
420-
# NOTE: We cannot work with `var` here because we might have a model of the form
421-
#
422-
# m = Vector{Float64}(undef, n)
423-
# m .~ Normal()
424-
#
425-
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
426-
r = reshape(vi[vec(vns)], size(vns))
427-
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
434+
r = getindex.((vi,), vns, dists)
435+
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns)))
428436
return r, lp, vi
429437
end
430438

@@ -438,7 +446,7 @@ function dot_assume(
438446
)
439447
r = get_and_set_val!(rng, vi, vns, dists, spl)
440448
# Make sure `r` is not a matrix for multivariate distributions
441-
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
449+
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns)))
442450
return r, lp, vi
443451
end
444452
function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any)
@@ -462,19 +470,23 @@ function get_and_set_val!(
462470
r = init(rng, dist, spl, n)
463471
for i in 1:n
464472
vn = vns[i]
465-
vi[vn] = vectorize(dist, r[:, i])
466-
settrans!(vi, false, vn)
473+
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[:, i]))
467474
setorder!(vi, vn, get_num_produce(vi))
468475
end
469476
else
470-
r = vi[vns]
477+
r = vi[vns, dist]
471478
end
472479
else
473480
r = init(rng, dist, spl, n)
474481
for i in 1:n
475482
vn = vns[i]
476-
push!!(vi, vn, r[:, i], dist, spl)
477-
settrans!(vi, false, vn)
483+
if istrans(vi)
484+
push!!(vi, vn, Bijectors.link(dist, r[:, i]), dist, spl)
485+
# `push!!` sets the trans-flag to `false` by default.
486+
settrans!!(vi, true, vn)
487+
else
488+
push!!(vi, vn, r[:, i], dist, spl)
489+
end
478490
end
479491
end
480492
return r
@@ -496,12 +508,13 @@ function get_and_set_val!(
496508
for i in eachindex(vns)
497509
vn = vns[i]
498510
dist = dists isa AbstractArray ? dists[i] : dists
499-
vi[vn] = vectorize(dist, r[i])
500-
settrans!(vi, false, vn)
511+
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[i]))
501512
setorder!(vi, vn, get_num_produce(vi))
502513
end
503514
else
504-
r = reshape(vi[vec(vns)], size(vns))
515+
# r = reshape(vi[vec(vns)], size(vns))
516+
r_raw = getindex_raw(vi, vec(vns))
517+
r = maybe_invlink.((vi,), vns, dists, reshape(r_raw, size(vns)))
505518
end
506519
else
507520
f = (vn, dist) -> init(rng, dist, spl)
@@ -511,8 +524,13 @@ function get_and_set_val!(
511524
# 1. Figure out the broadcast size and use a `foreach`.
512525
# 2. Define an anonymous function which returns `nothing`, which
513526
# we then broadcast. This will allocate a vector of `nothing` though.
514-
push!!.(Ref(vi), vns, r, dists, Ref(spl))
515-
settrans!.(Ref(vi), false, vns)
527+
if istrans(vi)
528+
push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,))
529+
# `push!!` sets the trans-flag to `false` by default.
530+
settrans!!.((vi,), true, vns)
531+
else
532+
push!!.((vi,), vns, r, dists, (spl,))
533+
end
516534
end
517535
return r
518536
end

src/distribution_wrappers.jl

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ end
1313

1414
NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName{name}())
1515

16+
Base.length(dist::NamedDist) = Base.length(dist.dist)
17+
Base.size(dist::NamedDist) = Base.size(dist.dist)
18+
1619
Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x)
1720
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real})
1821
return Distributions.logpdf(dist.dist, x)
@@ -24,12 +27,20 @@ function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real})
2427
return Distributions.loglikelihood(dist.dist, x)
2528
end
2629

30+
Bijectors.bijector(d::NamedDist) = Bijectors.bijector(d.dist)
31+
2732
struct NoDist{variate,support,Td<:Distribution{variate,support}} <:
2833
Distribution{variate,support}
2934
dist::Td
3035
end
3136
NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name)
3237

38+
nodist(dist::Distribution) = NoDist(dist)
39+
nodist(dists::AbstractArray) = nodist.(dists)
40+
41+
Base.length(dist::NoDist) = Base.length(dist.dist)
42+
Base.size(dist::NoDist) = Base.size(dist.dist)
43+
3344
Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist)
3445
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
3546
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
@@ -40,9 +51,21 @@ Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0
4051
Distributions.minimum(d::NoDist) = minimum(d.dist)
4152
Distributions.maximum(d::NoDist) = maximum(d.dist)
4253

43-
Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real) = 0
44-
Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
45-
function Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
54+
Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0
55+
function Bijectors.logpdf_with_trans(
56+
d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}, ::Bool
57+
)
58+
return 0
59+
end
60+
function Bijectors.logpdf_with_trans(
61+
d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool
62+
)
4663
return zeros(Int, size(x, 2))
4764
end
48-
Bijectors.logpdf_with_trans(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0
65+
function Bijectors.logpdf_with_trans(
66+
d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}, ::Bool
67+
)
68+
return 0
69+
end
70+
71+
Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist)

0 commit comments

Comments
 (0)