Skip to content

Commit c9489aa

Browse files
Allow merge to work on VarInfo with different distributions (#562)
* allow different distributions in merge * added tests for merge with different distributions * bump patch version since this is a bug * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 485ebfb commit c9489aa

File tree

3 files changed

+32
-9
lines changed

3 files changed

+32
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.24.0"
3+
version = "0.24.1"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/varinfo.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -397,10 +397,9 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
397397
push!(ranges, r)
398398
offset = r[end]
399399
# `dists`: only valid if they're the same.
400-
dists_left = getdist(metadata_left, vn)
401-
dists_right = getdist(metadata_right, vn)
402-
@assert dists_left == dists_right
403-
push!(dists, dists_left)
400+
dist_right = getdist(metadata_right, vn)
401+
# Give precedence to `metadata_right`.
402+
push!(dists, dist_right)
404403
# `orders`: giving precedence to `metadata_right`
405404
push!(orders, getorder(metadata_right, vn))
406405
# `flags`
@@ -418,8 +417,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
418417
push!(ranges, r)
419418
offset = r[end]
420419
# `dists`
421-
dists_left = getdist(metadata_left, vn)
422-
push!(dists, dists_left)
420+
dist_left = getdist(metadata_left, vn)
421+
push!(dists, dist_left)
423422
# `orders`
424423
push!(orders, getorder(metadata_left, vn))
425424
# `flags`
@@ -436,8 +435,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
436435
push!(ranges, r)
437436
offset = r[end]
438437
# `dists`
439-
dists_right = getdist(metadata_right, vn)
440-
push!(dists, dists_right)
438+
dist_right = getdist(metadata_right, vn)
439+
push!(dists, dist_right)
441440
# `orders`
442441
push!(orders, getorder(metadata_right, vn))
443442
# `flags`

test/varinfo.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,30 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
594594
end
595595
end
596596
end
597+
598+
@testset "different models" begin
599+
@model function demo_merge_different_y()
600+
x ~ Uniform()
601+
return y ~ Normal()
602+
end
603+
@model function demo_merge_different_z()
604+
x ~ Normal()
605+
return z ~ Normal()
606+
end
607+
model_left = demo_merge_different_y()
608+
model_right = demo_merge_different_z()
609+
610+
varinfo_left = VarInfo(model_left)
611+
varinfo_right = VarInfo(model_right)
612+
613+
varinfo_merged = merge(varinfo_left, varinfo_right)
614+
vns = [@varname(x), @varname(y), @varname(z)]
615+
check_varinfo_keys(varinfo_merged, vns)
616+
617+
# Right has precedence.
618+
@test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)]
619+
@test DynamicPPL.getdist(varinfo_merged, @varname(x)) isa Normal
620+
end
597621
end
598622

599623
@testset "VarInfo with selectors" begin

0 commit comments

Comments
 (0)