|
| 1 | +# model_interface.jl |
| 2 | +# ------------------ |
| 3 | +# |
| 4 | +# This file contains the functions that the models inside models.jl should |
| 5 | +# implement. |
| 6 | + |
| 7 | +""" |
| 8 | + logprior_true(model, args...) |
| 9 | +
|
| 10 | +Return the `logprior` of `model` for `args`. |
| 11 | +
|
| 12 | +This should generally be implemented by hand for every specific `model`. |
| 13 | +
|
| 14 | +See also: [`logjoint_true`](@ref), [`loglikelihood_true`](@ref). |
| 15 | +""" |
| 16 | +function logprior_true end |
| 17 | + |
| 18 | +""" |
| 19 | + loglikelihood_true(model, args...) |
| 20 | +
|
| 21 | +Return the `loglikelihood` of `model` for `args`. |
| 22 | +
|
| 23 | +This should generally be implemented by hand for every specific `model`. |
| 24 | +
|
| 25 | +See also: [`logjoint_true`](@ref), [`logprior_true`](@ref). |
| 26 | +""" |
| 27 | +function loglikelihood_true end |
| 28 | + |
| 29 | +""" |
| 30 | + logjoint_true(model, args...) |
| 31 | +
|
| 32 | +Return the `logjoint` of `model` for `args`. |
| 33 | +
|
| 34 | +Defaults to `logprior_true(model, args...) + loglikelihood_true(model, args..)`. |
| 35 | +
|
| 36 | +This should generally be implemented by hand for every specific `model` |
| 37 | +so that the returned value can be used as a ground-truth for testing things like: |
| 38 | +
|
| 39 | +1. Validity of evaluation of `model` using a particular implementation of `AbstractVarInfo`. |
| 40 | +2. Validity of a sampler when combined with DynamicPPL by running the sampler twice: once targeting ground-truth functions, e.g. `logjoint_true`, and once targeting `model`. |
| 41 | +
|
| 42 | +And more. |
| 43 | +
|
| 44 | +See also: [`logprior_true`](@ref), [`loglikelihood_true`](@ref). |
| 45 | +""" |
| 46 | +function logjoint_true(model::Model, args...) |
| 47 | + return logprior_true(model, args...) + loglikelihood_true(model, args...) |
| 48 | +end |
| 49 | + |
| 50 | +""" |
| 51 | + logjoint_true_with_logabsdet_jacobian(model::Model, args...) |
| 52 | +
|
| 53 | +Return a tuple `(args_unconstrained, logjoint)` of `model` for `args`. |
| 54 | +
|
| 55 | +Unlike [`logjoint_true`](@ref), the returned logjoint computation includes the |
| 56 | +log-absdet-jacobian adjustment, thus computing logjoint for the unconstrained variables. |
| 57 | +
|
| 58 | +Note that `args` are assumed be in the support of `model`, while `args_unconstrained` |
| 59 | +will be unconstrained. |
| 60 | +
|
| 61 | +This should generally not be implemented directly, instead one should implement |
| 62 | +[`logprior_true_with_logabsdet_jacobian`](@ref) for a given `model`. |
| 63 | +
|
| 64 | +See also: [`logjoint_true`](@ref), [`logprior_true_with_logabsdet_jacobian`](@ref). |
| 65 | +""" |
| 66 | +function logjoint_true_with_logabsdet_jacobian(model::Model, args...) |
| 67 | + args_unconstrained, lp = logprior_true_with_logabsdet_jacobian(model, args...) |
| 68 | + return args_unconstrained, lp + loglikelihood_true(model, args...) |
| 69 | +end |
| 70 | + |
| 71 | +""" |
| 72 | + logprior_true_with_logabsdet_jacobian(model::Model, args...) |
| 73 | +
|
| 74 | +Return a tuple `(args_unconstrained, logprior_unconstrained)` of `model` for `args...`. |
| 75 | +
|
| 76 | +Unlike [`logprior_true`](@ref), the returned logprior computation includes the |
| 77 | +log-absdet-jacobian adjustment, thus computing logprior for the unconstrained variables. |
| 78 | +
|
| 79 | +Note that `args` are assumed be in the support of `model`, while `args_unconstrained` |
| 80 | +will be unconstrained. |
| 81 | +
|
| 82 | +See also: [`logprior_true`](@ref). |
| 83 | +""" |
| 84 | +function logprior_true_with_logabsdet_jacobian end |
| 85 | + |
| 86 | +""" |
| 87 | + varnames(model::Model) |
| 88 | +
|
| 89 | +Return a collection of `VarName` as they are expected to appear in the model. |
| 90 | +
|
| 91 | +Even though it is recommended to implement this by hand for a particular `Model`, |
| 92 | +a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. |
| 93 | +""" |
| 94 | +function varnames(model::Model) |
| 95 | + return collect( |
| 96 | + keys(last(DynamicPPL.evaluate!!(model, SimpleVarInfo(Dict()), SamplingContext()))) |
| 97 | + ) |
| 98 | +end |
| 99 | + |
| 100 | +""" |
| 101 | + posterior_mean(model::Model) |
| 102 | +
|
| 103 | +Return a `NamedTuple` compatible with `varnames(model)` where the values represent |
| 104 | +the posterior mean under `model`. |
| 105 | +
|
| 106 | +"Compatible" means that a `varname` from `varnames(model)` can be used to extract the |
| 107 | +corresponding value using `get`, e.g. `get(posterior_mean(model), varname)`. |
| 108 | +""" |
| 109 | +function posterior_mean end |
| 110 | + |
| 111 | +""" |
| 112 | + rand_prior_true([rng::AbstractRNG, ]model::DynamicPPL.Model) |
| 113 | +
|
| 114 | +Return a `NamedTuple` of realizations from the prior of `model` compatible with `varnames(model)`. |
| 115 | +""" |
| 116 | +function rand_prior_true(model::DynamicPPL.Model) |
| 117 | + return rand_prior_true(Random.default_rng(), model) |
| 118 | +end |
0 commit comments