diff --git a/Project.toml b/Project.toml index 863ae9e8b..4b8de9192 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.35.1" +version = "0.35.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/varinfo.jl b/src/varinfo.jl index f70582428..4151625b6 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -215,7 +215,11 @@ function unflatten(vi::VarInfo, x::AbstractVector) md = unflatten_metadata(vi.metadata, x) # Note that use of RefValue{eltype(x)} rather than Ref is necessary to deal with cases # where e.g. x is a type gradient of some AD backend. - return VarInfo(md, Base.RefValue{eltype(x)}(getlogp(vi)), Ref(get_num_produce(vi))) + return VarInfo( + md, + Base.RefValue{float_type_with_fallback(eltype(x))}(getlogp(vi)), + Ref(get_num_produce(vi)), + ) end # We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in diff --git a/test/varinfo.jl b/test/varinfo.jl index 80eb05480..74feb42f6 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1017,4 +1017,15 @@ end @test vi.metadata.b.orders == [2] @test DynamicPPL.get_num_produce(vi) == 3 end + + @testset "issue #842" begin + model = DynamicPPL.TestUtils.DEMO_MODELS[1] + varinfo = VarInfo(model) + + n = length(varinfo[:]) + # `Bool`. + @test getlogp(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1)) + # `Int`. + @test getlogp(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1)) + end end