Skip to content

Commit f6f02ac

Browse files
YongchaoHuanggithub-actions[bot]devmotionyebaiCompatHelper Julia
authored
log probability interface for post-inference analysis (#438)
* extended methods for `logprior`, `loglikelihood`, `logposterior` for chains. * accept Github Actions. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * typed `AbstractChains`; removed Array inputs. * re-formatting. * removed comments to pass formatting test. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * 1. removed the import statements in `lop.jl`; 2. removed the DynamicPPL. namespace declarations of the functions; 3. used Distributions.loglikelihood (instead of `StatsBase.loglikelihood`); 4. moved the tests to test/logp.jl and included them in test/runtests.jl. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: David Widmann <[email protected]> * Update src/logp.jl Co-authored-by: David Widmann <[email protected]> * Update src/logp.jl Co-authored-by: David Widmann <[email protected]> * modified src/logp.jl; added test/logp.jl * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Modified Docstrings; Changed names; Modified methods following Tor's suggestions. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * renamed `chain_logprior`,`chain_loglikelihood',`chain_logposterior' to `logprior`,`loglikelihood',`logposterior' . * added `include("logdensityfunction.jl")` to `DynamicPPL.jl` * formatted `test/logp.jl`. * Update test/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * formatted `scr/logp.jl`. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * formatted `scr/logp.jl`. * Removed comments. * Update src/logp.jl Co-authored-by: David Widmann <[email protected]> * Update src/logp.jl Co-authored-by: David Widmann <[email protected]> * Update src/logp.jl Co-authored-by: David Widmann <[email protected]> * Update src/logp.jl Co-authored-by: David Widmann <[email protected]> * Update src/logp.jl Co-authored-by: David Widmann <[email protected]> * Update src/logp.jl Co-authored-by: David Widmann <[email protected]> * Update test/logp.jl Co-authored-by: David Widmann <[email protected]> * removed redundant methods (NamedTuples and Array inputs). * added REPL examples to docstrings. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added `start_idx` into `src/logp.jl`; rewrite `test/logp.jl` using `map-do`; added `MCMCChains` & `StableRNGs` to `DynamicPPL.jl`. * added `start_idx` to the 3 methods. * Reduced chainn size in the docstrings example. * applied formatting. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * upated signatures in docstrings. * applied formatting. * formatted again. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix doctests setup * Update docs/make.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/runtests.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * applied formatting. * Formatting. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix doc tests again. * Fixed formatting. * Merged `logp.jl` into `model.jl` * CompatHelper: bump compat for Turing to 0.24 for package turing, (keep existing compat) (#448) * CompatHelper: bump compat for Turing to 0.24 for package turing, (keep existing compat) * Update test/turing/Project.toml Co-authored-by: CompatHelper Julia <[email protected]> Co-authored-by: Tor Erlend Fjelde <[email protected]> * CompatHelper: bump compat for Turing to 0.23 for package turing, (keep existing compat) (#439) * CompatHelper: bump compat for Turing to 0.23 for package turing, (keep existing compat) * Update test/turing/Project.toml Co-authored-by: CompatHelper Julia <[email protected]> Co-authored-by: Hong Ge <[email protected]> * Fixed obsolete `TArray` reference. * Fixed incorrect code. * More bugfixes in logp tests. * Avoid calling Turing sampler. * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Replace SampleFromPrior with synthetic chain. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Minor bugfix. * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update Project.toml * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update test/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Update src/model.jl Co-authored-by: David Widmann <[email protected]> * Added `logprior_true(model,NamedTuple)' and `loglikelihood_true(model, NamedTuple)' methods; revised test/model.jl accordingly (removed `MCMCChains.get_param()' ). * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fixed missing prefix and imports. * Move tests into convenience functions. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Removed constraints on floating number precision. * Fix type constraint again. * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * 1. removed `StableRNGs`; 2. replaced `map(1:N) do i` in test/model.jl by `for i in 1:N`. * Bugfix. * Import TestUtils -- it is not exported by DPPL. * Specialise on model type. * Improve test. * Update src/test_utils.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: David Widmann <[email protected]> * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * midified the way chain value was extracted in all 3 methods. * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Update src/model.jl * Update src/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/utils.jl * rewrote the tests (mainly the way extracting parameter values from chain). * removed BangBang from doctest setup; fixed imcomplete end in test/model.jl. * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed a naming bug (argvals_mat_dict) in src/model.jl. * fixed a typo - missing `var_info`. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Explicitly added `using Distributions` in doctests; Accepted suggestion in test/model.jl, tests passed locally. * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * rm unnecessary deps * replace contains with subsumes. * rm redundant deps in docs build script. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix format. * Replaced `subsumes` by `contains`. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * replaced 'contains' by a new, temporary method 'subsumes_sym', just for testing purpose. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * modified `/test/model.jl`: 1. build a map between model parameter symbol (`s`,`m`) and chain parameter names (which is obtained via `varname_leaves`.) 2. use this naming map to collect sample values from chain and drop into `log_true` for validation. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fixed a mistake in `modify_value_representation`. * fixed `gdemo_default` * assigned `model=gdemo_default`. * src/model.jl: added `DynamicPPL.` to `logprior` and `logjoint`. * commented out `gdemo_d()` as a trial test. * used `Symbol(vn_child)` as keys in `chain_sym_map`. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * explicitly loaded `varname_leaves` and `values_from_chain`. * added `print` statements for temporary diagnosis purpose. * added 'print` statements for temporary diagnostics purpose. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * diagnostics again. * Removed some `print` statements as it's working. * Update test/model.jl Co-authored-by: Hong Ge <[email protected]> * Update test/model.jl Co-authored-by: Hong Ge <[email protected]> * Update test/model.jl Co-authored-by: Hong Ge <[email protected]> * Update test/model.jl Co-authored-by: Hong Ge <[email protected]> * 1. moved helper functions to `test_util.jl`; 2. re-wrote the way `chain_mat` can be generated. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * formatting. * formatting. * Update utils.jl * Update test_util.jl * Update Project.toml * replaced 'varname_leaves' by 'DynamicPPL.varname_leaves'. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann <[email protected]> Co-authored-by: Hong Ge <[email protected]> Co-authored-by: CompatHelper Julia <[email protected]> Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Jose Storopoli <[email protected]>
1 parent ba206f4 commit f6f02ac

File tree

6 files changed

+199
-5
lines changed

6 files changed

+199
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.23.7"
3+
version = "0.23.8"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
44
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
55
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
66
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
7+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
78
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
89
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
910
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -17,3 +18,4 @@ LogDensityProblems = "2"
1718
MLUtils = "0.3, 0.4"
1819
Setfield = "0.7.1, 0.8, 1"
1920
StableRNGs = "1"
21+
MCMCChains = "5"

docs/make.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ using DynamicPPL: AbstractPPL
1010
using Distributions
1111

1212
# Doctest setup
13-
DocMeta.setdocmeta!(
14-
DynamicPPL, :DocTestSetup, :(using DynamicPPL, Distributions); recursive=true
15-
)
13+
DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true)
1614

1715
makedocs(;
1816
sitename="DynamicPPL",

src/model.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,42 @@ function logjoint(model::Model, varinfo::AbstractVarInfo)
10591059
return getlogp(last(evaluate!!(model, varinfo, DefaultContext())))
10601060
end
10611061

1062+
"""
1063+
logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
1064+
1065+
Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`.
1066+
1067+
# Examples
1068+
1069+
```jldoctest
1070+
julia> using MCMCChains, Distributions
1071+
1072+
julia> @model function demo_model(x)
1073+
s ~ InverseGamma(2, 3)
1074+
m ~ Normal(0, sqrt(s))
1075+
for i in eachindex(x)
1076+
x[i] ~ Normal(m, sqrt(s))
1077+
end
1078+
end;
1079+
1080+
julia> # construct a chain of samples using MCMCChains
1081+
chain = Chains(rand(10, 2, 3), [:s, :m]);
1082+
1083+
julia> logjoint(demo_model([1., 2.]), chain);
1084+
```
1085+
"""
1086+
function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
1087+
var_info = VarInfo(model) # extract variables info from the model
1088+
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
1089+
argvals_dict = OrderedDict(
1090+
vn_parent =>
1091+
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
1092+
vn_parent in keys(var_info)
1093+
)
1094+
DynamicPPL.logjoint(model, argvals_dict)
1095+
end
1096+
end
1097+
10621098
"""
10631099
logprior(model::Model, varinfo::AbstractVarInfo)
10641100
@@ -1070,6 +1106,42 @@ function logprior(model::Model, varinfo::AbstractVarInfo)
10701106
return getlogp(last(evaluate!!(model, varinfo, PriorContext())))
10711107
end
10721108

1109+
"""
1110+
logprior(model::Model, chain::AbstractMCMC.AbstractChains)
1111+
1112+
Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`.
1113+
1114+
# Examples
1115+
1116+
```jldoctest
1117+
julia> using MCMCChains, Distributions
1118+
1119+
julia> @model function demo_model(x)
1120+
s ~ InverseGamma(2, 3)
1121+
m ~ Normal(0, sqrt(s))
1122+
for i in eachindex(x)
1123+
x[i] ~ Normal(m, sqrt(s))
1124+
end
1125+
end;
1126+
1127+
julia> # construct a chain of samples using MCMCChains
1128+
chain = Chains(rand(10, 2, 3), [:s, :m]);
1129+
1130+
julia> logprior(demo_model([1., 2.]), chain);
1131+
```
1132+
"""
1133+
function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
1134+
var_info = VarInfo(model) # extract variables info from the model
1135+
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
1136+
argvals_dict = OrderedDict(
1137+
vn_parent =>
1138+
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
1139+
vn_parent in keys(var_info)
1140+
)
1141+
DynamicPPL.logprior(model, argvals_dict)
1142+
end
1143+
end
1144+
10731145
"""
10741146
loglikelihood(model::Model, varinfo::AbstractVarInfo)
10751147
@@ -1081,6 +1153,42 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo)
10811153
return getlogp(last(evaluate!!(model, varinfo, LikelihoodContext())))
10821154
end
10831155

1156+
"""
1157+
loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)
1158+
1159+
Return an array of log likelihoods evaluated at each sample in an MCMC `chain`.
1160+
1161+
# Examples
1162+
1163+
```jldoctest
1164+
julia> using MCMCChains, Distributions
1165+
1166+
julia> @model function demo_model(x)
1167+
s ~ InverseGamma(2, 3)
1168+
m ~ Normal(0, sqrt(s))
1169+
for i in eachindex(x)
1170+
x[i] ~ Normal(m, sqrt(s))
1171+
end
1172+
end;
1173+
1174+
julia> # construct a chain of samples using MCMCChains
1175+
chain = Chains(rand(10, 2, 3), [:s, :m]);
1176+
1177+
julia> loglikelihood(demo_model([1., 2.]), chain);
1178+
```
1179+
"""
1180+
function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)
1181+
var_info = VarInfo(model) # extract variables info from the model
1182+
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
1183+
argvals_dict = OrderedDict(
1184+
vn_parent =>
1185+
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
1186+
vn_parent in keys(var_info)
1187+
)
1188+
loglikelihood(model, argvals_dict)
1189+
end
1190+
end
1191+
10841192
"""
10851193
generated_quantities(model::Model, chain::AbstractChains)
10861194

test/model.jl

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ end
2727

2828
@testset "model.jl" begin
2929
@testset "convenience functions" begin
30-
model = gdemo_default
30+
model = gdemo_default # defined in test/test_util.jl
3131

3232
# sample from model and extract variables
3333
vi = VarInfo(model)
@@ -49,6 +49,77 @@ end
4949
ljoint = logjoint(model, vi)
5050
@test ljoint lprior + llikelihood
5151
@test ljoint lp
52+
53+
#### logprior, logjoint, loglikelihood for MCMC chains ####
54+
for model in DynamicPPL.TestUtils.DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12
55+
var_info = VarInfo(model)
56+
vns = DynamicPPL.TestUtils.varnames(model)
57+
syms = unique(DynamicPPL.getsym.(vns))
58+
59+
# generate a chain of sample parameter values.
60+
N = 200
61+
vals_OrderedDict = mapreduce(hcat, 1:N) do _
62+
rand(OrderedDict, model)
63+
end
64+
vals_mat = mapreduce(hcat, 1:N) do i
65+
[vals_OrderedDict[i][vn] for vn in vns]
66+
end
67+
i = 1
68+
for col in eachcol(vals_mat)
69+
col_flattened = []
70+
[push!(col_flattened, x...) for x in col]
71+
if i == 1
72+
chain_mat = Matrix(reshape(col_flattened, 1, length(col_flattened)))
73+
else
74+
chain_mat = vcat(
75+
chain_mat, reshape(col_flattened, 1, length(col_flattened))
76+
)
77+
end
78+
i += 1
79+
end
80+
chain_mat = convert(Matrix{Float64}, chain_mat)
81+
82+
# devise parameter names for chain
83+
sample_values_vec = collect(values(vals_OrderedDict[1]))
84+
symbol_names = []
85+
chain_sym_map = Dict()
86+
for k in 1:length(keys(var_info))
87+
vn_parent = keys(var_info)[k]
88+
sym = DynamicPPL.getsym(vn_parent)
89+
vn_children = DynamicPPL.varname_leaves(vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl
90+
for vn_child in vn_children
91+
chain_sym_map[Symbol(vn_child)] = sym
92+
symbol_names = [symbol_names; Symbol(vn_child)]
93+
end
94+
end
95+
chain = Chains(chain_mat, symbol_names)
96+
97+
# calculate the pointwise loglikelihoods for the whole chain using the newly written functions
98+
logpriors = logprior(model, chain)
99+
loglikelihoods = loglikelihood(model, chain)
100+
logjoints = logjoint(model, chain)
101+
# compare them with true values
102+
for i in 1:N
103+
samples_dict = Dict()
104+
for chain_key in keys(chain)
105+
value = chain[i, chain_key, 1]
106+
key = chain_sym_map[chain_key]
107+
existing_value = get(samples_dict, key, Float64[])
108+
push!(existing_value, value)
109+
samples_dict[key] = existing_value
110+
end
111+
samples = (; samples_dict...)
112+
samples = modify_value_representation(samples) # `modify_value_representation` defined in test/test_util.jl
113+
@test logpriors[i]
114+
DynamicPPL.TestUtils.logprior_true(model, samples[:s], samples[:m])
115+
@test loglikelihoods[i] DynamicPPL.TestUtils.loglikelihood_true(
116+
model, samples[:s], samples[:m]
117+
)
118+
@test logjoints[i]
119+
DynamicPPL.TestUtils.logjoint_true(model, samples[:s], samples[:m])
120+
end
121+
println("\n model $(model) passed !!! \n")
122+
end
52123
end
53124

54125
@testset "rng" begin

test/test_util.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,18 @@ short_varinfo_name(::TypedVarInfo) = "TypedVarInfo"
8282
short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo"
8383
short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}"
8484
short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}"
85+
86+
# convenient functions for testing model.jl
87+
# function to modify the representation of values based on their length
88+
function modify_value_representation(nt::NamedTuple)
89+
modified_nt = NamedTuple()
90+
for (key, value) in zip(keys(nt), values(nt))
91+
if length(value) == 1 # Scalar value
92+
modified_value = value[1]
93+
else # Non-scalar value
94+
modified_value = value
95+
end
96+
modified_nt = merge(modified_nt, (key => modified_value,))
97+
end
98+
return modified_nt
99+
end

0 commit comments

Comments
 (0)