Skip to content

Commit 5472d9d

Browse files
torfjeldeyebai
andcommitted
Some small improvements (#291)
A couple of small improvements/fixes that I noticed recently. Co-authored-by: Hong Ge <[email protected]>
1 parent 5609335 commit 5472d9d

File tree

5 files changed

+14
-10
lines changed

5 files changed

+14
-10
lines changed

src/compiler.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,12 @@ end
9898
function unwrap_right_left_vns(
9999
right::MultivariateDistribution, left::AbstractMatrix, vn::VarName
100100
)
101+
# This an expression such as `x .~ MvNormal()` which we interpret as
102+
# x[:, i] ~ MvNormal()
103+
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
104+
# and we therefore add the `Colon()` below.
101105
vns = map(axes(left, 2)) do i
102-
return VarName(vn, (vn.indexing..., Tuple(i)))
106+
return VarName(vn, (vn.indexing..., Colon(), Tuple(i)))
103107
end
104108
return unwrap_right_left_vns(right, left, vns)
105109
end

src/context_implementations.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
1414
require_gradient(spl::Sampler) = false
1515
require_particles(spl::Sampler) = false
1616

17-
_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds))
17+
_getindex(x, inds::Tuple) = _getindex(view(x, first(inds)...), Base.tail(inds))
1818
_getindex(x, inds::Tuple{}) = x
1919

2020
# assume
@@ -227,11 +227,8 @@ end
227227

228228
# fallback without sampler
229229
function assume(dist::Distribution, vn::VarName, vi)
230-
if !haskey(vi, vn)
231-
error("variable $vn does not exist")
232-
end
233230
r = vi[vn]
234-
return r, Bijectors.logpdf_with_trans(dist, vi[vn], istrans(vi, vn))
231+
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn))
235232
end
236233

237234
# SampleFromPrior and SampleFromUniform
@@ -430,12 +427,13 @@ function dot_assume(
430427
# m .~ Normal()
431428
#
432429
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
433-
r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dist, SampleFromPrior())
430+
r = vi[vns]
434431
lp = sum(zip(vns, eachcol(r))) do vn, ri
435432
return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn))
436433
end
437434
return r, lp
438435
end
436+
439437
function dot_assume(
440438
rng,
441439
spl::Union{SampleFromPrior,SampleFromUniform},
@@ -462,7 +460,7 @@ function dot_assume(
462460
# m .~ Normal()
463461
#
464462
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
465-
r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dists, SampleFromPrior())
463+
r = reshape(vi[vec(vns)], size(vns))
466464
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
467465
return r, lp
468466
end

src/varinfo.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,8 @@ end
501501
end
502502

503503
@inline function findranges(f_ranges, f_idcs)
504+
# Old implementation was using `mapreduce` but turned out
505+
# to be type-unstable.
504506
results = Int[]
505507
for i in f_idcs
506508
append!(results, f_ranges[i])

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2020

2121
[compat]
2222
AbstractMCMC = "2.1, 3.0"
23-
AbstractPPL = "0.1.3"
23+
AbstractPPL = "0.1.4, 0.2"
2424
Bijectors = "0.9.5"
2525
Distributions = "< 0.25.11"
2626
DistributionsAD = "0.6.3"

test/turing/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
55
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
66

77
[compat]
8-
DynamicPPL = "0.12"
8+
DynamicPPL = "0.13"
99
Turing = "0.15, 0.16"
1010
julia = "1.3"

0 commit comments

Comments
 (0)