From 759e61e53a79455b12b57c0d10b7713a8f2fc17f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 21 Jan 2025 10:36:44 +0000 Subject: [PATCH 1/8] Add test merging VarInfos with different dimensions for a variable --- test/varinfo.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/varinfo.jl b/test/varinfo.jl index 9a55cffb9..3f3c83ea9 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -869,6 +869,17 @@ end @test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left @test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right end + + # The below used to error, testing to avoid regression. + @testset "merge different dimensions" begin + vn = @varname(x) + vi_single = VarInfo() + vi_single = push!!(vi_single, vn, 1.0, Normal()) + vi_double = VarInfo() + vi_double = push!!(vi_double, vn, [2.0, 3.0], Normal()) + @test merge(vi_single, vi_double)[vn] == 1.0 + @test merge(vi_double, vi_single)[vn] == [2.0, 3.0] + end end @testset "VarInfo with selectors" begin From 87f0159ba10fa9cb15839a8efc824a1f338b17b4 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 21 Jan 2025 10:46:19 +0000 Subject: [PATCH 2/8] Fix merge_metadata for differing dimensions --- src/varinfo.jl | 79 +++++++++----------------------------------------- 1 file changed, 14 insertions(+), 65 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 3ebb505e0..3f36cc293 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -521,73 +521,22 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) offset = 0 for (idx, vn) in enumerate(vns_both) - # `idcs` idcs[vn] = idx - # `vns` push!(vns, vn) - if vn in vns_left && vn in vns_right - # `vals`: only valid if they're the length. - vals_left = getindex_internal(metadata_left, vn) - vals_right = getindex_internal(metadata_right, vn) - @assert length(vals_left) == length(vals_right) - append!(vals, vals_right) - # `ranges` - r = (offset + 1):(offset + length(vals_left)) - push!(ranges, r) - offset = r[end] - # `dists`: only valid if they're the same. - dist_right = getdist(metadata_right, vn) - # Give precedence to `metadata_right`. - push!(dists, dist_right) - gid = metadata_right.gids[getidx(metadata_right, vn)] - push!(gids, gid) - # `orders`: giving precedence to `metadata_right` - push!(orders, getorder(metadata_right, vn)) - # `flags` - for k in keys(flags) - # Using `metadata_right`; should we? - push!(flags[k], is_flagged(metadata_right, vn, k)) - end - elseif vn in vns_left - # Just extract the metadata from `metadata_left`. - # `vals` - vals_left = getindex_internal(metadata_left, vn) - append!(vals, vals_left) - # `ranges` - r = (offset + 1):(offset + length(vals_left)) - push!(ranges, r) - offset = r[end] - # `dists` - dist_left = getdist(metadata_left, vn) - push!(dists, dist_left) - gid = metadata_left.gids[getidx(metadata_left, vn)] - push!(gids, gid) - # `orders` - push!(orders, getorder(metadata_left, vn)) - # `flags` - for k in keys(flags) - push!(flags[k], is_flagged(metadata_left, vn, k)) - end - else - # Just extract the metadata from `metadata_right`. - # `vals` - vals_right = getindex_internal(metadata_right, vn) - append!(vals, vals_right) - # `ranges` - r = (offset + 1):(offset + length(vals_right)) - push!(ranges, r) - offset = r[end] - # `dists` - dist_right = getdist(metadata_right, vn) - push!(dists, dist_right) - gid = metadata_right.gids[getidx(metadata_right, vn)] - push!(gids, gid) - # `orders` - push!(orders, getorder(metadata_right, vn)) - # `flags` - for k in keys(flags) - push!(flags[k], is_flagged(metadata_right, vn, k)) - end + metadata_for_vn = vn in vns_right ? metadata_right : metadata_left + + val = getindex_internal(metadata_for_vn, vn) + append!(vals, val) + r = (offset + 1):(offset + length(val)) + push!(ranges, r) + offset = r[end] + dist = getdist(metadata_for_vn, vn) + push!(dists, dist) + gid = metadata_for_vn.gids[getidx(metadata_for_vn, vn)] + push!(gids, gid) + push!(orders, getorder(metadata_for_vn, vn)) + for k in keys(flags) + push!(flags[k], is_flagged(metadata_for_vn, vn, k)) end end From 0909d57d9b6cbdec12009ea08c39897d5d01630d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 21 Jan 2025 10:47:16 +0000 Subject: [PATCH 3/8] Bump patch version to 0.34.1. --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fb9a1c55f..ec0c719b9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.34.0" +version = "0.34.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 8806597f79b4f91b5e7d4a16fc7a5b63ba346170 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 21 Jan 2025 10:49:59 +0000 Subject: [PATCH 4/8] Fix test --- test/varinfo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index 3f3c83ea9..ec177669e 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -877,8 +877,8 @@ end vi_single = push!!(vi_single, vn, 1.0, Normal()) vi_double = VarInfo() vi_double = push!!(vi_double, vn, [2.0, 3.0], Normal()) - @test merge(vi_single, vi_double)[vn] == 1.0 - @test merge(vi_double, vi_single)[vn] == [2.0, 3.0] + @test merge(vi_single, vi_double)[vn] == [2.0, 3.0] + @test merge(vi_double, vi_single)[vn] == 1.0 end end From 309dfc54a9f929ea2c2ae33433eb6171e1b4a1c8 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 21 Jan 2025 10:53:54 +0000 Subject: [PATCH 5/8] Fix test more --- test/varinfo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index ec177669e..fce87b2f3 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -876,8 +876,8 @@ end vi_single = VarInfo() vi_single = push!!(vi_single, vn, 1.0, Normal()) vi_double = VarInfo() - vi_double = push!!(vi_double, vn, [2.0, 3.0], Normal()) - @test merge(vi_single, vi_double)[vn] == [2.0, 3.0] + vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0)) + @test merge(vi_single, vi_double)[vn] == [0.5, 0.6] @test merge(vi_double, vi_single)[vn] == 1.0 end end From 190ac890f822481d50cf44f9f05aa15eb1fe6843 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 22 Jan 2025 10:21:11 +0000 Subject: [PATCH 6/8] Pin KernelAbstractions to v0.9.31 --- Project.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Project.toml b/Project.toml index ec0c719b9..c6ce52c54 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,9 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +# TODO(penelopeysm,mhauru) KernelAbstractions is only a dependency so that we can pin its version, see +# https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767 +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" @@ -55,6 +58,9 @@ Compat = "4" ConstructionBase = "1.5.4" Distributions = "0.25" DocStringExtensions = "0.9" +# TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767 +# for why KernelAbstractions is pinned like this. +KernelAbstractions = "=0.9.31" EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10" JET = "0.9" From 229afd0164cd39d7f8de44f04d5442ae7414a1fd Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 22 Jan 2025 11:20:27 +0000 Subject: [PATCH 7/8] Make KernelAbstractions version bound an upper bound Co-authored-by: Penelope Yong --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c6ce52c54..9c1a43123 100644 --- a/Project.toml +++ b/Project.toml @@ -60,7 +60,7 @@ Distributions = "0.25" DocStringExtensions = "0.9" # TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767 # for why KernelAbstractions is pinned like this. -KernelAbstractions = "=0.9.31" +KernelAbstractions = "<= 0.9.31" EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10" JET = "0.9" From 0547f1bbe80e48f337f3377d1ca634813a051c58 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 22 Jan 2025 11:24:59 +0000 Subject: [PATCH 8/8] Fix syntax --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9c1a43123..bd553c0cc 100644 --- a/Project.toml +++ b/Project.toml @@ -60,7 +60,7 @@ Distributions = "0.25" DocStringExtensions = "0.9" # TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767 # for why KernelAbstractions is pinned like this. -KernelAbstractions = "<= 0.9.31" +KernelAbstractions = "< 0.9.32" EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10" JET = "0.9"