diff --git a/HISTORY.md b/HISTORY.md index 90db022e7..fa9e58e99 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,74 @@ **Breaking** +### `.~` right hand side must be a univariate distribution + +Previously we allowed statements like + +```julia +x .~ [Normal(), Gamma()] +``` + +where the right hand side of a `.~` was an array of distributions, and ones like + +```julia +x .~ MvNormal(fill(0.0, 2), I) +``` + +where the right hand side was a multivariate distribution. + +These are no longer allowed. The only things allowed on the right hand side of a `.~` statement are univariate distributions, such as + +```julia +x = Array{Float64,3}(undef, 2, 3, 4) +x .~ Normal() +``` + +The reasons for this are internal code simplification and the fact that broadcasting where both sides are multidimensional but of different dimensions is typically confusing to read. + +If the right hand side and the left hand side have the same dimension, one can simply use `~`. Arrays of distributions can be replaced with `product_distribution`. So instead of + +```julia +x .~ [Normal(), Gamma()] +x .~ Normal.(y) +x .~ MvNormal(fill(0.0, 2), I) +``` + +do + +```julia +x ~ product_distribution([Normal(), Gamma()]) +x ~ product_distribution(Normal.(y)) +x ~ MvNormal(fill(0.0, 2), I) +``` + +This is often more performant as well. Note that using `~` rather than `.~` does change the internal storage format a bit: With `.~` `x[i]` are stored as separate variables, with `~` as a single multivariate variable `x`. In most cases this does not change anything for the user, but if it does cause issues, e.g. if you are dealing with `VarInfo` objects directly and need to keep the old behavior, you can always expand into a loop, such as + +```julia +dists = Normal.(y) +for i in 1:length(dists) + x[i] ~ dists[i] +end +``` + +Cases where the right hand side is of a different dimension than the left hand side, and neither is a scalar, must be replaced with a loop. For example, + +```julia +x = Array{Float64,3}(undef, 2, 3, 4) +x .~ MvNormal(fill(0, 2), I) +``` + +should be replaced with something like + +```julia +x = Array{Float64,3}(2, 3, 4) +for i in 1:3, j in 1:4 + x[:, i, j] ~ MvNormal(fill(0, 2), I) +end +``` + +This release also completely rewrites the internal implementation of `.~`, where from now on all `.~` statements are turned into loops over `~` statements at macro time. However, the only breaking aspect of this change is the above change to what's allowed on the right hand side. + ### Remove indexing by samplers This release removes the feature of `VarInfo` where it kept track of which variable was associated with which sampler. This means removing all user-facing methods where `VarInfo`s where being indexed with samplers. In particular, @@ -14,7 +82,7 @@ This release removes the feature of `VarInfo` where it kept track of which varia - `unflatten` no longer accepts a sampler as an argument - `eltype(::VarInfo)` no longer accepts a sampler as an argument - `keys(::VarInfo)` no longer accepts a sampler as an argument - - `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument. + - `VarInfo(::VarInfo, ::Sampler, ::AbstractVector)` no longer accepts the sampler argument. - `push!!` and `push!` no longer accept samplers or `Selector`s as arguments - `getgid`, `setgid!`, `updategid!`, `getspace`, and `inspace` no longer exist diff --git a/Project.toml b/Project.toml index be4586246..7cd47fdbb 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,6 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -33,7 +32,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] @@ -42,7 +40,6 @@ DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLJETExt = ["JET"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLMooncakeExt = ["Mooncake"] -DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] ADTypes = "1" @@ -65,10 +62,9 @@ LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" MCMCChains = "6" MacroTools = "0.5.6" -Mooncake = "0.4.59" +Mooncake = "0.4.95" OrderedCollections = "1" Random = "1.6" Requires = "1" Test = "1.6" -ZygoteRules = "0.2" julia = "1.10" diff --git a/docs/src/api.md b/docs/src/api.md index b9cafaaf4..4d3c6bc97 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -440,10 +440,8 @@ DynamicPPL.Experimental.is_suitable_varinfo ```@docs tilde_assume -dot_tilde_assume ``` ```@docs tilde_observe -dot_tilde_observe ``` diff --git a/ext/DynamicPPLZygoteRulesExt.jl b/ext/DynamicPPLZygoteRulesExt.jl deleted file mode 100644 index 78831fdc4..000000000 --- a/ext/DynamicPPLZygoteRulesExt.jl +++ /dev/null @@ -1,25 +0,0 @@ -module DynamicPPLZygoteRulesExt - -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL, Distributions - using ZygoteRules: ZygoteRules -else - using ..DynamicPPL: DynamicPPL, Distributions - using ..ZygoteRules: ZygoteRules -end - -# https://github.com/TuringLang/Turing.jl/issues/1595 -ZygoteRules.@adjoint function DynamicPPL.dot_observe( - spl::Union{DynamicPPL.SampleFromPrior,DynamicPPL.SampleFromUniform}, - dists::AbstractArray{<:Distributions.Distribution}, - value::AbstractArray, - vi, -) - function dot_observe_fallback(spl, dists, value, vi) - DynamicPPL.increment_num_produce!(vi) - return sum(map(Distributions.loglikelihood, dists, value)), vi - end - return ZygoteRules.pullback(dot_observe_fallback, __context__, spl, dists, value, vi) -end - -end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 8fea43e50..f0d42f98c 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -98,13 +98,9 @@ export AbstractVarInfo, PrefixContext, ConditionContext, assume, - dot_assume, observe, - dot_observe, tilde_assume, tilde_observe, - dot_tilde_assume, - dot_tilde_observe, # Pseudo distributions NamedDist, NoDist, diff --git a/src/compiler.jl b/src/compiler.jl index 8743641af..8bde5e784 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -161,7 +161,16 @@ Return `true` if `expr` is a literal, e.g. `1.0` or `[1.0, ]`, and `false` other """ isliteral(e) = false isliteral(::Number) = true -isliteral(e::Expr) = !isempty(e.args) && all(isliteral, e.args) +function isliteral(e::Expr) + # In the special case that the expression is of the form `abc[blahblah]`, we consider it + # to be a literal if `abc` is a literal. This is necessary for cases like + # [1.0, 2.0][idx...] ~ Normal() + # which are generated when turning `.~` expressions into loops over `~` expressions. + if e.head == :ref + return isliteral(e.args[1]) + end + return !isempty(e.args) && all(isliteral, e.args) +end """ check_tilde_rhs(x) @@ -172,7 +181,7 @@ Check if the right-hand side `x` of a `~` is a `Distribution` or an array of function check_tilde_rhs(@nospecialize(x)) return throw( ArgumentError( - "the right-hand side of a `~` must be a `Distribution` or an array of `Distribution`s", + "the right-hand side of a `~` must be a `Distribution`, an array of `Distribution`s, or a submodel", ), ) end @@ -184,6 +193,27 @@ function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} return Sampleable{typeof(model),AutoPrefix}(model) end +""" + check_dot_tilde_rhs(x) + +Check if the right-hand side `x` of a `.~` is a `UnivariateDistribution`, then return `x`. +""" +function check_dot_tilde_rhs(@nospecialize(x)) + return throw( + ArgumentError("the right-hand side of a `.~` must be a `UnivariateDistribution`") + ) +end +function check_dot_tilde_rhs(::AbstractArray{<:Distribution}) + msg = """ + As of v0.35, DynamicPPL does not allow arrays of distributions in `.~`. \ + Please use `product_distribution` instead, or write a loop if necessary. \ + See https://github.com/TuringLang/DynamicPPL.jl/releases/tag/v0.35.0 for more \ + details.\ + """ + return throw(ArgumentError(msg)) +end +check_dot_tilde_rhs(x::UnivariateDistribution) = x + """ unwrap_right_vn(right, vn) @@ -356,11 +386,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn) args_dottilde = getargs_dottilde(expr) if args_dottilde !== nothing L, R = args_dottilde - return Base.remove_linenums!( - generate_dot_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), - ), + return generate_mainbody!( + mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn ) end @@ -487,56 +514,16 @@ end Generate the expression that replaces `left .~ right` in the model body. """ function generate_dot_tilde(left, right) - isliteral(left) && return generate_tilde_literal(left, right) - - # Otherwise it is determined by the model or its value, - # if the LHS represents an observation - @gensym vn isassumption value + @gensym dist left_axes idx return quote - $vn = $(DynamicPPL.resolve_varnames)( - $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $right - ) - $isassumption = $(DynamicPPL.isassumption(left, vn)) - if $(DynamicPPL.isfixed(left, vn)) - $left .= $(DynamicPPL.getfixed_nested)(__context__, $vn) - elseif $isassumption - $(generate_dot_tilde_assume(left, right, vn)) - else - # If `vn` is not in `argnames`, we need to make sure that the variable is defined. - if !$(DynamicPPL.inargnames)($vn, __model__) - $left .= $(DynamicPPL.getconditioned_nested)(__context__, $vn) - end - - $value, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)( - __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $(maybe_view(left)), - $vn, - __varinfo__, - ) - $value + $dist = DynamicPPL.check_dot_tilde_rhs($right) + $left_axes = axes($left) + for $idx in Iterators.product($left_axes...) + $left[$idx...] ~ $dist end end end -function generate_dot_tilde_assume(left, right, vn) - # We don't need to use `Setfield.@set` here since - # `.=` is always going to be inplace + needs `left` to - # be something that supports `.=`. - @gensym value - return quote - $value, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)( - __context__, - $(DynamicPPL.unwrap_right_left_vns)( - $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn - )..., - __varinfo__, - ) - $left .= $value - $value - end -end - # Note that we cannot use `MacroTools.isdef` because # of https://github.com/FluxML/MacroTools.jl/issues/154. """ diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 4594902dc..af04d0f57 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -258,384 +258,3 @@ function observe(right::Distribution, left, vi) increment_num_produce!(vi) return Distributions.loglikelihood(right, left), vi end - -# .~ functions - -# assume -""" - dot_tilde_assume(context::SamplingContext, right, left, vn, vi) - -Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the -model inputs), accumulate the log probability, and return the sampled value for a context -associated with a sampler. - -Falls back to -```julia -dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, vi) -``` -""" -function dot_tilde_assume(context::SamplingContext, right, left, vn, vi) - return dot_tilde_assume( - context.rng, context.context, context.sampler, right, left, vn, vi - ) -end - -# `DefaultContext` -function dot_tilde_assume(context::AbstractContext, args...) - return dot_tilde_assume(NodeTrait(dot_tilde_assume, context), context, args...) -end -function dot_tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) - return dot_tilde_assume(NodeTrait(dot_tilde_assume, context), rng, context, args...) -end - -function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, vi) - return dot_assume(right, left, vns, vi) -end -function dot_tilde_assume(::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, vi) - return dot_assume(rng, sampler, right, vns, left, vi) -end - -function dot_tilde_assume(::IsParent, context::AbstractContext, args...) - return dot_tilde_assume(childcontext(context), args...) -end -function dot_tilde_assume(::IsParent, rng, context::AbstractContext, args...) - return dot_tilde_assume(rng, childcontext(context), args...) -end - -function dot_tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, sampler, right, left, vns, vi -) - return dot_assume(rng, sampler, right, vns, left, vi) -end - -# `LikelihoodContext` -function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) - return dot_assume(nodist(right), left, vn, vi) -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi -) - return dot_assume(rng, sampler, nodist(right), vn, left, vi) -end - -# `PrefixContext` -function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) - return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi) -end - -function dot_tilde_assume( - rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi -) - return dot_tilde_assume( - rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi - ) -end - -""" - dot_tilde_assume!!(context, right, left, vn, vi) - -Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the -model inputs), accumulate the log probability, and return the sampled value and updated `vi`. - -Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. -""" -function dot_tilde_assume!!(context, right, left, vn, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`.~` with a model on the right-hand side is not supported; please use `~`" - ), - ) - value, logp, vi = dot_tilde_assume(context, right, left, vn, vi) - return value, acclogp_assume!!(context, vi, logp) -end - -# `dot_assume` -function dot_assume( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi::AbstractVarInfo, -) - @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = vi[vns, dist] - lp = sum(zip(vns, eachcol(r))) do (vn, ri) - return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) - end - return r, lp, vi -end - -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dist::MultivariateDistribution, - vns::AbstractVector{<:VarName}, - var::AbstractMatrix, - vi::AbstractVarInfo, -) - @assert length(dist) == size(var, 1) - r = get_and_set_val!(rng, vi, vns, dist, spl) - lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) - return r, lp, vi -end - -function dot_assume( - dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi -) - r = getindex.((vi,), vns, (dist,)) - lp = sum(Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns))) - return r, lp, vi -end - -function dot_assume( - dists::AbstractArray{<:Distribution}, - var::AbstractArray, - vns::AbstractArray{<:VarName}, - vi, -) - r = getindex.((vi,), vns, dists) - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) - return r, lp, vi -end - -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - vns::AbstractArray{<:VarName}, - var::AbstractArray, - vi::AbstractVarInfo, -) - r = get_and_set_val!(rng, vi, vns, dists, spl) - # Make sure `r` is not a matrix for multivariate distributions - lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns))) - return r, lp, vi -end -function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any) - return error( - "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing assume statement" - ) -end - -# HACK: These methods are only used in the `get_and_set_val!` methods below. -# FIXME: Remove these. -function _link_broadcast_new(vi, vn, dist, r) - b = to_linked_internal_transform(vi, vn, dist) - return b(r) -end - -function _maybe_invlink_broadcast(vi, vn, dist) - xvec = getindex_internal(vi, vn) - b = from_maybe_linked_internal_transform(vi, vn, dist) - return b(xvec) -end - -function get_and_set_val!( - rng, - vi::VarInfoOrThreadSafeVarInfo, - vns::AbstractVector{<:VarName}, - dist::MultivariateDistribution, - spl::Union{SampleFromPrior,SampleFromUniform}, -) - n = length(vns) - if haskey(vi, vns[1]) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - # TODO(mhauru) Is it important to unset the flag here? The `true` allows us - # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if - # that's okay. - unset_flag!(vi, vns[1], "del", true) - r = init(rng, dist, spl, n) - for i in 1:n - vn = vns[i] - f_link_maybe = to_maybe_linked_internal_transform(vi, vn, dist) - setindex!!(vi, f_link_maybe(r[:, i]), vn) - setorder!(vi, vn, get_num_produce(vi)) - end - else - r = vi[vns, dist] - end - else - r = init(rng, dist, spl, n) - for i in 1:n - vn = vns[i] - if istrans(vi) - ri_linked = _link_broadcast_new(vi, vn, dist, r[:, i]) - push!!(vi, vn, ri_linked, dist) - # `push!!` sets the trans-flag to `false` by default. - settrans!!(vi, true, vn) - else - push!!(vi, vn, r[:, i], dist) - end - end - end - return r -end - -function get_and_set_val!( - rng, - vi::VarInfoOrThreadSafeVarInfo, - vns::AbstractArray{<:VarName}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - spl::Union{SampleFromPrior,SampleFromUniform}, -) - if haskey(vi, vns[1]) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - # TODO(mhauru) Is it important to unset the flag here? The `true` allows us - # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if - # that's okay. - unset_flag!(vi, vns[1], "del", true) - f = (vn, dist) -> init(rng, dist, spl) - r = f.(vns, dists) - for i in eachindex(vns) - vn = vns[i] - dist = dists isa AbstractArray ? dists[i] : dists - f_link_maybe = to_maybe_linked_internal_transform(vi, vn, dist) - setindex!!(vi, f_link_maybe(r[i]), vn) - setorder!(vi, vn, get_num_produce(vi)) - end - else - rs = _maybe_invlink_broadcast.((vi,), vns, dists) - r = reshape(rs, size(vns)) - end - else - f = (vn, dist) -> init(rng, dist, spl) - r = f.(vns, dists) - # TODO: This will inefficient since it will allocate an entire vector. - # We could either: - # 1. Figure out the broadcast size and use a `foreach`. - # 2. Define an anonymous function which returns `nothing`, which - # we then broadcast. This will allocate a vector of `nothing` though. - if istrans(vi) - push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists) - # NOTE: Need to add the correction. - # FIXME: This is not great. - acclogp!!(vi, sum(logabsdetjac.(link_transform.(dists), r))) - # `push!!` sets the trans-flag to `false` by default. - settrans!!.((vi,), true, vns) - else - push!!.((vi,), vns, r, dists) - end - end - return r -end - -function set_val!( - vi::VarInfoOrThreadSafeVarInfo, - vns::AbstractVector{<:VarName}, - dist::MultivariateDistribution, - val::AbstractMatrix, -) - @assert size(val, 2) == length(vns) - foreach(enumerate(vns)) do (i, vn) - setindex!!(vi, val[:, i], vn) - end - return val -end -function set_val!( - vi::VarInfoOrThreadSafeVarInfo, - vns::AbstractArray{<:VarName}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - val::AbstractArray, -) - @assert size(val) == size(vns) - foreach(CartesianIndices(val)) do ind - setindex!!(vi, tovec(val[ind]), vns[ind]) - end - return val -end - -# observe -""" - dot_tilde_observe(context::SamplingContext, right, left, vi) - -Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log -probability, and return the observed value for a context associated with a sampler. - -Falls back to `dot_tilde_observe(context.context, context.sampler, right, left, vi)`. -""" -function dot_tilde_observe(context::SamplingContext, right, left, vi) - return dot_tilde_observe(context.context, context.sampler, right, left, vi) -end - -# Leaf contexts -function dot_tilde_observe(context::AbstractContext, args...) - return dot_tilde_observe(NodeTrait(tilde_observe, context), context, args...) -end -dot_tilde_observe(::IsLeaf, ::AbstractContext, args...) = dot_observe(args...) -function dot_tilde_observe(::IsParent, context::AbstractContext, args...) - return dot_tilde_observe(childcontext(context), args...) -end - -dot_tilde_observe(::PriorContext, right, left, vi) = 0, vi -dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0, vi - -# `MiniBatchContext` -function dot_tilde_observe(context::MiniBatchContext, right, left, vi) - logp, vi = dot_tilde_observe(context.context, right, left, vi) - return context.loglike_scalar * logp, vi -end - -# `PrefixContext` -function dot_tilde_observe(context::PrefixContext, right, left, vi) - return dot_tilde_observe(context.context, right, left, vi) -end - -""" - dot_tilde_observe!!(context, right, left, vname, vi) - -Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the observed value and updated `vi`. - -Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the information about variable -name and indices; if needed, these can be accessed through this function, though. -""" -function dot_tilde_observe!!(context, right, left, vn, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) - return dot_tilde_observe!!(context, right, left, vi) -end - -""" - dot_tilde_observe!!(context, right, left, vi) - -Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log -probability, and return the observed value and updated `vi`. - -Falls back to `dot_tilde_observe(context, right, left, vi)`. -""" -function dot_tilde_observe!!(context, right, left, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) - logp, vi = dot_tilde_observe(context, right, left, vi) - return left, acclogp_observe!!(context, vi, logp) -end - -# Falls back to non-sampler definition. -function dot_observe(::AbstractSampler, dist, value, vi) - return dot_observe(dist, value, vi) -end -function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) - increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value), vi -end -function dot_observe(dists::Distribution, value::AbstractArray, vi) - increment_num_produce!(vi) - return Distributions.loglikelihood(dists, value), vi -end -function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) - increment_num_produce!(vi) - return sum(Distributions.loglikelihood.(dists, value)), vi -end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 43b5054d5..328fe6983 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -113,52 +113,6 @@ function Base.show(io::IO, stmt::ObserveStmt) return print(io, ")") end -Base.@kwdef struct DotAssumeStmt <: Stmt - varname - left - right - value - logp - varinfo = nothing -end - -function Base.show(io::IO, stmt::DotAssumeStmt) - io = add_io_context(io) - print(io, " assume: ") - show_varname(io, stmt.varname) - print(io, " = ") - print(io, stmt.left) - print(io, " .~ ") - show_right(io, stmt.right) - print(io, " ") - print(io, RESULT_SYMBOL) - print(io, " ") - print(io, stmt.value) - print(io, " (logprob = ") - print(io, stmt.logp) - return print(io, ")") -end - -Base.@kwdef struct DotObserveStmt <: Stmt - left - right - logp - varinfo = nothing -end - -function Base.show(io::IO, stmt::DotObserveStmt) - io = add_io_context(io) - print(io, "observe: ") - print(io, stmt.left) - print(io, " .~ ") - show_right(io, stmt.right) - print(io, " ") - print(io, RESULT_SYMBOL) - print(io, " (logprob = ") - print(io, stmt.logp) - return print(io, ")") -end - # Some utility methods for extracting information from a trace. """ varnames_in_trace(trace) @@ -168,24 +122,14 @@ Return all the varnames present in the trace. varnames_in_trace(trace::AbstractVector) = mapreduce(varnames_in_stmt, vcat, trace) varnames_in_stmt(stmt::AssumeStmt) = [stmt.varname] -function varnames_in_stmt(stmt::DotAssumeStmt) - return stmt.varname isa VarName ? [stmt.varname] : stmt.varname -end varnames_in_stmt(::ObserveStmt) = [] -varnames_in_stmt(::DotObserveStmt) = [] function distributions_in_trace(trace::AbstractVector) return mapreduce(distributions_in_stmt, vcat, trace) end distributions_in_stmt(stmt::AssumeStmt) = [stmt.right] -function distributions_in_stmt(stmt::DotAssumeStmt) - return stmt.right isa AbstractArray ? vec(stmt.right) : [stmt.right] -end distributions_in_stmt(stmt::ObserveStmt) = [stmt.right] -function distributions_in_stmt(stmt::DotObserveStmt) - return stmt.right isa AbstractArray ? vec(stmt.right) : [stmt.right] -end """ DebugContext <: AbstractContext @@ -382,95 +326,6 @@ function DynamicPPL.tilde_observe(context::DebugContext, sampler, right, left, v return logp, vi end -# dot-assume -function record_pre_dot_tilde_assume!(context::DebugContext, vn, left, right, varinfo) - # Check for `missing`s; these should not end up here. - if _has_missings(left) - error( - "Variable $(vn) has missing has missing value(s)!\n" * - "Usage of `missing` is not supported for dotted syntax, such as " * - "`@. x ~ dist` or `x .~ dist`", - ) - end - - # TODO: Can we do without the memory allocation here? - record_varname!.(broadcast_safe(context), vn, broadcast_safe(right)) - - # Check that `left` does not contain any `` - return nothing -end - -function record_post_dot_tilde_assume!( - context::DebugContext, vns, left, right, value, logp, varinfo -) - stmt = DotAssumeStmt(; - varname=vns, - left=left, - right=right, - value=value, - logp=logp, - varinfo=context.record_varinfo ? deepcopy(varinfo) : nothing, - ) - if context.record_statements - push!(context.statements, stmt) - end - - return nothing -end - -function DynamicPPL.dot_tilde_assume(context::DebugContext, right, left, vn, vi) - record_pre_dot_tilde_assume!(context, vn, left, right, vi) - value, logp, vi = DynamicPPL.dot_tilde_assume( - childcontext(context), right, left, vn, vi - ) - record_post_dot_tilde_assume!(context, vn, left, right, value, logp, vi) - return value, logp, vi -end - -function DynamicPPL.dot_tilde_assume( - rng::Random.AbstractRNG, context::DebugContext, sampler, right, left, vn, vi -) - record_pre_dot_tilde_assume!(context, vn, left, right, vi) - value, logp, vi = DynamicPPL.dot_tilde_assume( - rng, childcontext(context), sampler, right, left, vn, vi - ) - record_post_dot_tilde_assume!(context, vn, left, right, value, logp, vi) - return value, logp, vi -end - -# dot-observe -function record_pre_dot_tilde_observe!(context::DebugContext, left, right, vi) - # Check for `missing`s; these should not end up here. - if _has_missings(left) - # TODO: Once `observe` statements receive `vn`, refer to this in the - # error message. - error( - "Encountered missing value(s) in observe!\n" * - "Usage of `missing` is not supported for dotted syntax, such as " * - "`@. x ~ dist` or `x .~ dist`", - ) - end -end - -function record_post_dot_tilde_observe!(context::DebugContext, left, right, logp, vi) - stmt = DotObserveStmt(; - left=left, - right=right, - logp=logp, - varinfo=context.record_varinfo ? deepcopy(vi) : nothing, - ) - if context.record_statements - push!(context.statements, stmt) - end - return nothing -end -function DynamicPPL.dot_tilde_observe(context::DebugContext, right, left, vi) - record_pre_dot_tilde_observe!(context, left, right, vi) - logp, vi = DynamicPPL.dot_tilde_observe(childcontext(context), right, left, vi) - record_post_dot_tilde_observe!(context, left, right, logp, vi) - return logp, vi -end - _conditioned_varnames(d::AbstractDict) = keys(d) _conditioned_varnames(d) = map(sym -> VarName{sym}(), keys(d)) function conditioned_varnames(context) diff --git a/src/extract_priors.jl b/src/extract_priors.jl index dd5aeeb04..0f312fa2c 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -39,11 +39,6 @@ function DynamicPPL.tilde_assume(context::PriorExtractorContext, right, vn, vi) return DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) end -function DynamicPPL.dot_tilde_assume(context::PriorExtractorContext, right, left, vn, vi) - setprior!(context, vn, right) - return DynamicPPL.dot_tilde_assume(childcontext(context), right, left, vn, vi) -end - """ extract_priors([rng::Random.AbstractRNG, ]model::Model) diff --git a/src/model.jl b/src/model.jl index 3601d77fd..0fb18f463 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,5 +1,5 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstactContext} + struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 8c18163e3..cb9ea4894 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -100,52 +100,6 @@ function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, v return left, acclogp!!(vi, logp) end -function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) - # Defer literal `observe` to child-context. - return dot_tilde_observe!!(context.context, right, left, vi) -end -function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) - # Completely defer to child context if we are not tracking likelihoods. - if !(_include_likelihood(context)) - return dot_tilde_observe!!(context.context, right, left, vn, vi) - end - - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `dot_tilde_observe!`. - - # We want to treat `.~` as a collection of independent observations, - # hence we need the `logp` for each of them. Broadcasting the univariate - # `tilde_observe` does exactly this. - logps = _pointwise_tilde_observe(context.context, right, left, vi) - - # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. - _, _, vns = unwrap_right_left_vns(right, left, vn) - for (vn, logp) in zip(vns, logps) - # Track loglikelihood value. - push!(context, vn, logp) - end - - return left, acclogp!!(vi, sum(logps)) -end - -# FIXME: This is really not a good approach since it needs to stay in sync with -# the `dot_assume` implementations, but as things are _right now_ this is the best we can do. -function _pointwise_tilde_observe(context, right, left, vi) - # We need to drop the `vi` returned. - return broadcast(right, left) do r, l - return first(tilde_observe(context, r, l, vi)) - end -end - -function _pointwise_tilde_observe( - context, right::MultivariateDistribution, left::AbstractMatrix, vi::AbstractVarInfo -) - # We need to drop the `vi` returned. - return map(eachcol(left)) do l - return first(tilde_observe(context, right, l, vi)) - end -end - # Note on submodels (penelopeysm) # # We don't need to overload tilde_observe!! for Sampleables (yet), because it @@ -174,44 +128,6 @@ function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) return value, acclogp!!(vi, logp) end -function dot_tilde_assume!!(context::PointwiseLogdensityContext, right, left, vns, vi) - !_include_prior(context) && - return (dot_tilde_assume!!(context.context, right, left, vns, vi)) - value, logps = _pointwise_tilde_assume(context, right, left, vns, vi) - # Track loglikelihood values. - for (vn, logp) in zip(vns, logps) - push!(context, vn, logp) - end - return value, acclogp!!(vi, sum(logps)) -end - -function _pointwise_tilde_assume(context, right, left, vns, vi) - # We need to drop the `vi` returned. - values_and_logps = broadcast(right, left, vns) do r, l, vn - # HACK(torfjelde): This drops the `vi` returned, which means the `vi` is not updated - # in case of immutable varinfos. But a) atm we're only using mutable varinfos for this, - # and b) even if the variables aren't stored in the vi correctly, we're not going to use - # this vi for anything downstream anyways, i.e. I don't see a case where this would matter - # for this particular use case. - val, logp, _ = tilde_assume(context, r, vn, vi) - return val, logp - end - return map(first, values_and_logps), map(last, values_and_logps) -end -function _pointwise_tilde_assume( - context, right::MultivariateDistribution, left::AbstractMatrix, vns, vi -) - # We need to drop the `vi` returned. - values_and_logps = map(eachcol(left), vns) do l, vn - val, logp, _ = tilde_assume(context, right, vn, vi) - return val, logp - end - # HACK(torfjelde): Due to the way we handle `.~`, we should use `recombine` to stay consistent. - # But this also means that we need to first flatten the entire `values` component before recombining. - values = recombine(right, mapreduce(vec ∘ first, vcat, values_and_logps), length(vns)) - return values, map(last, values_and_logps) -end - """ pointwise_logdensities(model::Model, chain::Chains, keytype = String) @@ -357,7 +273,7 @@ end """ pointwise_loglikelihoods(model, chain[, keytype, context]) - + Compute the pointwise log-likelihoods of the model given the chain. This is the same as `pointwise_logdensities(model, chain, context)`, but only including the likelihood terms. diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 00d6b3437..173eaa9e1 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -471,57 +471,6 @@ function assume( return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi end -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::Union{Distribution,AbstractArray{<:Distribution}}, - vns::AbstractArray{<:VarName}, - var::AbstractArray, - vi::SimpleOrThreadSafeSimple, -) - f = (vn, dist) -> init(rng, dist, spl) - value = f.(vns, dists) - - # Transform if we're working in transformed space. - value_raw = if dists isa Distribution - to_maybe_linked_internal.((vi,), vns, (dists,), value) - else - to_maybe_linked_internal.((vi,), vns, dists, value) - end - - # Update `vi` - vi = BangBang.setindex!!(vi, value_raw, vns) - - # Compute logp. - lp = sum(Bijectors.logpdf_with_trans.(dists, value, istrans.((vi,), vns))) - return value, lp, vi -end - -function dot_assume( - rng, - spl::Union{SampleFromPrior,SampleFromUniform}, - dist::MultivariateDistribution, - vns::AbstractVector{<:VarName}, - var::AbstractMatrix, - vi::SimpleOrThreadSafeSimple, -) - @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" - - # r = get_and_set_val!(rng, vi, vns, dist, spl) - n = length(vns) - value = init(rng, dist, spl, n) - - # Update `vi`. - for (vn, val) in zip(vns, eachcol(value)) - val_linked = to_maybe_linked_internal(vi, vn, dist, val) - vi = BangBang.setindex!!(vi, val_linked, vn) - end - - # Compute logp. - lp = sum(Bijectors.logpdf_with_trans(dist, value, istrans(vi))) - return value, lp, vi -end - # NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 93bb02d3b..5150be64b 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -26,22 +26,10 @@ function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, v value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) return value, logp * context.mod, vi end -function DynamicPPL.dot_tilde_assume( - context::TestLogModifyingChildContext, right, left, vn, vi -) - value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi) - return value, logp * context.mod, vi -end function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) return logp * context.mod, vi end -function DynamicPPL.dot_tilde_observe( - context::TestLogModifyingChildContext, right, left, vi -) - logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi) - return logp * context.mod, vi -end # Dummy context to test nested behaviors. struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index c506e1ba3..e29614982 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -186,31 +186,29 @@ function _demo_logprior_true_with_logabsdet_jacobian(model, s, m) return (s=s_unconstrained, m=m), logprior_true(model, s, m) - Δlogp end -@model function demo_dot_assume_dot_observe( - x=[1.5, 2.0], ::Type{TV}=Vector{Float64} -) where {TV} +@model function demo_dot_assume_observe(x=[1.5, 2.0], ::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and `observe` s = TV(undef, length(x)) m = TV(undef, length(x)) s .~ InverseGamma(2, 3) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) x ~ MvNormal(m, Diagonal(s)) return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe)}, s, m) return loglikelihood(MvNormal(m, Diagonal(s)), model.args.x) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_dot_assume_dot_observe)}, s, m + model::Model{typeof(demo_dot_assume_observe)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_dot_assume_dot_observe)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] +function varnames(model::Model{typeof(demo_dot_assume_observe)}) + return [@varname(s[1]), @varname(s[2]), @varname(m)] end @model function demo_assume_index_observe( @@ -276,7 +274,7 @@ end s = TV(undef, length(x)) s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) for i in eachindex(x) x[i] ~ Normal(m[i], sqrt(s[i])) end @@ -295,7 +293,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_index)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m)] end # Using vector of `length` 1 here so the posterior of `m` is the same @@ -355,7 +353,7 @@ end s = TV(undef, 2) m = TV(undef, 2) s .~ InverseGamma(2, 3) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) 1.5 ~ Normal(m[1], sqrt(s[1])) 2.0 ~ Normal(m[2], sqrt(s[2])) @@ -376,7 +374,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m)] end @model function demo_assume_observe_literal() @@ -431,7 +429,7 @@ end s = TV(undef, 2) s .~ InverseGamma(2, 3) m = TV(undef, 2) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) return s, m end @@ -460,7 +458,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m)] end @model function _likelihood_multivariate_observe(s, m, x) @@ -473,7 +471,7 @@ end s = TV(undef, length(x)) s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) # Submodel likelihood # With to_submodel, we have to have a left-hand side variable to @@ -494,76 +492,39 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m)] end -@model function demo_dot_assume_dot_observe_matrix( +@model function demo_dot_assume_observe_matrix_index( x=transpose([1.5 2.0;]), ::Type{TV}=Vector{Float64} ) where {TV} s = TV(undef, length(x)) s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal.(0, sqrt.(s)) + m ~ product_distribution(Normal.(0, sqrt.(s))) - # Dotted observe for `Matrix`. - x .~ MvNormal(m, Diagonal(s)) + x[:, 1] ~ MvNormal(m, Diagonal(s)) return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) - return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) -end -function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m -) - return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) -end -function varnames(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) - return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] -end - -@model function demo_dot_assume_matrix_dot_observe_matrix( - x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} -) where {TV} - n = length(x) - d = length(x) ÷ 2 - s = TV(undef, d, 2) - s .~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) - s_vec = vec(s) - m ~ MvNormal(zeros(n), Diagonal(s_vec)) - - # Dotted observe for `Matrix`. - x .~ MvNormal(m, Diagonal(s_vec)) - - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) -end -function logprior_true( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m -) - n = length(model.args.x) - s_vec = vec(s) - return loglikelihood(InverseGamma(2, 3), s_vec) + - logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m) -end function loglikelihood_true( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m + model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m ) - return loglikelihood(MvNormal(m, Diagonal(vec(s))), model.args.x) + return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m + model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) - s = zeros(1, 2) # used for varname concretization only - return [@varname(s[:, 1], true), @varname(s[:, 2], true), @varname(m)] +function varnames(model::Model{typeof(demo_dot_assume_observe_matrix_index)}) + return [@varname(s[1]), @varname(s[2]), @varname(m)] end -@model function demo_assume_matrix_dot_observe_matrix( +@model function demo_assume_matrix_observe_matrix_index( x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} ) where {TV} n = length(x) @@ -572,33 +533,32 @@ end s_vec = vec(s) m ~ MvNormal(zeros(n), Diagonal(s_vec)) - # Dotted observe for `Matrix`. - x .~ MvNormal(m, Diagonal(s_vec)) + x[:, 1] ~ MvNormal(m, Diagonal(s_vec)) return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}, s, m) +function logprior_true(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m) n = length(model.args.x) s_vec = vec(s) return loglikelihood(InverseGamma(2, 3), s_vec) + logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m) end function loglikelihood_true( - model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}, s, m + model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m ) return loglikelihood(MvNormal(m, Diagonal(vec(s))), model.args.x) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}, s, m + model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_assume_matrix_dot_observe_matrix)}) +function varnames(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}) return [@varname(s), @varname(m)] end const DemoModels = Union{ - Model{typeof(demo_dot_assume_dot_observe)}, + Model{typeof(demo_dot_assume_observe)}, Model{typeof(demo_assume_index_observe)}, Model{typeof(demo_assume_multivariate_observe)}, Model{typeof(demo_dot_assume_observe_index)}, @@ -609,9 +569,8 @@ const DemoModels = Union{ Model{typeof(demo_dot_assume_observe_index_literal)}, Model{typeof(demo_assume_submodel_observe_index_literal)}, Model{typeof(demo_dot_assume_observe_submodel)}, - Model{typeof(demo_dot_assume_dot_observe_matrix)}, - Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, - Model{typeof(demo_assume_matrix_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_observe_matrix_index)}, + Model{typeof(demo_assume_matrix_observe_matrix_index)}, } const UnivariateAssumeDemoModels = Union{ @@ -637,7 +596,7 @@ function rand_prior_true(rng::Random.AbstractRNG, model::UnivariateAssumeDemoMod end const MultivariateAssumeDemoModels = Union{ - Model{typeof(demo_dot_assume_dot_observe)}, + Model{typeof(demo_dot_assume_observe)}, Model{typeof(demo_assume_index_observe)}, Model{typeof(demo_assume_multivariate_observe)}, Model{typeof(demo_dot_assume_observe_index)}, @@ -645,8 +604,7 @@ const MultivariateAssumeDemoModels = Union{ Model{typeof(demo_dot_assume_observe_index_literal)}, Model{typeof(demo_assume_submodel_observe_index_literal)}, Model{typeof(demo_dot_assume_observe_submodel)}, - Model{typeof(demo_dot_assume_dot_observe_matrix)}, - Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_observe_matrix_index)}, } function posterior_mean(model::MultivariateAssumeDemoModels) # Get some containers to fill. @@ -699,7 +657,7 @@ function rand_prior_true(rng::Random.AbstractRNG, model::MultivariateAssumeDemoM end const MatrixvariateAssumeDemoModels = Union{ - Model{typeof(demo_assume_matrix_dot_observe_matrix)} + Model{typeof(demo_assume_matrix_observe_matrix_index)} } function posterior_mean(model::MatrixvariateAssumeDemoModels) # Get some containers to fill. @@ -786,7 +744,7 @@ And for the multivariate one (the latter one): """ const DEMO_MODELS = ( - demo_dot_assume_dot_observe(), + demo_dot_assume_observe(), demo_assume_index_observe(), demo_assume_multivariate_observe(), demo_dot_assume_observe_index(), @@ -797,7 +755,6 @@ const DEMO_MODELS = ( demo_assume_observe_literal(), demo_assume_submodel_observe_index_literal(), demo_dot_assume_observe_submodel(), - demo_dot_assume_dot_observe_matrix(), - demo_dot_assume_matrix_dot_observe_matrix(), - demo_assume_matrix_dot_observe_matrix(), + demo_dot_assume_observe_matrix_index(), + demo_assume_matrix_observe_matrix_index(), ) diff --git a/src/transforming.jl b/src/transforming.jl index 1a26d212f..0239725ae 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -30,67 +30,6 @@ function tilde_assume( return r, lp, setindex!!(vi, r_transformed, vn) end -function dot_tilde_assume( - ::DynamicTransformationContext{isinverse}, - dist::Distribution, - var::AbstractArray, - vns::AbstractArray{<:VarName}, - vi, -) where {isinverse} - r = getindex.((vi,), vns, (dist,)) - b = link_transform(dist) - - is_trans_uniques = unique(istrans.((vi,), vns)) - @assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables" - is_trans = first(is_trans_uniques) - if is_trans - @assert isinverse "Trying to link already transformed variables" - else - @assert !isinverse "Trying to invlink non-transformed variables" - end - - # Only transform if `!isinverse` since `vi[vn, right]` - # already performs the inverse transformation if it's transformed. - r_transformed = isinverse ? r : b.(r) - lp = sum(Bijectors.logpdf_with_trans.((dist,), r, (!isinverse,))) - return r, lp, setindex!!(vi, r_transformed, vns) -end - -function dot_tilde_assume( - ::DynamicTransformationContext{isinverse}, - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi::AbstractVarInfo, -) where {isinverse} - @assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))" - r = vi[vns, dist] - - # Compute `logpdf` with logabsdet-jacobian correction. - lp = sum(zip(vns, eachcol(r))) do (vn, ri) - return Bijectors.logpdf_with_trans(dist, ri, !isinverse) - end - - # Transform _all_ values. - is_trans_uniques = unique(istrans.((vi,), vns)) - @assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables" - is_trans = first(is_trans_uniques) - if is_trans - @assert isinverse "Trying to link already transformed variables" - else - @assert !isinverse "Trying to invlink non-transformed variables" - end - - b = link_transform(dist) - for (vn, ri) in zip(vns, eachcol(r)) - # Only transform if `!isinverse` since `vi[vn, right]` - # already performs the inverse transformation if it's transformed. - vi = setindex!!(vi, isinverse ? ri : b(ri), vn) - end - - return r, lp, vi -end - function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) end diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 4cef5fa4e..d3bfd697a 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -90,29 +90,6 @@ function tilde_assume( return value, logp, vi end -# `dot_tilde_assume` -function dot_tilde_assume(context::ValuesAsInModelContext, right, left, vn, vi) - value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi) - - # Save the value. - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - broadcast_push!(context, _vns, value) - - return value, logp, vi -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, left, vn, vi -) - value, logp, vi = dot_tilde_assume( - rng, childcontext(context), sampler, right, left, vn, vi - ) - # Save the value. - _right, _left, _vns = unwrap_right_left_vns(right, left, vn) - broadcast_push!(context, _vns, value) - - return value, logp, vi -end - """ values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext]) diff --git a/test/compat/ad.jl b/test/compat/ad.jl index f76ce6f6e..e6b23f379 100644 --- a/test/compat/ad.jl +++ b/test/compat/ad.jl @@ -26,32 +26,4 @@ test_model_ad(wishart_ad(), logp_wishart_ad) end - - # https://github.com/TuringLang/Turing.jl/issues/1595 - @testset "dot_observe" begin - function f_dot_observe(x) - logp, _ = DynamicPPL.dot_observe( - SampleFromPrior(), [Normal(), Normal(-1.0, 2.0)], x, VarInfo() - ) - return logp - end - function f_dot_observe_manual(x) - return logpdf(Normal(), x[1]) + logpdf(Normal(-1.0, 2.0), x[2]) - end - - # Manual computation of the gradient. - x = randn(2) - val = f_dot_observe_manual(x) - grad = ForwardDiff.gradient(f_dot_observe_manual, x) - - @test ForwardDiff.gradient(f_dot_observe, x) ≈ grad - - y, back = Tracker.forward(f_dot_observe, x) - @test Tracker.data(y) ≈ val - @test Tracker.data(back(1)[1]) ≈ grad - - y, back = Zygote.pullback(f_dot_observe, x) - @test y ≈ val - @test back(1)[1] ≈ grad - end end diff --git a/test/compiler.jl b/test/compiler.jl index 051eba618..8d81c530a 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -288,6 +288,33 @@ module Issue537 end x = vdemo()() @test all((isassigned(x, i) for i in eachindex(x))) end + + # A couple of uses of .~ that are no longer valid as of v0.35. + @testset "old .~ syntax" begin + @model function multivariate_dot_tilde() + x = Vector{Float64}(undef, 2) + x .~ MvNormal(zeros(2), I) + return x + end + expected_error = ArgumentError( + "the right-hand side of a `.~` must be a `UnivariateDistribution`" + ) + @test_throws expected_error (multivariate_dot_tilde()(); true) + + @model function vector_dot_tilde() + x = Vector{Float64}(undef, 2) + x .~ [Normal(), Normal()] + return x + end + expected_error = ArgumentError(""" + As of v0.35, DynamicPPL does not allow arrays of distributions in `.~`. \ + Please use `product_distribution` instead, or write a loop if necessary. \ + See https://github.com/TuringLang/DynamicPPL.jl/releases/tag/v0.35.0 for more \ + details.\ + """) + @test_throws expected_error (vector_dot_tilde()(); true) + end + @testset "nested model" begin function makemodel(p) @model function testmodel(x) diff --git a/test/context_implementations.jl b/test/context_implementations.jl index 8a795320d..0ec88c07c 100644 --- a/test/context_implementations.jl +++ b/test/context_implementations.jl @@ -4,7 +4,7 @@ @model function test(x) μ ~ MvNormal(zeros(2), 4 * I) z = Vector{Int}(undef, length(x)) - z .~ Categorical.(fill([0.5, 0.5], length(x))) + z ~ product_distribution(Categorical.(fill([0.5, 0.5], length(x)))) for i in 1:length(x) x[i] ~ Normal(μ[z[i]], 0.1) end @@ -13,59 +13,36 @@ test([1, 1, -1])(VarInfo(), SampleFromPrior(), LikelihoodContext()) end - # https://github.com/TuringLang/DynamicPPL.jl/issues/28#issuecomment-829223577 - @testset "dot tilde: arrays of distributions" begin + @testset "dot tilde with varying sizes" begin @testset "assume" begin @model function test(x, size) y = Array{Float64,length(size)}(undef, size...) - y .~ Normal.(x) + y .~ Normal(x) return y, getlogp(__varinfo__) end for ysize in ((2,), (2, 3), (2, 3, 4)) - for x in ( - # scalar - randn(), - # drop trailing dimensions - ntuple(i -> randn(ysize[1:i]), length(ysize))..., - # singleton dimensions - ntuple( - i -> randn(ysize[1:(i - 1)]..., 1, ysize[(i + 1):end]...), - length(ysize), - )..., - ) - model = test(x, ysize) - y, lp = model() - @test lp ≈ sum(logpdf.(Normal.(x), y)) + x = randn() + model = test(x, ysize) + y, lp = model() + @test lp ≈ sum(logpdf.(Normal.(x), y)) - ys = [first(model()) for _ in 1:10_000] - @test norm(mean(ys) .- x, Inf) < 0.1 - @test norm(std(ys) .- 1, Inf) < 0.1 - end + ys = [first(model()) for _ in 1:10_000] + @test norm(mean(ys) .- x, Inf) < 0.1 + @test norm(std(ys) .- 1, Inf) < 0.1 end end @testset "observe" begin @model function test(x, y) - return y .~ Normal.(x) + return y .~ Normal(x) end for ysize in ((2,), (2, 3), (2, 3, 4)) - for x in ( - # scalar - randn(), - # drop trailing dimensions - ntuple(i -> randn(ysize[1:i]), length(ysize))..., - # singleton dimensions - ntuple( - i -> randn(ysize[1:(i - 1)]..., 1, ysize[(i + 1):end]...), - length(ysize), - )..., - ) - y = randn(ysize) - z = logjoint(test(x, y), VarInfo()) - @test z ≈ sum(logpdf.(Normal.(x), y)) - end + x = randn() + y = randn(ysize) + z = logjoint(test(x, y), VarInfo()) + @test z ≈ sum(logpdf.(Normal.(x), y)) end end end diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 5c0b2e090..61c842638 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -48,8 +48,8 @@ end @testset "pointwise_logdensities chain" begin # We'll just test one, since `pointwise_logdensities(::Model, ::AbstractVarInfo)` is tested extensively, # and this is what is used to implement `pointwise_logdensities(::Model, ::Chains)`. This test suite is just - # to ensure that we don't accidentally break the the version on `Chains`. - model = DynamicPPL.TestUtils.demo_dot_assume_dot_observe() + # to ensure that we don't accidentally break the version on `Chains`. + model = DynamicPPL.TestUtils.demo_assume_index_observe() # FIXME(torfjelde): Make use of `varname_and_value_leaves` once we've introduced # an impl of this for containers. # NOTE(torfjelde): This only returns the varnames of the _random_ variables, i.e. excl. observed. diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 137c791c2..e67b5656a 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -139,8 +139,6 @@ @testset "SimpleVarInfo on $(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - model = DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix() - # We might need to pre-allocate for the variable `m`, so we need # to see whether this is the case. svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) @@ -155,9 +153,10 @@ svi_nt, svi_dict, svi_vnv, - DynamicPPL.settrans!!(deepcopy(svi_nt), true), - DynamicPPL.settrans!!(deepcopy(svi_dict), true), - DynamicPPL.settrans!!(deepcopy(svi_vnv), true), + # TODO(mhauru) Fix linked SimpleVarInfos to work with our test models. + # DynamicPPL.settrans!!(deepcopy(svi_nt), true), + # DynamicPPL.settrans!!(deepcopy(svi_dict), true), + # DynamicPPL.settrans!!(deepcopy(svi_vnv), true), ) # RandOM seed is set in each `@testset`, so we need to sample # a new realization for `m` here.