Skip to content

Commit b711362

Browse files
committed
Simplify tests with make_chain_from_prior
1 parent caf7a7b commit b711362

File tree

1 file changed

+1
-43
lines changed

1 file changed

+1
-43
lines changed

test/model.jl

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -56,49 +56,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
5656

5757
#### logprior, logjoint, loglikelihood for MCMC chains ####
5858
for model in DynamicPPL.TestUtils.DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12
59-
var_info = VarInfo(model)
60-
vns = DynamicPPL.TestUtils.varnames(model)
61-
syms = unique(DynamicPPL.getsym.(vns))
62-
63-
# generate a chain of sample parameter values.
64-
N = 200
65-
vals_OrderedDict = mapreduce(hcat, 1:N) do _
66-
rand(OrderedDict, model)
67-
end
68-
vals_mat = mapreduce(hcat, 1:N) do i
69-
[vals_OrderedDict[i][vn] for vn in vns]
70-
end
71-
i = 1
72-
for col in eachcol(vals_mat)
73-
col_flattened = []
74-
[push!(col_flattened, x...) for x in col]
75-
if i == 1
76-
chain_mat = Matrix(reshape(col_flattened, 1, length(col_flattened)))
77-
else
78-
chain_mat = vcat(
79-
chain_mat, reshape(col_flattened, 1, length(col_flattened))
80-
)
81-
end
82-
i += 1
83-
end
84-
chain_mat = convert(Matrix{Float64}, chain_mat)
85-
86-
# devise parameter names for chain
87-
sample_values_vec = collect(values(vals_OrderedDict[1]))
88-
symbol_names = []
89-
chain_sym_map = Dict()
90-
for k in 1:length(keys(var_info))
91-
vn_parent = keys(var_info)[k]
92-
sym = DynamicPPL.getsym(vn_parent)
93-
vn_children = DynamicPPL.varname_leaves(vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl
94-
for vn_child in vn_children
95-
chain_sym_map[Symbol(vn_child)] = sym
96-
symbol_names = [symbol_names; Symbol(vn_child)]
97-
end
98-
end
99-
chain = Chains(chain_mat, symbol_names)
100-
101-
# calculate the pointwise loglikelihoods for the whole chain using the newly written functions
59+
chain = make_chain_from_prior(chain, 200)
10260
logpriors = logprior(model, chain)
10361
loglikelihoods = loglikelihood(model, chain)
10462
logjoints = logjoint(model, chain)

0 commit comments

Comments
 (0)