From 5846a4632d50eceab8c3619591c4b6a0a4c0c38e Mon Sep 17 00:00:00 2001 From: torfjelde Date: Thu, 13 Mar 2025 10:21:04 +0000 Subject: [PATCH 1/3] fix for #842 --- src/varinfo.jl | 2 +- test/varinfo.jl | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index f70582428..93d758c95 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -215,7 +215,7 @@ function unflatten(vi::VarInfo, x::AbstractVector) md = unflatten_metadata(vi.metadata, x) # Note that use of RefValue{eltype(x)} rather than Ref is necessary to deal with cases # where e.g. x is a type gradient of some AD backend. - return VarInfo(md, Base.RefValue{eltype(x)}(getlogp(vi)), Ref(get_num_produce(vi))) + return VarInfo(md, Base.RefValue{float_type_with_fallback(eltype(x))}(getlogp(vi)), Ref(get_num_produce(vi))) end # We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in diff --git a/test/varinfo.jl b/test/varinfo.jl index 80eb05480..74feb42f6 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1017,4 +1017,15 @@ end @test vi.metadata.b.orders == [2] @test DynamicPPL.get_num_produce(vi) == 3 end + + @testset "issue #842" begin + model = DynamicPPL.TestUtils.DEMO_MODELS[1] + varinfo = VarInfo(model) + + n = length(varinfo[:]) + # `Bool`. + @test getlogp(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1)) + # `Int`. + @test getlogp(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1)) + end end From b05047b8ec7fb4b81d0375a4ee9256e6b7da05b9 Mon Sep 17 00:00:00 2001 From: torfjelde Date: Thu, 13 Mar 2025 10:21:49 +0000 Subject: [PATCH 2/3] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 863ae9e8b..4b8de9192 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.35.1" +version = "0.35.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 3d3cafececbed37b22c847432c5a764b03a620f4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 13 Mar 2025 10:28:45 +0000 Subject: [PATCH 3/3] Update src/varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/varinfo.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 93d758c95..4151625b6 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -215,7 +215,11 @@ function unflatten(vi::VarInfo, x::AbstractVector) md = unflatten_metadata(vi.metadata, x) # Note that use of RefValue{eltype(x)} rather than Ref is necessary to deal with cases # where e.g. x is a type gradient of some AD backend. - return VarInfo(md, Base.RefValue{float_type_with_fallback(eltype(x))}(getlogp(vi)), Ref(get_num_produce(vi))) + return VarInfo( + md, + Base.RefValue{float_type_with_fallback(eltype(x))}(getlogp(vi)), + Ref(get_num_produce(vi)), + ) end # We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in