@@ -55,50 +55,8 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
55
55
@test ljoint ≈ lp
56
56
57
57
# ### logprior, logjoint, loglikelihood for MCMC chains ####
58
- for model in DynamicPPL. TestUtils. DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12
59
- var_info = VarInfo (model)
60
- vns = DynamicPPL. TestUtils. varnames (model)
61
- syms = unique (DynamicPPL. getsym .(vns))
62
-
63
- # generate a chain of sample parameter values.
64
- N = 200
65
- vals_OrderedDict = mapreduce (hcat, 1 : N) do _
66
- rand (OrderedDict, model)
67
- end
68
- vals_mat = mapreduce (hcat, 1 : N) do i
69
- [vals_OrderedDict[i][vn] for vn in vns]
70
- end
71
- i = 1
72
- for col in eachcol (vals_mat)
73
- col_flattened = []
74
- [push! (col_flattened, x... ) for x in col]
75
- if i == 1
76
- chain_mat = Matrix (reshape (col_flattened, 1 , length (col_flattened)))
77
- else
78
- chain_mat = vcat (
79
- chain_mat, reshape (col_flattened, 1 , length (col_flattened))
80
- )
81
- end
82
- i += 1
83
- end
84
- chain_mat = convert (Matrix{Float64}, chain_mat)
85
-
86
- # devise parameter names for chain
87
- sample_values_vec = collect (values (vals_OrderedDict[1 ]))
88
- symbol_names = []
89
- chain_sym_map = Dict ()
90
- for k in 1 : length (keys (var_info))
91
- vn_parent = keys (var_info)[k]
92
- sym = DynamicPPL. getsym (vn_parent)
93
- vn_children = DynamicPPL. varname_leaves (vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl
94
- for vn_child in vn_children
95
- chain_sym_map[Symbol (vn_child)] = sym
96
- symbol_names = [symbol_names; Symbol (vn_child)]
97
- end
98
- end
99
- chain = Chains (chain_mat, symbol_names)
100
-
101
- # calculate the pointwise loglikelihoods for the whole chain using the newly written functions
58
+ for model in DynamicPPL. TestUtils. DEMO_MODELS
59
+ chain = make_chain_from_prior (model, 200 )
102
60
logpriors = logprior (model, chain)
103
61
loglikelihoods = loglikelihood (model, chain)
104
62
logjoints = logjoint (model, chain)
@@ -125,6 +83,19 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
125
83
end
126
84
end
127
85
86
+ @testset " DynamicPPL#684: threadsafe evaluation with multiple types" begin
87
+ @model function multiple_types (x)
88
+ ns ~ filldist (Normal (0 , 2.0 ), 3 )
89
+ m ~ Uniform (0 , 1 )
90
+ return x ~ Normal (m, 1 )
91
+ end
92
+ model = multiple_types (1 )
93
+ chain = make_chain_from_prior (model, 10 )
94
+ loglikelihood (model, chain)
95
+ logprior (model, chain)
96
+ logjoint (model, chain)
97
+ end
98
+
128
99
@testset " rng" begin
129
100
model = gdemo_default
130
101
0 commit comments