Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.34.0"
version = "0.34.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
79 changes: 14 additions & 65 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,73 +521,22 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
offset = 0

for (idx, vn) in enumerate(vns_both)
# `idcs`
idcs[vn] = idx
# `vns`
push!(vns, vn)
if vn in vns_left && vn in vns_right
# `vals`: only valid if they're the length.
vals_left = getindex_internal(metadata_left, vn)
vals_right = getindex_internal(metadata_right, vn)
@assert length(vals_left) == length(vals_right)
append!(vals, vals_right)
# `ranges`
r = (offset + 1):(offset + length(vals_left))
push!(ranges, r)
offset = r[end]
# `dists`: only valid if they're the same.
dist_right = getdist(metadata_right, vn)
# Give precedence to `metadata_right`.
push!(dists, dist_right)
gid = metadata_right.gids[getidx(metadata_right, vn)]
push!(gids, gid)
# `orders`: giving precedence to `metadata_right`
push!(orders, getorder(metadata_right, vn))
# `flags`
for k in keys(flags)
# Using `metadata_right`; should we?
push!(flags[k], is_flagged(metadata_right, vn, k))
end
elseif vn in vns_left
# Just extract the metadata from `metadata_left`.
# `vals`
vals_left = getindex_internal(metadata_left, vn)
append!(vals, vals_left)
# `ranges`
r = (offset + 1):(offset + length(vals_left))
push!(ranges, r)
offset = r[end]
# `dists`
dist_left = getdist(metadata_left, vn)
push!(dists, dist_left)
gid = metadata_left.gids[getidx(metadata_left, vn)]
push!(gids, gid)
# `orders`
push!(orders, getorder(metadata_left, vn))
# `flags`
for k in keys(flags)
push!(flags[k], is_flagged(metadata_left, vn, k))
end
else
# Just extract the metadata from `metadata_right`.
# `vals`
vals_right = getindex_internal(metadata_right, vn)
append!(vals, vals_right)
# `ranges`
r = (offset + 1):(offset + length(vals_right))
push!(ranges, r)
offset = r[end]
# `dists`
dist_right = getdist(metadata_right, vn)
push!(dists, dist_right)
gid = metadata_right.gids[getidx(metadata_right, vn)]
push!(gids, gid)
# `orders`
push!(orders, getorder(metadata_right, vn))
# `flags`
for k in keys(flags)
push!(flags[k], is_flagged(metadata_right, vn, k))
end
metadata_for_vn = vn in vns_right ? metadata_right : metadata_left

val = getindex_internal(metadata_for_vn, vn)
append!(vals, val)
r = (offset + 1):(offset + length(val))
push!(ranges, r)
offset = r[end]
dist = getdist(metadata_for_vn, vn)
push!(dists, dist)
gid = metadata_for_vn.gids[getidx(metadata_for_vn, vn)]
push!(gids, gid)
push!(orders, getorder(metadata_for_vn, vn))
for k in keys(flags)
push!(flags[k], is_flagged(metadata_for_vn, vn, k))
end
end

Expand Down
11 changes: 11 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,17 @@ end
@test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left
@test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right
end

# The below used to error, testing to avoid regression.
@testset "merge different dimensions" begin
vn = @varname(x)
vi_single = VarInfo()
vi_single = push!!(vi_single, vn, 1.0, Normal())
vi_double = VarInfo()
vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0))
@test merge(vi_single, vi_double)[vn] == [0.5, 0.6]
@test merge(vi_double, vi_single)[vn] == 1.0
end
end

@testset "VarInfo with selectors" begin
Expand Down
Loading