Skip to content

Commit 559e622

Browse files
committed
Add missing tovec methods for NamedTuple and Tuple.
1 parent a8a7026 commit 559e622

File tree

5 files changed

+22
-1
lines changed

5 files changed

+22
-1
lines changed

HISTORY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# DynamicPPL Changelog
22

3+
## 0.36.9
4+
5+
Fixed a failure when sampling from `ProductNamedTupleDistribution` due to
6+
missing `tovec` methods for `NamedTuple` and `Tuple`.
7+
38
## 0.36.8
49

510
Made `LogDensityFunction` a subtype of `AbstractMCMC.AbstractModel`.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.36.8"
3+
version = "0.36.9"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,8 @@ to_linked_vec_transform(x) = inverse(from_linked_vec_transform(x))
443443
# or fix `tovec` to flatten the full matrix instead of using `Bijectors.triu_to_vec`.
444444
tovec(x::Real) = [x]
445445
tovec(x::AbstractArray) = vec(x)
446+
tovec(t::Tuple) = mapreduce(tovec, vcat, t)
447+
tovec(nt::NamedTuple) = mapreduce(tovec, vcat, values(nt))
446448
tovec(C::Cholesky) = tovec(Matrix(C.UL))
447449

448450
"""

test/model.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,4 +617,15 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
617617
]) 1.0 * xs_test[2] rtol = 0.1
618618
end
619619
end
620+
621+
@testset "ProductNamedTupleDistribution sampling" begin
622+
priors = (a = Normal(), b = Normal())
623+
d = product_distribution(priors)
624+
@model function sample_nt(priors_dist)
625+
x ~ priors_dist
626+
return x
627+
end
628+
model = sample_nt(d)
629+
@test model() isa NamedTuple
630+
end
620631
end

test/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
dist = LKJCholesky(2, 1)
4848
x = rand(dist)
4949
@test DynamicPPL.tovec(x) == vec(x.UL)
50+
51+
nt = (a = [1, 2], b = 3.0)
52+
@test DynamicPPL.tovec(nt) == [1, 2, 3.0]
5053
end
5154

5255
@testset "unique_syms" begin

0 commit comments

Comments
 (0)