Skip to content

Commit 99df54b

Browse files
committed
Fix for unwrap_left_right_vns (#297)
I just noticed a bug I introduced in a recent PR when looking at #295 . This PR fixes it. I'll add tests, a sec. @yebai
1 parent 5472d9d commit 99df54b

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.13.0"
3+
version = "0.13.1"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/compiler.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,21 @@ left-hand side of a `.~` expression such as `x .~ Normal()`.
9090
9191
This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the
9292
variables.
93+
94+
# Example
95+
```jldoctest; setup=:(using Distributions)
96+
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(1, 1.0), randn(1, 2), @varname(x)); string(vns[end])
97+
"x[:,2]"
98+
99+
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); string(vns[end])
100+
"x[:][1,2]"
101+
102+
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); string(vns[end])
103+
"x[1][3]"
104+
105+
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2, 3), @varname(x)); string(vns[end])
106+
"x[1,2,3]"
107+
```
93108
"""
94109
unwrap_right_left_vns(right, left, vns) = right, left, vns
95110
function unwrap_right_left_vns(right::NamedDist, left, vns)
@@ -103,7 +118,7 @@ function unwrap_right_left_vns(
103118
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
104119
# and we therefore add the `Colon()` below.
105120
vns = map(axes(left, 2)) do i
106-
return VarName(vn, (vn.indexing..., Colon(), Tuple(i)))
121+
return VarName(vn, (vn.indexing..., (Colon(), i)))
107122
end
108123
return unwrap_right_left_vns(right, left, vns)
109124
end

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2020

2121
[compat]
2222
AbstractMCMC = "2.1, 3.0"
23-
AbstractPPL = "0.1.4, 0.2"
23+
AbstractPPL = "0.2"
2424
Bijectors = "0.9.5"
2525
Distributions = "< 0.25.11"
2626
DistributionsAD = "0.6.3"

0 commit comments

Comments
 (0)