Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.34.0"
version = "0.34.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -15,6 +15,9 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
# TODO(penelopeysm,mhauru) KernelAbstractions is only a dependency so that we can pin its version, see
# https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Expand Down Expand Up @@ -55,6 +58,9 @@ Compat = "4"
ConstructionBase = "1.5.4"
Distributions = "0.25"
DocStringExtensions = "0.9"
# TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
# for why KernelAbstractions is pinned like this.
KernelAbstractions = "<= 0.9.31"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
JET = "0.9"
Expand Down
79 changes: 14 additions & 65 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,73 +521,22 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
offset = 0

for (idx, vn) in enumerate(vns_both)
# `idcs`
idcs[vn] = idx
# `vns`
push!(vns, vn)
if vn in vns_left && vn in vns_right
# `vals`: only valid if they're the length.
vals_left = getindex_internal(metadata_left, vn)
vals_right = getindex_internal(metadata_right, vn)
@assert length(vals_left) == length(vals_right)
append!(vals, vals_right)
# `ranges`
r = (offset + 1):(offset + length(vals_left))
push!(ranges, r)
offset = r[end]
# `dists`: only valid if they're the same.
dist_right = getdist(metadata_right, vn)
# Give precedence to `metadata_right`.
push!(dists, dist_right)
gid = metadata_right.gids[getidx(metadata_right, vn)]
push!(gids, gid)
# `orders`: giving precedence to `metadata_right`
push!(orders, getorder(metadata_right, vn))
# `flags`
for k in keys(flags)
# Using `metadata_right`; should we?
push!(flags[k], is_flagged(metadata_right, vn, k))
end
elseif vn in vns_left
# Just extract the metadata from `metadata_left`.
# `vals`
vals_left = getindex_internal(metadata_left, vn)
append!(vals, vals_left)
# `ranges`
r = (offset + 1):(offset + length(vals_left))
push!(ranges, r)
offset = r[end]
# `dists`
dist_left = getdist(metadata_left, vn)
push!(dists, dist_left)
gid = metadata_left.gids[getidx(metadata_left, vn)]
push!(gids, gid)
# `orders`
push!(orders, getorder(metadata_left, vn))
# `flags`
for k in keys(flags)
push!(flags[k], is_flagged(metadata_left, vn, k))
end
else
# Just extract the metadata from `metadata_right`.
# `vals`
vals_right = getindex_internal(metadata_right, vn)
append!(vals, vals_right)
# `ranges`
r = (offset + 1):(offset + length(vals_right))
push!(ranges, r)
offset = r[end]
# `dists`
dist_right = getdist(metadata_right, vn)
push!(dists, dist_right)
gid = metadata_right.gids[getidx(metadata_right, vn)]
push!(gids, gid)
# `orders`
push!(orders, getorder(metadata_right, vn))
# `flags`
for k in keys(flags)
push!(flags[k], is_flagged(metadata_right, vn, k))
end
metadata_for_vn = vn in vns_right ? metadata_right : metadata_left

val = getindex_internal(metadata_for_vn, vn)
append!(vals, val)
r = (offset + 1):(offset + length(val))
push!(ranges, r)
offset = r[end]
dist = getdist(metadata_for_vn, vn)
push!(dists, dist)
gid = metadata_for_vn.gids[getidx(metadata_for_vn, vn)]
push!(gids, gid)
push!(orders, getorder(metadata_for_vn, vn))
for k in keys(flags)
push!(flags[k], is_flagged(metadata_for_vn, vn, k))
end
end

Expand Down
11 changes: 11 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,17 @@ end
@test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left
@test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right
end

# The below used to error, testing to avoid regression.
@testset "merge different dimensions" begin
vn = @varname(x)
vi_single = VarInfo()
vi_single = push!!(vi_single, vn, 1.0, Normal())
vi_double = VarInfo()
vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0))
@test merge(vi_single, vi_double)[vn] == [0.5, 0.6]
@test merge(vi_double, vi_single)[vn] == 1.0
end
end

@testset "VarInfo with selectors" begin
Expand Down
Loading