diff --git a/HISTORY.md b/HISTORY.md index 40a671dc1..b8c9d30e0 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,10 @@ # DynamicPPL Changelog +## 0.36.9 + +Fixed a failure when sampling from `ProductNamedTupleDistribution` due to +missing `tovec` methods for `NamedTuple` and `Tuple`. + ## 0.36.8 Made `LogDensityFunction` a subtype of `AbstractMCMC.AbstractModel`. diff --git a/Project.toml b/Project.toml index 2fc1d984c..fd5d20c9b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.36.8" +version = "0.36.9" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/utils.jl b/src/utils.jl index 71919480c..d828fd771 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -443,6 +443,8 @@ to_linked_vec_transform(x) = inverse(from_linked_vec_transform(x)) # or fix `tovec` to flatten the full matrix instead of using `Bijectors.triu_to_vec`. tovec(x::Real) = [x] tovec(x::AbstractArray) = vec(x) +tovec(t::Tuple) = mapreduce(tovec, vcat, t) +tovec(nt::NamedTuple) = mapreduce(tovec, vcat, values(nt)) tovec(C::Cholesky) = tovec(Matrix(C.UL)) """ diff --git a/test/model.jl b/test/model.jl index dd5a35fe6..829ddd302 100644 --- a/test/model.jl +++ b/test/model.jl @@ -617,4 +617,15 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() ]) ≈ 1.0 * xs_test[2] rtol = 0.1 end end + + @testset "ProductNamedTupleDistribution sampling" begin + priors = (a=Normal(), b=Normal()) + d = product_distribution(priors) + @model function sample_nt(priors_dist) + x ~ priors_dist + return x + end + model = sample_nt(d) + @test model() isa NamedTuple + end end diff --git a/test/utils.jl b/test/utils.jl index d683f132d..7a7338fa7 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -47,6 +47,12 @@ dist = LKJCholesky(2, 1) x = rand(dist) @test DynamicPPL.tovec(x) == vec(x.UL) + + nt = (a=[1, 2], b=3.0) + @test DynamicPPL.tovec(nt) == [1, 2, 3.0] + + t = (2.0, [3.0, 4.0]) + @test DynamicPPL.tovec(t) == [2.0, 3.0, 4.0] end @testset "unique_syms" begin