diff --git a/.github/workflows/Enzyme.yml b/.github/workflows/Enzyme.yml index 36f11b914..8074095f8 100644 --- a/.github/workflows/Enzyme.yml +++ b/.github/workflows/Enzyme.yml @@ -18,7 +18,7 @@ concurrency: jobs: enzyme: - runs-on: ubuntu-latest + runs-on: macos-latest steps: - uses: actions/checkout@v5 @@ -27,9 +27,19 @@ jobs: version: "1.11" - uses: julia-actions/cache@v2 + id: julia-cache - name: Run AD with Enzyme on demo models working-directory: test/integration/enzyme run: | julia --project=. --color=yes -e 'using Pkg; Pkg.instantiate()' julia --project=. --color=yes main.jl + + - name: Save Julia depot cache on cancel or failure + id: julia-cache-save + if: cancelled() || failure() + uses: actions/cache/save@v4 + with: + path: | + ${{ steps.julia-cache.outputs.cache-paths }} + key: ${{ steps.julia-cache.outputs.cache-key }} diff --git a/HISTORY.md b/HISTORY.md index 0f0102ce4..777f3f32c 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,20 @@ ### Breaking changes +#### Fast Log Density Functions + +This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. +Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. + +For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. + +As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it. +In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`. +If you were previously relying on this behaviour, you will need to store a VarInfo separately. + +Along with this change, DynamicPPL now exposes the `fast_evaluate!!` method which allows you to hook into this 'fast evaluation' pipeline directly. +Please see the documentation for details. + #### Parent and leaf contexts The `DynamicPPL.NodeTrait` function has been removed. @@ -17,6 +31,17 @@ Leaf contexts require no changes, apart from a removal of the `NodeTrait` functi `ConditionContext` and `PrefixContext` are no longer exported. You should not need to use these directly, please use `AbstractPPL.condition` and `DynamicPPL.prefix` instead. +#### SimpleVarInfo + +`SimpleVarInfo` has been removed. +Its main purpose was for evaluating models rapidly. +However, `fast_evaluate!!` provides a cleaner way of doing this. +In particular, if you want to evaluate a model at a given set of parameters, you can do: + +```julia +retval, vi = DynamicPPL.fast_evaluate!!(rng, model, InitFromParams(params), accs) +``` + #### Miscellaneous Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. @@ -24,18 +49,6 @@ Removed the method `returned(::Model, values, keys)`; please use `returned(::Mod The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space. This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). -### Other changes - -#### FastLDF - -Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. -Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. - -Please note that `FastLDF` is currently considered internal and its API may change without warning. -We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it. - -For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. - ## 0.38.9 Remove warning when using Enzyme as the AD backend. diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index e8ffa7e0b..ba3439986 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -68,9 +68,7 @@ function run(; to_json=false) false, ), ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), - ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 0dc7ece6e..00d2e071b 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -1,6 +1,6 @@ module DynamicPPLBenchmarks -using DynamicPPL: VarInfo, SimpleVarInfo, VarName +using DynamicPPL: VarInfo, VarName using DynamicPPL: DynamicPPL using DynamicPPL.TestUtils.AD: run_ad, NoTest using ADTypes: ADTypes @@ -60,8 +60,6 @@ and AD backend. Available varinfo choices: • `:untyped` → uses `DynamicPPL.untyped_varinfo(model)` • `:typed` → uses `DynamicPPL.typed_varinfo(model)` - • `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())` - • `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs) The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`). @@ -76,12 +74,6 @@ function benchmark(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::B DynamicPPL.untyped_varinfo(rng, model) elseif varinfo_choice == :typed DynamicPPL.typed_varinfo(rng, model) - elseif varinfo_choice == :simple_namedtuple - SimpleVarInfo{Float64}(model(rng)) - elseif varinfo_choice == :simple_dict - retvals = model(rng) - vns = [VarName{k}() for k in keys(retvals)] - SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals)))) elseif varinfo_choice == :typed_vector DynamicPPL.typed_vector_varinfo(rng, model) elseif varinfo_choice == :untyped_vector diff --git a/benchmarks/src/Models.jl b/benchmarks/src/Models.jl index 2c881aa95..76d4b2e93 100644 --- a/benchmarks/src/Models.jl +++ b/benchmarks/src/Models.jl @@ -2,7 +2,7 @@ Models for benchmarking Turing.jl. Each model returns a NamedTuple of all the random variables in the model that are not -observed (this is used for constructing SimpleVarInfos). +observed. """ module Models diff --git a/docs/src/api.md b/docs/src/api.md index e81f18dc7..a3b9e2fdc 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -66,6 +66,13 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte LogDensityFunction ``` +Internally, this is accomplished using: + +```@docs +OnlyAccsVarInfo +fast_evaluate!! +``` + ## Condition and decondition A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref). @@ -352,12 +359,6 @@ set_transformed!! Base.empty! ``` -#### `SimpleVarInfo` - -```@docs -SimpleVarInfo -``` - ### Accumulators The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 2155fa161..8b3040757 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -6,8 +6,13 @@ using MarginalLogDensities: MarginalLogDensities # A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by # MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type # below. -struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction} +struct LogDensityFunctionWrapper{ + L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo +} logdensity::L + # This field is used only to reconstruct the VarInfo later on; it's not needed for the + # actual log-density evaluation. + varinfo::V end function (lw::LogDensityFunctionWrapper)(x, _) return LogDensityProblems.logdensity(lw.logdensity, x) @@ -101,7 +106,7 @@ function DynamicPPL.marginalize( # Construct the marginal log-density model. f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) mld = MarginalLogDensities.MarginalLogDensity( - LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs... + LogDensityFunctionWrapper(f, varinfo), varinfo[:], varindices, (), method; kwargs... ) return mld end @@ -190,7 +195,7 @@ function DynamicPPL.VarInfo( unmarginalized_params::Union{AbstractVector,Nothing}=nothing, ) # Extract the original VarInfo. Its contents will in general be junk. - original_vi = mld.logdensity.logdensity.varinfo + original_vi = mld.logdensity.varinfo # Extract the stored parameters, which includes the modes for any marginalized # parameters full_params = MarginalLogDensities.cached_params(mld) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e9b902363..f985fda73 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -46,7 +46,6 @@ import Base: # VarInfo export AbstractVarInfo, VarInfo, - SimpleVarInfo, AbstractAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, @@ -92,8 +91,10 @@ export AbstractVarInfo, getargnames, extract_priors, values_as_in_model, - # LogDensityFunction + # LogDensityFunction and fasteval LogDensityFunction, + fast_evaluate!!, + OnlyAccsVarInfo, # Leaf contexts AbstractContext, contextualize, @@ -172,7 +173,7 @@ Abstract supertype for data structures that capture random variables when execut probabilistic model and accumulate log densities such as the log likelihood or the log joint probability of the model. -See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref). +See also: [`VarInfo`](@ref). """ abstract type AbstractVarInfo <: AbstractModelTrace end @@ -194,11 +195,10 @@ include("default_accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") -include("simple_varinfo.jl") include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") -include("logdensityfunction.jl") +include("fasteval.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index ec5e1ea10..14528522b 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -502,52 +502,6 @@ If no `Type` is provided, return values as stored in `varinfo`. # Examples -`SimpleVarInfo` with `NamedTuple`: - -```jldoctest -julia> data = (x = 1.0, m = [2.0]); - -julia> values_as(SimpleVarInfo(data)) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), NamedTuple) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{VarName{sym, typeof(identity)} where sym, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), Vector) -2-element Vector{Float64}: - 1.0 - 2.0 -``` - -`SimpleVarInfo` with `OrderedDict`: - -```jldoctest -julia> data = OrderedDict{Any,Any}(@varname(x) => 1.0, @varname(m) => [2.0]); - -julia> values_as(SimpleVarInfo(data)) -OrderedDict{Any, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), NamedTuple) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{Any, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), Vector) -2-element Vector{Float64}: - 1.0 - 2.0 -``` - `VarInfo` with `NamedTuple` of `Metadata`: ```jldoctest @@ -828,8 +782,8 @@ function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link!!(default_transformation(model, vi), vi, vns, model) end function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - # Note that in practice this method is only called for SimpleVarInfo, because VarInfo - # has a dedicated implementation + # Note that VarInfo has a dedicated implementation so this is only a generic + # fallback (previously used for SimpleVarInfo) model = setleafcontext(model, DynamicTransformationContext{false}()) vi = last(evaluate!!(model, vi)) return set_transformed!!(vi, t) @@ -890,8 +844,8 @@ function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink!!(default_transformation(model, vi), vi, vns, model) end function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) - # Note that in practice this method is only called for SimpleVarInfo, because VarInfo - # has a dedicated implementation + # Note that VarInfo has a dedicated implementation so this is only a generic + # fallback (previously used for SimpleVarInfo) model = setleafcontext(model, DynamicTransformationContext{true}()) vi = last(evaluate!!(model, vi)) return set_transformed!!(vi, NoTransformation()) @@ -946,47 +900,6 @@ This will be called prior to `model` evaluation, allowing one to perform a singl basis as is done with [`DynamicTransformation`](@ref). See also: [`StaticTransformation`](@ref), [`DynamicTransformation`](@ref). - -# Examples -```julia-repl -julia> using DynamicPPL, Distributions, Bijectors - -julia> @model demo() = x ~ Normal() -demo (generic function with 2 methods) - -julia> # By subtyping `Transform`, we inherit the `(inv)link!!`. - struct MyBijector <: Bijectors.Transform end - -julia> # Define some dummy `inverse` which will be used in the `link!!` call. - Bijectors.inverse(f::MyBijector) = identity - -julia> # We need to define `with_logabsdet_jacobian` for `MyBijector` - # (`identity` already has `with_logabsdet_jacobian` defined) - function Bijectors.with_logabsdet_jacobian(::MyBijector, x) - # Just using a large number of the logabsdet-jacobian term - # for demonstration purposes. - return (x, 1000) - end - -julia> # Change the `default_transformation` for our model to be a - # `StaticTransformation` using `MyBijector`. - function DynamicPPL.default_transformation(::Model{typeof(demo)}) - return DynamicPPL.StaticTransformation(MyBijector()) - end - -julia> model = demo(); - -julia> vi = SimpleVarInfo(x=1.0) -SimpleVarInfo((x = 1.0,), 0.0) - -julia> # Uses the `inverse` of `MyBijector`, which we have defined as `identity` - vi_linked = link!!(vi, model) -Transformed SimpleVarInfo((x = 1.0,), 0.0) - -julia> # Now performs a single `invlink!!` before model evaluation. - logjoint(model, vi_linked) --1001.4189385332047 -``` """ function maybe_invlink_before_eval!!(vi::AbstractVarInfo, model::Model) return maybe_invlink_before_eval!!(transformation(vi), vi, model) diff --git a/src/contexts/transformation.jl b/src/contexts/transformation.jl index c2eee2863..0914d7a79 100644 --- a/src/contexts/transformation.jl +++ b/src/contexts/transformation.jl @@ -7,7 +7,7 @@ constrained space if `isinverse` or unconstrained if `!isinverse`. Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the `DynamicTransformationContext` methods with more efficient implementations. `DynamicTransformationContext` is a fallback for when we need to evaluate the model to know -how to do the transformation, used by e.g. `SimpleVarInfo`. +how to do the transformation. """ struct DynamicTransformationContext{isinverse} <: AbstractContext end diff --git a/src/experimental.jl b/src/experimental.jl index c644c09b2..8c82dca68 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -2,8 +2,6 @@ module Experimental using DynamicPPL: DynamicPPL -include("fasteval.jl") - # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) diff --git a/src/fasteval.jl b/src/fasteval.jl index 4f402f4a8..639e5b6c1 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -29,7 +29,61 @@ import DifferentiationInterface as DI using Random: Random """ - FastLDF( + DynamicPPL.fast_evaluate!!( + [rng::Random.AbstractRNG,] + model::Model, + strategy::AbstractInitStrategy, + accs::AccumulatorTuple, params::AbstractVector{<:Real} + ) + +Evaluate a model using parameters obtained via `strategy`, and only computing the results in +the provided accumulators. + +It is assumed that the accumulators passed in have been initialised to appropriate values, +as this function will not reset them. The default constructors for each accumulator will do +this for you correctly. + +Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs` +argument may be mutated (depending on how the accumulators are implemented); hence the `!!` +in the function name. +""" +@inline function fast_evaluate!!( + # Note that this `@inline` is mandatory for performance. If it's not inlined, it leads + # to extra allocations (even for trivial models) and much slower runtime. + rng::Random.AbstractRNG, + model::Model, + strategy::AbstractInitStrategy, + accs::AccumulatorTuple, +) + ctx = InitContext(rng, strategy) + model = DynamicPPL.setleafcontext(model, ctx) + # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, + # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` + # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic + # here. + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 + vi = if Threads.nthreads() > 1 + param_eltype = DynamicPPL.get_param_eltype(strategy) + accs = map(accs) do acc + DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) + end + ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) + else + OnlyAccsVarInfo(accs) + end + return DynamicPPL._evaluate!!(model, vi) +end +@inline function fast_evaluate!!( + model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple +) + # This `@inline` is also mandatory for performance + return fast_evaluate!!(Random.default_rng(), model, strategy, accs) +end + +""" + DynamicPPL.LogDensityFunction( model::Model, getlogdensity::Function=getlogjoint_internal, varinfo::AbstractVarInfo=VarInfo(model); @@ -60,10 +114,10 @@ There are several options for `getlogdensity` that are 'supported' out of the bo since transforms are only applied to random variables) !!! note - By default, `FastLDF` uses `getlogjoint_internal`, i.e., the result of - `LogDensityProblems.logdensity(f, x)` will depend on whether the `FastLDF` was created - with a linked or unlinked VarInfo. This is done primarily to ease interoperability with - MCMC samplers. + By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the result of + `LogDensityProblems.logdensity(f, x)` will depend on whether the `LogDensityFunction` + was created with a linked or unlinked VarInfo. This is done primarily to ease + interoperability with MCMC samplers. If you provide one of these functions, a `VarInfo` will be automatically created for you. If you provide a different function, you have to manually create a VarInfo and pass it as the @@ -71,15 +125,16 @@ third argument. If the `adtype` keyword argument is provided, then this struct will also store the adtype along with other information for efficient calculation of the gradient of the log density. -Note that preparing a `FastLDF` with an AD type `AutoBackend()` requires the AD backend -itself to have been loaded (e.g. with `import Backend`). +Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD +backend itself to have been loaded (e.g. with `import Backend`). ## Fields -Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart from: +Note that it is undefined behaviour to access any of a `LogDensityFunction`'s fields, apart +from: -- `fastldf.model`: The original model from which this `FastLDF` was constructed. -- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD +- `ldf.model`: The original model from which this `LogDensityFunction` was constructed. +- `ldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD type was provided. # Extended help @@ -117,8 +172,9 @@ Traditionally, this problem has been solved by `unflatten`, because that functio place values into the VarInfo's metadata alongside the information about ranges and linking. That way, when we evaluate with `DefaultContext`, we can read this information out again. However, we want to avoid using a metadata. Thus, here, we _extract this information from -the VarInfo_ a single time when constructing a `FastLDF` object. Inside the FastLDF, we -store a mapping from VarNames to ranges in that vector, along with link status. +the VarInfo_ a single time when constructing a `LogDensityFunction` object. Inside the +LogDensityFunction, we store a mapping from VarNames to ranges in that vector, along with +link status. For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all other VarNames, this is stored in a Dict. The internal data structure used to represent this @@ -130,13 +186,13 @@ ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quick parameter values from the vector. Note that this assumes that the ranges and link status are static throughout the lifetime of -the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable -numbers of parameters, or models which may visit random variables in different orders depending -on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a -general limitation of vectorised parameters: the original `unflatten` + `evaluate!!` -approach also fails with such models. +the `LogDensityFunction` object. Therefore, a `LogDensityFunction` object cannot handle +models which have variable numbers of parameters, or models which may visit random variables +in different orders depending on stochastic control flow. **Indeed, silent errors may occur +with such models.** This is a general limitation of vectorised parameters: the original +`unflatten` + `evaluate!!` approach also fails with such models. """ -struct FastLDF{ +struct LogDensityFunction{ M<:Model, AD<:Union{ADTypes.AbstractADType,Nothing}, F<:Function, @@ -151,7 +207,7 @@ struct FastLDF{ _adprep::ADP _dim::Int - function FastLDF( + function LogDensityFunction( model::Model, getlogdensity::Function=getlogjoint_internal, varinfo::AbstractVarInfo=VarInfo(model); @@ -169,7 +225,7 @@ struct FastLDF{ # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) DI.prepare_gradient( - FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), adtype, x, ) @@ -206,85 +262,77 @@ end fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} +struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple} model::M getlogdensity::F iden_varname_ranges::N varname_ranges::Dict{VarName,RangeAndLinked} end -function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) - ctx = InitContext( - Random.default_rng(), - InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing - ), +function (f::LogDensityAt)(params::AbstractVector{<:Real}) + strategy = InitFromParams( + VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing ) - model = DynamicPPL.setleafcontext(f.model, ctx) accs = fast_ldf_accs(f.getlogdensity) - # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, - # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` - # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic - # here. - # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what - # it _should_ do, but this is wrong regardless. - # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - vi = if Threads.nthreads() > 1 - accs = map( - acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), - accs, - ) - ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) - else - OnlyAccsVarInfo(accs) - end - _, vi = DynamicPPL._evaluate!!(model, vi) + _, vi = DynamicPPL.fast_evaluate!!(f.model, strategy, accs) return f.getlogdensity(vi) end -function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) - return FastLogDensityAt( - fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges +function LogDensityProblems.logdensity( + ldf::LogDensityFunction, params::AbstractVector{<:Real} +) + return LogDensityAt( + ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges )( params ) end function LogDensityProblems.logdensity_and_gradient( - fldf::FastLDF, params::AbstractVector{<:Real} + ldf::LogDensityFunction, params::AbstractVector{<:Real} ) return DI.value_and_gradient( - FastLogDensityAt( - fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + LogDensityAt( + ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges ), - fldf._adprep, - fldf.adtype, + ldf._adprep, + ldf.adtype, params, ) end -function LogDensityProblems.capabilities( - ::Type{<:DynamicPPL.Experimental.FastLDF{M,Nothing}} -) where {M} +function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:DynamicPPL.Experimental.FastLDF{M,<:ADTypes.AbstractADType}} + ::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}} ) where {M} return LogDensityProblems.LogDensityOrder{1}() end -function LogDensityProblems.dimension(fldf::FastLDF) - return fldf._dim +function LogDensityProblems.dimension(ldf::LogDensityFunction) + return ldf._dim end +""" + tweak_adtype( + adtype::ADTypes.AbstractADType, + model::Model, + varinfo::AbstractVarInfo, + ) + +Return an 'optimised' form of the adtype. This is useful for doing +backend-specific optimisation of the adtype (e.g., for ForwardDiff, calculating +the chunk size: see the method override in `ext/DynamicPPLForwardDiffExt.jl`). +The model is passed as a parameter in case the optimisation depends on the +model. + +By default, this just returns the input unchanged. +""" +tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype + ###################################################### # Helper functions to extract ranges and link status # ###################################################### -# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The -# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges -# and link status. So there is no motivation to use SimpleVarInfo inside a -# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue -# that there is no purpose in supporting untyped VarInfo either. """ get_ranges_and_linked(varinfo::VarInfo) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl deleted file mode 100644 index 7c7438c9f..000000000 --- a/src/logdensityfunction.jl +++ /dev/null @@ -1,377 +0,0 @@ -using AbstractMCMC: AbstractModel -import DifferentiationInterface as DI - -""" - is_supported(adtype::AbstractADType) - -Check if the given AD type is formally supported by DynamicPPL. - -AD backends that are not formally supported can still be used for gradient -calculation; it is just that the DynamicPPL developers do not commit to -maintaining compatibility with them. -""" -is_supported(::ADTypes.AbstractADType) = false -is_supported(::ADTypes.AutoEnzyme) = true -is_supported(::ADTypes.AutoForwardDiff) = true -is_supported(::ADTypes.AutoMooncake) = true -is_supported(::ADTypes.AutoReverseDiff) = true - -""" - LogDensityFunction( - model::Model, - getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing - ) - -A struct which contains a model, along with all the information necessary to: - - - calculate its log density at a given point; - - and if `adtype` is provided, calculate the gradient of the log density at - that point. - -This information can be extracted using the LogDensityProblems.jl interface, -specifically, using `LogDensityProblems.logdensity` and -`LogDensityProblems.logdensity_and_gradient`. If `adtype` is nothing, then only -`logdensity` is implemented. If `adtype` is a concrete AD backend type, then -`logdensity_and_gradient` is also implemented. - -There are several options for `getlogdensity` that are 'supported' out of the -box: - -- [`getlogjoint_internal`](@ref): calculate the log joint, including the - log-Jacobian term for any variables that have been linked in the provided - VarInfo. -- [`getlogprior_internal`](@ref): calculate the log prior, including the - log-Jacobian term for any variables that have been linked in the provided - VarInfo. -- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring - any effects of linking -- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring - any effects of linking -- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected - by linking, since transforms are only applied to random variables) - -!!! note - By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the - result of `LogDensityProblems.logdensity(f, x)` will depend on whether the - `LogDensityFunction` was created with a linked or unlinked VarInfo. This - is done primarily to ease interoperability with MCMC samplers. - -If you provide one of these functions, a `VarInfo` will be automatically created -for you. If you provide a different function, you have to manually create a -VarInfo and pass it as the third argument. - -If the `adtype` keyword argument is provided, then this struct will also store -the adtype along with other information for efficient calculation of the -gradient of the log density. Note that preparing a `LogDensityFunction` with an -AD type `AutoBackend()` requires the AD backend itself to have been loaded -(e.g. with `import Backend`). - -# Fields -$(FIELDS) - -# Examples - -```jldoctest -julia> using Distributions - -julia> using DynamicPPL: LogDensityFunction, setaccs!! - -julia> @model function demo(x) - m ~ Normal() - x ~ Normal(m, 1) - end -demo (generic function with 2 methods) - -julia> model = demo(1.0); - -julia> f = LogDensityFunction(model); - -julia> # It implements the interface of LogDensityProblems.jl. - using LogDensityProblems - -julia> LogDensityProblems.logdensity(f, [0.0]) --2.3378770664093453 - -julia> LogDensityProblems.dimension(f) -1 - -julia> # By default it uses `VarInfo` under the hood, but this is not necessary. - f = LogDensityFunction(model, getlogjoint_internal, SimpleVarInfo(model)); - -julia> LogDensityProblems.logdensity(f, [0.0]) --2.3378770664093453 - -julia> # One can also specify evaluating e.g. the log prior only: - f_prior = LogDensityFunction(model, getlogprior); - -julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) -true - -julia> # If we also need to calculate the gradient, we can specify an AD backend. - import ForwardDiff, ADTypes - -julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff()); - -julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) -(-2.3378770664093453, [1.0]) -``` -""" -struct LogDensityFunction{ - M<:Model,F<:Function,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType} -} <: AbstractModel - "model used for evaluation" - model::M - "function to be called on `varinfo` to extract the log density. By default `getlogjoint_internal`." - getlogdensity::F - "varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`." - varinfo::V - "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" - adtype::AD - "(internal use only) gradient preparation object for the model" - prep::Union{Nothing,DI.GradientPrep} - - function LogDensityFunction( - model::Model, - getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - ) - if adtype === nothing - prep = nothing - else - # Make backend-specific tweaks to the adtype - adtype = tweak_adtype(adtype, model, varinfo) - # Check whether it is supported - is_supported(adtype) || - @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." - # Get a set of dummy params to use for prep - x = [val for val in varinfo[:]] - if use_closure(adtype) - prep = DI.prepare_gradient( - LogDensityAt(model, getlogdensity, varinfo), adtype, x - ) - else - prep = DI.prepare_gradient( - logdensity_at, - adtype, - x, - DI.Constant(model), - DI.Constant(getlogdensity), - DI.Constant(varinfo), - ) - end - end - return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)}( - model, getlogdensity, varinfo, adtype, prep - ) - end -end - -""" - LogDensityFunction( - ldf::LogDensityFunction, - adtype::Union{Nothing,ADTypes.AbstractADType} - ) - -Create a new LogDensityFunction using the model and varinfo from the given -`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, -pass `nothing` as the second argument. -""" -function LogDensityFunction( - f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType} -) - return if adtype === f.adtype - f # Avoid recomputing prep if not needed - else - LogDensityFunction(f.model, f.getlogdensity, f.varinfo; adtype=adtype) - end -end - -""" - ldf_default_varinfo(model::Model, getlogdensity::Function) - -Create the default AbstractVarInfo that should be used for evaluating the log density. - -Only the accumulators necesessary for `getlogdensity` will be used. -""" -function ldf_default_varinfo(::Model, getlogdensity::Function) - msg = """ - LogDensityFunction does not know what sort of VarInfo should be used when \ - `getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly. - """ - return error(msg) -end - -ldf_default_varinfo(model::Model, ::typeof(getlogjoint_internal)) = VarInfo(model) - -function ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogLikelihoodAccumulator())) -end - -function ldf_default_varinfo(model::Model, ::typeof(getlogprior_internal)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogJacobianAccumulator())) -end - -function ldf_default_varinfo(model::Model, ::typeof(getlogprior)) - return setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) -end - -function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood)) - return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),)) -end - -""" - logdensity_at( - x::AbstractVector, - model::Model, - getlogdensity::Function, - varinfo::AbstractVarInfo, - ) - -Evaluate the log density of the given `model` at the given parameter values -`x`, using the given `varinfo`. Note that the `varinfo` argument is provided -only for its structure, in the sense that the parameters from the vector `x` -are inserted into it, and its own parameters are discarded. `getlogdensity` is -the function that extracts the log density from the evaluated varinfo. -""" -function logdensity_at( - x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo -) - varinfo_new = unflatten(varinfo, x) - varinfo_eval = last(evaluate!!(model, varinfo_new)) - return getlogdensity(varinfo_eval) -end - -""" - LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo}( - model::M - getlogdensity::F, - varinfo::V - ) - -A callable struct that serves the same purpose as `x -> logdensity_at(x, model, -getlogdensity, varinfo)`. -""" -struct LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo} - model::M - getlogdensity::F - varinfo::V -end -function (ld::LogDensityAt)(x::AbstractVector) - return logdensity_at(x, ld.model, ld.getlogdensity, ld.varinfo) -end - -### LogDensityProblems interface - -function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,F,V,Nothing}} -) where {M,F,V} - return LogDensityProblems.LogDensityOrder{0}() -end -function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,F,V,AD}} -) where {M,F,V,AD<:ADTypes.AbstractADType} - return LogDensityProblems.LogDensityOrder{1}() -end -function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) - return logdensity_at(x, f.model, f.getlogdensity, f.varinfo) -end -function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction{M,F,V,AD}, x::AbstractVector -) where {M,F,V,AD<:ADTypes.AbstractADType} - f.prep === nothing && - error("Gradient preparation not available; this should not happen") - x = [val for val in x] # Concretise type - # Make branching statically inferrable, i.e. type-stable (even if the two - # branches happen to return different types) - return if use_closure(f.adtype) - DI.value_and_gradient( - LogDensityAt(f.model, f.getlogdensity, f.varinfo), f.prep, f.adtype, x - ) - else - DI.value_and_gradient( - logdensity_at, - f.prep, - f.adtype, - x, - DI.Constant(f.model), - DI.Constant(f.getlogdensity), - DI.Constant(f.varinfo), - ) - end -end - -# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? -LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) - -### Utils - -""" - tweak_adtype( - adtype::ADTypes.AbstractADType, - model::Model, - varinfo::AbstractVarInfo, - ) - -Return an 'optimised' form of the adtype. This is useful for doing -backend-specific optimisation of the adtype (e.g., for ForwardDiff, calculating -the chunk size: see the method override in `ext/DynamicPPLForwardDiffExt.jl`). -The model is passed as a parameter in case the optimisation depends on the -model. - -By default, this just returns the input unchanged. -""" -tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype - -""" - use_closure(adtype::ADTypes.AbstractADType) - -In LogDensityProblems, we want to calculate the derivative of logdensity(f, x) -with respect to x, where f is the model (in our case LogDensityFunction) and is -a constant. However, DifferentiationInterface generally expects a -single-argument function g(x) to differentiate. - -There are two ways of dealing with this: - -1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f) - -2. Use a constant DI.Context. This lets us pass a two-argument function to DI, - as long as we also give it the 'inactive argument' (i.e. the model) wrapped - in `DI.Constant`. - -The relative performance of the two approaches, however, depends on the AD -backend used. Some benchmarks are provided here: -https://github.com/TuringLang/DynamicPPL.jl/issues/946#issuecomment-2931604829 - -This function is used to determine whether a given AD backend should use a -closure or a constant. If `use_closure(adtype)` returns `true`, then the -closure approach will be used. By default, this function returns `false`, i.e. -the constant approach will be used. -""" -use_closure(::ADTypes.AbstractADType) = true -use_closure(::ADTypes.AutoEnzyme) = false - -""" - getmodel(f) - -Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. -""" -getmodel(f::DynamicPPL.LogDensityFunction) = f.model - -""" - setmodel(f, model[, adtype]) - -Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. -""" -function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return LogDensityFunction(model, f.getlogdensity, f.varinfo; adtype=f.adtype) -end - -""" - getparams(f::LogDensityFunction) - -Return the parameters of the wrapped varinfo as a vector. -""" -getparams(f::LogDensityFunction) = f.varinfo[:] diff --git a/src/model.jl b/src/model.jl index 2bcfe8f98..2ba0c6cd4 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1062,8 +1062,11 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()))) - return values_as(x, T) + # TODO(penelopeysm): This can be done with an accumulator instead. For + # T = Dict, ValuesAsInModelAcc can already do it. For T = NamedTuple we + # would just need a similar accumulator that collects into a NamedTuple + # rather than a Dict. + return values_as(VarInfo(rng, model), T) end # Default RNG and type @@ -1155,12 +1158,115 @@ julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0)) ``` """ function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}}) - vi = DynamicPPL.setaccs!!(VarInfo(), ()) - # Note: we can't use `fix(model, parameters)` because - # https://github.com/TuringLang/DynamicPPL.jl/issues/1097 - # Use `nothing` as the fallback to ensure that any missing parameters cause an error - ctx = InitContext(Random.default_rng(), InitFromParams(parameters, nothing)) - new_model = setleafcontext(model, ctx) - # We can't use new_model() because that overwrites it with an InitContext of its own. - return first(evaluate!!(new_model, vi)) + accs = AccumulatorTuple() + retval, _ = DynamicPPL.fast_evaluate!!(model, InitFromParams(parameters, nothing), accs) + return retval +end + +""" + logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) + +Return the log joint probability of variables `θ` for the probabilistic `model`. + +See [`logprior`](@ref) and [`loglikelihood`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + logjoint(demo([1.0]), (m = 100.0, )) +-9902.33787706641 + +julia> # Using a `OrderedDict`. + logjoint(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-9902.33787706641 + +julia> # Truth. + logpdf(Normal(100.0, 1.0), 1.0) + logpdf(Normal(), 100.0) +-9902.33787706641 +``` +""" +function logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) + accs = AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) + _, vi = DynamicPPL.fast_evaluate!!(model, InitFromParams(θ, nothing), accs) + return getlogjoint(vi) +end + +""" + logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) + +Return the log prior probability of variables `θ` for the probabilistic `model`. + +See also [`logjoint`](@ref) and [`loglikelihood`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + logprior(demo([1.0]), (m = 100.0, )) +-5000.918938533205 + +julia> # Using a `OrderedDict`. + logprior(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-5000.918938533205 + +julia> # Truth. + logpdf(Normal(), 100.0) +-5000.918938533205 +``` +""" +function logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) + accs = AccumulatorTuple((LogPriorAccumulator(),)) + _, vi = DynamicPPL.fast_evaluate!!(model, InitFromParams(θ, nothing), accs) + return getlogprior(vi) +end + +""" + loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) + +Return the log likelihood of variables `θ` for the probabilistic `model`. + +See also [`logjoint`](@ref) and [`logprior`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + loglikelihood(demo([1.0]), (m = 100.0, )) +-4901.418938533205 + +julia> # Using a `OrderedDict`. + loglikelihood(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-4901.418938533205 + +julia> # Truth. + logpdf(Normal(100.0, 1.0), 1.0) +-4901.418938533205 +``` +""" +function Distributions.loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) + accs = AccumulatorTuple((LogLikelihoodAccumulator(),)) + _, vi = DynamicPPL.fast_evaluate!!(model, InitFromParams(θ, nothing), accs) + return getloglikelihood(vi) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl deleted file mode 100644 index 434480be6..000000000 --- a/src/simple_varinfo.jl +++ /dev/null @@ -1,655 +0,0 @@ -""" - $(TYPEDEF) - -A simple wrapper of the parameters with a `logp` field for -accumulation of the logdensity. - -Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict`. - -# Fields -$(FIELDS) - -# Notes -The major differences between this and `NTVarInfo` are: -1. `SimpleVarInfo` does not require linearization. -2. `SimpleVarInfo` can use more efficient bijectors. -3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either - a) no indexing is used in tilde-statements, or - b) the values have been specified with the correct shapes. - -# Examples -## General usage -```jldoctest simplevarinfo-general; setup=:(using Distributions) -julia> using StableRNGs - -julia> @model function demo() - m ~ Normal() - x = Vector{Float64}(undef, 2) - for i in eachindex(x) - x[i] ~ Normal() - end - return x - end -demo (generic function with 2 methods) - -julia> m = demo(); - -julia> rng = StableRNG(42); - -julia> # In the `NamedTuple` version we need to provide the place-holder values for - # the variables which are using "containers", e.g. `Array`. - # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo((x = ones(2), ))); - -julia> # (✓) Vroom, vroom! FAST!!! - vi[@varname(x[1])] -0.4471218424633827 - -julia> # We can also access arbitrary varnames pointing to `x`, e.g. - vi[@varname(x)] -2-element Vector{Float64}: - 0.4471218424633827 - 1.3736306979834252 - -julia> vi[@varname(x[1:2])] -2-element Vector{Float64}: - 0.4471218424633827 - 1.3736306979834252 - -julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); -ERROR: FieldError: type NamedTuple has no field `x`, available fields: `m` -[...] - -julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); - -julia> # (✓) Sort of fast, but only possible at runtime. - vi[@varname(x[1])] --1.019202452456547 - -julia> # In addtion, we can only access varnames as they appear in the model! - vi[@varname(x)] -ERROR: x was not found in the dictionary provided -[...] - -julia> vi[@varname(x[1:2])] -ERROR: x[1:2] was not found in the dictionary provided -[...] -``` - -_Technically_, it's possible to use any implementation of `AbstractDict` in place of -`OrderedDict`, but `OrderedDict` ensures that certain operations, e.g. linearization/flattening -of the values in the varinfo, are consistent between evaluations. Hence `OrderedDict` is -the preferred implementation of `AbstractDict` to use here. - -You can also sample in _transformed_ space: - -```jldoctest simplevarinfo-general -julia> @model demo_constrained() = x ~ Exponential() -demo_constrained (generic function with 2 methods) - -julia> m = demo_constrained(); - -julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); - -julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ -1.8632965762164932 - -julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)); - -julia> vi[@varname(x)] # (✓) -∞ < x < ∞ --0.21080155351918753 - -julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; - -julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! -true - -julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); - -julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.6225185067787314 - -julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; - -julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! -true -``` - -Evaluation in transformed space of course also works: - -```jldoctest simplevarinfo-general -julia> vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0))) - -julia> # (✓) Positive probability mass on negative numbers! - getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) --1.3678794411714423 - -julia> # While if we forget to indicate that it's transformed: - vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0))) - -julia> # (✓) No probability mass on negative numbers! - getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) --Inf -``` - -## Indexing -Using `NamedTuple` as underlying storage. - -```jldoctest -julia> svi_nt = SimpleVarInfo((m = (a = [1.0], ), )); - -julia> svi_nt[@varname(m)] -(a = [1.0],) - -julia> svi_nt[@varname(m.a)] -1-element Vector{Float64}: - 1.0 - -julia> svi_nt[@varname(m.a[1])] -1.0 - -julia> svi_nt[@varname(m.a[2])] -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] -[...] - -julia> svi_nt[@varname(m.b)] -ERROR: FieldError: type NamedTuple has no field `b`, available fields: `a` -[...] -``` - -Using `OrderedDict` as underlying storage. -```jldoctest -julia> svi_dict = SimpleVarInfo(OrderedDict(@varname(m) => (a = [1.0], ))); - -julia> svi_dict[@varname(m)] -(a = [1.0],) - -julia> svi_dict[@varname(m.a)] -1-element Vector{Float64}: - 1.0 - -julia> svi_dict[@varname(m.a[1])] -1.0 - -julia> svi_dict[@varname(m.a[2])] -ERROR: m.a[2] was not found in the dictionary provided -[...] - -julia> svi_dict[@varname(m.b)] -ERROR: m.b was not found in the dictionary provided -[...] -``` -""" -struct SimpleVarInfo{NT,Accs<:AccumulatorTuple where {N},C<:AbstractTransformation} <: - AbstractVarInfo - "underlying representation of the realization represented" - values::NT - "tuple of accumulators for things like log prior and log likelihood" - accs::Accs - "represents whether it assumes variables to be transformed" - transformation::C -end - -function Base.:(==)(vi1::SimpleVarInfo, vi2::SimpleVarInfo) - return vi1.values == vi2.values && - vi1.accs == vi2.accs && - vi1.transformation == vi2.transformation -end - -transformation(vi::SimpleVarInfo) = vi.transformation - -function SimpleVarInfo(values, accs) - return SimpleVarInfo(values, accs, NoTransformation()) -end -function SimpleVarInfo{T}(values) where {T<:Real} - return SimpleVarInfo(values, default_accumulators(T)) -end -function SimpleVarInfo(values) - return SimpleVarInfo{LogProbType}(values) -end -function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict{<:VarName}}) - return if isempty(values) - # Can't infer from values, so we just use default. - SimpleVarInfo{LogProbType}(values) - else - # Infer from `values`. - SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(values)))}(values) - end -end - -# Using `kwargs` to specify the values. -function SimpleVarInfo{T}(; kwargs...) where {T<:Real} - return SimpleVarInfo{T}(NamedTuple(kwargs)) -end -function SimpleVarInfo(; kwargs...) - return SimpleVarInfo(NamedTuple(kwargs)) -end - -# Constructor from `Model`. -function SimpleVarInfo{T}( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) where {T<:Real} - return last(init!!(rng, model, SimpleVarInfo{T}(), init_strategy)) -end -function SimpleVarInfo{T}( - model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() -) where {T<:Real} - return SimpleVarInfo{T}(Random.default_rng(), model, init_strategy) -end -# Constructors without type param -function SimpleVarInfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return SimpleVarInfo{LogProbType}(rng, model, init_strategy) -end -function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return SimpleVarInfo{LogProbType}(Random.default_rng(), model, init_strategy) -end - -# Constructor from `VarInfo`. -function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} - values = values_as(vi, D) - return SimpleVarInfo(values, copy(getaccs(vi))) -end -function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} - values = values_as(vi, D) - accs = map(acc -> convert_eltype(T, acc), getaccs(vi)) - return SimpleVarInfo(values, accs) -end - -function untyped_simple_varinfo(model::Model) - varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) - return last(init!!(model, varinfo)) -end - -function typed_simple_varinfo(model::Model) - varinfo = SimpleVarInfo{Float64}() - return last(init!!(model, varinfo)) -end - -function unflatten(svi::SimpleVarInfo, x::AbstractVector) - vals = unflatten(svi.values, x) - # TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is - # required but undesireable. - # The below line is finicky for type stability. For instance, assigning the eltype to - # convert to into an intermediate variable makes this unstable (constant propagation) - # fails. Take care when editing. - accs = map( - acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), getaccs(svi) - ) - return SimpleVarInfo(vals, accs, svi.transformation) -end - -function BangBang.empty!!(vi::SimpleVarInfo) - return resetaccs!!(Accessors.@set vi.values = empty!!(vi.values)) -end -Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) - -getaccs(vi::SimpleVarInfo) = vi.accs -setaccs!!(vi::SimpleVarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs - -""" - keys(vi::SimpleVarInfo) - -Return an iterator of keys present in `vi`. -""" -Base.keys(vi::SimpleVarInfo) = keys(vi.values) -Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values)) - -function Base.show(io::IO, mime::MIME"text/plain", svi::SimpleVarInfo) - if !(svi.transformation isa NoTransformation) - print(io, "Transformed ") - end - - return print(io, "SimpleVarInfo(", svi.values, ", ", repr(mime, getaccs(svi)), ")") -end - -function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) - return from_maybe_linked_internal(vi, vn, dist, getindex(vi, vn)) -end -function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) - vals_linked = mapreduce(vcat, vns) do vn - getindex(vi, vn, dist) - end - return recombine(dist, vals_linked, length(vns)) -end - -Base.getindex(vi::SimpleVarInfo, vn::VarName) = getindex_internal(vi, vn) - -# `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than -# just `Vector`. -function Base.getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) - return map(Base.Fix1(getindex, vi), vns) -end -# HACK: Needed to disambiguate. -Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) - -Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) - -getindex_internal(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) -# `AbstractDict` -function getindex_internal( - vi::SimpleVarInfo{<:Union{AbstractDict,VarNamedVector}}, vn::VarName -) - return getvalue(vi.values, vn) -end - -Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) - -function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) - # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. - return Accessors.@set vi.values = set!!(vi.values, vn, val) -end - -# TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with -# same symbol and same type of, say, `IndexLens`, for improved `.~` performance. -function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) - for (vn, val) in zip(vns, vals) - vi = BangBang.setindex!!(vi, val, vn) - end - return vi -end - -function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName) - # For dictlike objects, we treat the entire `vn` as a _key_ to set. - dict = values_as(vi) - # Attempt to split into `parent` and `child` optic. - parent, child, issuccess = splitoptic(getoptic(vn)) do optic - o = optic === nothing ? identity : optic - haskey(dict, VarName{getsym(vn)}(o)) - end - # When combined with `VarInfo`, `nothing` is equivalent to `identity`. - keyoptic = parent === nothing ? identity : parent - - dict_new = if !issuccess - # Split doesn't exist ⟹ we're working with a new key. - BangBang.setindex!!(dict, val, vn) - else - # Split exists ⟹ trying to set an existing key. - vn_key = VarName{getsym(vn)}(keyoptic) - BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) - end - return Accessors.@set vi.values = dict_new -end - -# `NamedTuple` -function BangBang.push!!( - vi::SimpleVarInfo{<:NamedTuple}, ::VarName{sym,typeof(identity)}, value, ::Distribution -) where {sym} - return Accessors.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,))) -end -function BangBang.push!!( - vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}, value, ::Distribution -) where {sym} - return Accessors.@set vi.values = set!!(vi.values, vn, value) -end - -# `AbstractDict` -function BangBang.push!!( - vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, value, ::Distribution -) - vi.values[vn] = value - return vi -end - -function BangBang.push!!( - vi::SimpleVarInfo{<:VarNamedVector}, vn::VarName, value, ::Distribution -) - # The semantics of push!! for SimpleVarInfo and VarNamedVector are different. For - # SimpleVarInfo, push!! allows the key to exist already, for VarNamedVector it does not. - # Hence we need to call update!! here, which has the same semantics as push!! does for - # SimpleVarInfo. - return Accessors.@set vi.values = setindex!!(vi.values, value, vn) -end - -const SimpleOrThreadSafeSimple{T,V,C} = Union{ - SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}} -} - -# Necessary for `matchingvalue` to work properly. -Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V - -# `subset` -function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) - return SimpleVarInfo( - _subset(varinfo.values, vns), map(copy, getaccs(varinfo)), varinfo.transformation - ) -end - -function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName} - vns_present = collect(keys(x)) - vns_found = filter( - vn_present -> any(subsumes(vn, vn_present) for vn in vns), vns_present - ) - C = ConstructionBase.constructorof(typeof(x)) - if isempty(vns_found) - return C() - else - return C(vn => x[vn] for vn in vns_found) - end -end - -function _subset(x::NamedTuple, vns) - # NOTE: Here we can only handle `vns` that contain `identity` as optic. - if any(Base.Fix1(!==, identity) ∘ getoptic, vns) - throw( - ArgumentError( - "Cannot subset `NamedTuple` with non-`identity` `VarName`. " * - "For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.", - ), - ) - end - - syms = map(getsym, vns) - x_syms = filter(Base.Fix2(in, syms), keys(x)) - return NamedTuple{Tuple(x_syms)}(Tuple(map(Base.Fix1(getindex, x), x_syms))) -end - -_subset(x::VarNamedVector, vns) = subset(x, vns) - -# `merge` -function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) - values = merge(varinfo_left.values, varinfo_right.values) - accs = map(copy, getaccs(varinfo_right)) - transformation = merge_transformations( - varinfo_left.transformation, varinfo_right.transformation - ) - return SimpleVarInfo(values, accs, transformation) -end - -function set_transformed!!(vi::SimpleVarInfo, trans) - return set_transformed!!(vi, trans ? DynamicTransformation() : NoTransformation()) -end -function set_transformed!!(vi::SimpleVarInfo, transformation::AbstractTransformation) - return Accessors.@set vi.transformation = transformation -end -function set_transformed!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) - return Accessors.@set vi.varinfo = set_transformed!!(vi.varinfo, trans) -end -function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) - # We keep this method around just to obey the AbstractVarInfo interface. - # However, note that this would only be a valid operation if it would be a - # no-op, which we check here. - if trans != is_transformed(vi) - error( - "Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.", - ) - end - return vi -end - -is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) -is_transformed(vi::SimpleVarInfo, ::VarName) = is_transformed(vi) -function is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) - return is_transformed(vi.varinfo, vn) -end -is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = is_transformed(vi.varinfo) - -values_as(vi::SimpleVarInfo) = vi.values -values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values -function values_as(vi::SimpleVarInfo, ::Type{Vector}) - isempty(vi) && return Any[] - return mapreduce(tovec, vcat, values(vi.values)) -end -function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict} - return ConstructionBase.constructorof(D)(zip(keys(vi), values(vi.values))) -end -function values_as(vi::SimpleVarInfo{<:AbstractDict}, ::Type{NamedTuple}) - return NamedTuple((Symbol(k), v) for (k, v) in vi.values) -end -function values_as(vi::SimpleVarInfo, ::Type{T}) where {T} - return values_as(vi.values, T) -end - -""" - logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) - -Return the log joint probability of variables `θ` for the probabilistic `model`. - -See [`logprior`](@ref) and [`loglikelihood`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - logjoint(demo([1.0]), (m = 100.0, )) --9902.33787706641 - -julia> # Using a `OrderedDict`. - logjoint(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --9902.33787706641 - -julia> # Truth. - logpdf(Normal(100.0, 1.0), 1.0) + logpdf(Normal(), 100.0) --9902.33787706641 -``` -""" -logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) = - logjoint(model, SimpleVarInfo(θ)) - -""" - logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) - -Return the log prior probability of variables `θ` for the probabilistic `model`. - -See also [`logjoint`](@ref) and [`loglikelihood`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - logprior(demo([1.0]), (m = 100.0, )) --5000.918938533205 - -julia> # Using a `OrderedDict`. - logprior(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --5000.918938533205 - -julia> # Truth. - logpdf(Normal(), 100.0) --5000.918938533205 -``` -""" -logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) = - logprior(model, SimpleVarInfo(θ)) - -""" - loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) - -Return the log likelihood of variables `θ` for the probabilistic `model`. - -See also [`logjoint`](@ref) and [`logprior`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - loglikelihood(demo([1.0]), (m = 100.0, )) --4901.418938533205 - -julia> # Using a `OrderedDict`. - loglikelihood(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --4901.418938533205 - -julia> # Truth. - logpdf(Normal(100.0, 1.0), 1.0) --4901.418938533205 -``` -""" -Distributions.loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) = - loglikelihood(model, SimpleVarInfo(θ)) - -# Allow usage of `NamedBijector` too. -function link!!( - t::StaticTransformation{<:Bijectors.NamedTransform}, - vi::SimpleVarInfo{<:NamedTuple}, - ::Model, -) - b = inverse(t.bijector) - x = vi.values - y, logjac = with_logabsdet_jacobian(b, x) - vi_new = Accessors.@set(vi.values = y) - if hasacc(vi_new, Val(:LogJacobian)) - vi_new = acclogjac!!(vi_new, logjac) - end - return set_transformed!!(vi_new, t) -end - -function invlink!!( - t::StaticTransformation{<:Bijectors.NamedTransform}, - vi::SimpleVarInfo{<:NamedTuple}, - ::Model, -) - b = t.bijector - y = vi.values - x, inv_logjac = with_logabsdet_jacobian(b, y) - vi_new = Accessors.@set(vi.values = x) - # Mildly confusing: we need to _add_ the logjac of the inverse transform, - # because we are trying to remove the logjac of the forward transform - # that was previously accumulated when linking. - if hasacc(vi_new, Val(:LogJacobian)) - vi_new = acclogjac!!(vi_new, inv_logjac) - end - return set_transformed!!(vi_new, NoTransformation()) -end - -# With `SimpleVarInfo`, when we're not working with linked variables, there's no need to do anything. -from_internal_transform(vi::SimpleVarInfo, ::VarName) = identity -from_internal_transform(vi::SimpleVarInfo, ::VarName, dist) = identity -# TODO: Should the following methods specialize on the case where we have a `StaticTransformation{<:Bijectors.NamedTransform}`? -from_linked_internal_transform(vi::SimpleVarInfo, ::VarName) = identity -function from_linked_internal_transform(vi::SimpleVarInfo, ::VarName, dist) - return invlink_transform(dist) -end - -has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index cb949464e..bd6caa93b 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -89,10 +89,10 @@ function logprior_true_with_logabsdet_jacobian end Return a collection of `VarName` as they are expected to appear in the model. Even though it is recommended to implement this by hand for a particular `Model`, -a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. +a default implementation using [`VarInfo`](@ref) is provided. """ function varnames(model::Model) - return collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict()))))) + return collect(keys(VarInfo(model))) end """ diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 26e2aa7ca..0f74da3ae 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -32,24 +32,12 @@ function setup_varinfos( vi_typed_metadata = DynamicPPL.typed_varinfo(model) vi_typed_vnv = DynamicPPL.typed_vector_varinfo(model) - # SimpleVarInfo - svi_typed = SimpleVarInfo(example_values) - svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}()) - svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) - - varinfos = map(( - vi_untyped_metadata, - vi_untyped_vnv, - vi_typed_metadata, - vi_typed_vnv, - svi_typed, - svi_untyped, - svi_vnv, - )) do vi - # Set them all to the same values and evaluate logp. - vi = update_values!!(vi, example_values, varnames) - last(DynamicPPL.evaluate!!(model, vi)) - end + varinfos = + map((vi_untyped_metadata, vi_untyped_vnv, vi_typed_metadata, vi_typed_vnv)) do vi + # Set them all to the same values and evaluate logp. + vi = update_values!!(vi, example_values, varnames) + last(DynamicPPL.evaluate!!(model, vi)) + end if include_threadsafe varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo ∘ deepcopy, varinfos)...) diff --git a/src/utils.jl b/src/utils.jl index 75fb805dc..2d7b0404f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,9 +5,6 @@ const NO_DEFAULT = NoDefault() # A short-hand for a type commonly used in type signatures for VarInfo methods. VarNameTuple = NTuple{N,VarName} where {N} -# TODO(mhauru) This is currently used in the transformation functions of NoDist, -# ReshapeTransform, and UnwrapSingletonTransform, and in VarInfo. We should also use it in -# SimpleVarInfo and maybe other places. """ The type for all log probability variables. diff --git a/test/ad.jl b/test/ad.jl deleted file mode 100644 index 0236c232f..000000000 --- a/test/ad.jl +++ /dev/null @@ -1,137 +0,0 @@ -using DynamicPPL: LogDensityFunction -using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest - -@testset "Automatic differentiation" begin - # Used as the ground truth that others are compared against. - ref_adtype = AutoForwardDiff() - - test_adtypes = [ - AutoReverseDiff(; compile=false), - AutoReverseDiff(; compile=true), - AutoMooncake(; config=nothing), - ] - - @testset "Unsupported backends" begin - @model demo() = x ~ Normal() - @test_logs (:warn, r"not officially supported") LogDensityFunction( - demo(); adtype=AutoZygote() - ) - end - - @testset "Correctness" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) - x = DynamicPPL.getparams(f) - - # Calculate reference logp + gradient of logp using ForwardDiff - ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) - ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual - - @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" - - # Put predicates here to avoid long lines - is_mooncake = adtype isa AutoMooncake - is_1_10 = v"1.10" <= VERSION < v"1.11" - is_1_11_or_1_12 = v"1.11" <= VERSION < v"1.13" - is_svi_vnv = - linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} - is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} - - # Mooncake doesn't work with several combinations of SimpleVarInfo. - if is_mooncake && is_1_11_or_1_12 && is_svi_vnv - # https://github.com/compintell/Mooncake.jl/issues/470 - @test_throws ArgumentError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_vnv - # TODO: report upstream - @test_throws UndefRefError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_od - # TODO: report upstream - @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - else - @test run_ad( - m, - adtype; - varinfo=linked_varinfo, - test=WithExpectedResult(ref_logp, ref_grad), - ) isa Any - end - end - end - end - end - - # Test that various different ways of specifying array types as arguments work with all - # ADTypes. - @testset "Array argument types" begin - test_m = randn(2, 3) - - function eval_logp_and_grad(model, m, adtype) - ldf = LogDensityFunction(model(); adtype=adtype) - return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) - end - - @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} - m = Matrix{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_matrix_model_reference = eval_logp_and_grad( - scalar_matrix_model, test_m, ref_adtype - ) - - @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) - - @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} - m = Array{T}(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - scalar_array_model_reference = eval_logp_and_grad( - scalar_array_model, test_m, ref_adtype - ) - - @model function array_model(::Type{T}=Array{Float64}) where {T} - m = T(undef, 2, 3) - return m ~ filldist(MvNormal(zeros(2), I), 3) - end - - array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) - - @testset "$adtype" for adtype in test_adtypes - scalar_matrix_model_logp_and_grad = eval_logp_and_grad( - scalar_matrix_model, test_m, adtype - ) - @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] - @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] - matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) - @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] - @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] - scalar_array_model_logp_and_grad = eval_logp_and_grad( - scalar_array_model, test_m, adtype - ) - @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] - @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] - array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) - @test array_model_logp_and_grad[1] ≈ array_model_reference[1] - @test array_model_logp_and_grad[2] ≈ array_model_reference[2] - end - end -end diff --git a/test/compiler.jl b/test/compiler.jl index b1309254e..0da1f13fb 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -604,13 +604,13 @@ module Issue537 end # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) - @test svi == SimpleVarInfo() + retval, vi = DynamicPPL.init!!(demo(), VarInfo()) + @test vi == VarInfo() if Threads.nthreads() > 1 - @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} - @test retval.varinfo == svi + @test retval isa DynamicPPL.ThreadSafeVarInfo{<:VarInfo} + @test retval.varinfo == vi else - @test retval == svi + @test retval == vi end # We should not be altering return-values other than at top-level. @@ -620,11 +620,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) + retval, vi = DynamicPPL.init!!(demo(), VarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) + retval, vi = DynamicPPL.init!!(demo(), VarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() diff --git a/test/contexts.jl b/test/contexts.jl index ae7332a43..71f2f13b6 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -422,8 +422,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() "typed+VNV", DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), ), - ("SVI+NamedTuple", SimpleVarInfo()), - ("Svi+Dict", SimpleVarInfo(Dict{VarName,Any}())), ] @model function test_init_model() diff --git a/test/fasteval.jl b/test/fasteval.jl index db2333711..a75441c93 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -1,4 +1,4 @@ -module DynamicPPLFastLDFTests +module DynamicPPLFastEvalTests using AbstractPPL: AbstractPPL using Chairmarks @@ -6,7 +6,6 @@ using DynamicPPL using Distributions using DistributionsAD: filldist using ADTypes -using DynamicPPL.Experimental: FastLDF using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest using LinearAlgebra: I using Test @@ -14,14 +13,9 @@ using LogDensityProblems: LogDensityProblems using ForwardDiff: ForwardDiff using ReverseDiff: ReverseDiff -# Need to include this block here in case we run this test file standalone -@static if VERSION < v"1.12" - using Pkg - Pkg.add("Mooncake") - using Mooncake: Mooncake -end +using Mooncake: Mooncake -@testset "FastLDF: Correctness" begin +@testset "LogDensityFunction: Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS @testset "$varinfo_func" for varinfo_func in [ DynamicPPL.untyped_varinfo, @@ -36,7 +30,7 @@ end else unlinked_vi end - nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) + nt_ranges, dict_ranges = DynamicPPL.get_ranges_and_linked(vi) params = [x for x in vi[:]] # Iterate over all variables for vn in keys(vi) @@ -52,26 +46,6 @@ end # Check that the link status is correct @test range_with_linked.is_linked == islinked end - - # Compare results of FastLDF vs ordinary LogDensityFunction. These tests - # can eventually go once we replace LogDensityFunction with FastLDF, but - # for now it helps to have this check! (Eventually we should just check - # against manually computed log-densities). - # - # TODO(penelopeysm): I think we need to add tests for some really - # pathological models here. - @testset "$getlogdensity" for getlogdensity in ( - DynamicPPL.getlogjoint_internal, - DynamicPPL.getlogjoint, - DynamicPPL.getloglikelihood, - DynamicPPL.getlogprior_internal, - DynamicPPL.getlogprior, - ) - ldf = DynamicPPL.LogDensityFunction(m, getlogdensity, vi) - fldf = FastLDF(m, getlogdensity, vi) - @test LogDensityProblems.logdensity(ldf, params) ≈ - LogDensityProblems.logdensity(fldf, params) - end end end end @@ -86,7 +60,7 @@ end end N = 100 model = threaded(zeros(N)) - ldf = DynamicPPL.Experimental.FastLDF(model) + ldf = DynamicPPL.LogDensityFunction(model) xs = [1.0] @test LogDensityProblems.logdensity(ldf, xs) ≈ @@ -95,10 +69,10 @@ end end end -@testset "FastLDF: performance" begin +@testset "Fast evaluation: performance" begin if Threads.nthreads() == 1 - # Evaluating these three models should not lead to any allocations (but only when - # not using TSVI). + # Evaluating these three models with OnlyAccsVarInfo should not lead to any + # allocations (but only when not using TSVI). @model function f() x ~ Normal() return 1.0 ~ Normal(x) @@ -116,38 +90,56 @@ end y ~ Normal(params.m, params.s) return 1.0 ~ Normal(y) end + @testset for model in (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) - vi = VarInfo(model) - fldf = DynamicPPL.Experimental.FastLDF( - model, DynamicPPL.getlogjoint_internal, vi - ) - x = vi[:] - bench = median(@be LogDensityProblems.logdensity(fldf, x)) - @test iszero(bench.allocs) + @testset "LogDensityFunction" begin + # Performance tests on LogDensityFunction. + vi = VarInfo(model) + fldf = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint_internal, vi + ) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity(fldf, x)) + @test iszero(bench.allocs) + end + + # And for returned/logp evaluation functions. + @testset "$func" for func in (returned, logprior, loglikelihood, logjoint) + if model.f !== submodel_outer + # submodel_outer contains nested parameters, so the NamedTuple + # representation doesn't work. One day, we'll fix rand(NamedTuple, + # model) to 'work' with nested parameters. But this will require us to + # figure out submodels properly... + params_nt = rand(NamedTuple, model) + bench = median(@be func(model, params_nt)) + @test iszero(bench.allocs) + end + + # Thank goodness Dicts work... + params_dict = rand(Dict, model) + bench = median(@be func(model, params_dict)) + @test iszero(bench.allocs) + end end end end -@testset "AD with FastLDF" begin +@testset "AD with LogDensityFunction" begin # Used as the ground truth that others are compared against. ref_adtype = AutoForwardDiff() - test_adtypes = @static if VERSION < v"1.12" - [ - AutoReverseDiff(; compile=false), - AutoReverseDiff(; compile=true), - AutoMooncake(; config=nothing), - ] - else - [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] - end + test_adtypes = [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] @testset "Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS varinfo = VarInfo(m) linked_varinfo = DynamicPPL.link(varinfo, m) - f = FastLDF(m, getlogjoint_internal, linked_varinfo) + f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) x = [p for p in linked_varinfo[:]] # Calculate reference logp + gradient of logp using ForwardDiff @@ -173,7 +165,7 @@ end test_m = randn(2, 3) function eval_logp_and_grad(model, m, adtype) - ldf = FastLDF(model(); adtype=adtype) + ldf = LogDensityFunction(model(); adtype=adtype) return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) end diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index ea4ec497d..b017c658d 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -1,11 +1,12 @@ using DynamicPPL.TestUtils: DEMO_MODELS using DynamicPPL.TestUtils.AD: run_ad +using DynamicPPL: OrderedDict using ADTypes: AutoEnzyme using Test: @test, @testset import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test -ADTYPES = Dict( +ADTYPES = OrderedDict( "EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), "EnzymeReverse" => diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl deleted file mode 100644 index fbd868f71..000000000 --- a/test/logdensityfunction.jl +++ /dev/null @@ -1,49 +0,0 @@ -using Test, DynamicPPL, ADTypes, LogDensityProblems, ForwardDiff - -@testset "`getmodel` and `setmodel`" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - model = DynamicPPL.TestUtils.DEMO_MODELS[1] - ℓ = DynamicPPL.LogDensityFunction(model) - @test DynamicPPL.getmodel(ℓ) == model - @test DynamicPPL.setmodel(ℓ, model).model == model - end -end - -@testset "LogDensityFunction" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - vns = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) - - vi = first(varinfos) - theta = vi[:] - ldf_joint = DynamicPPL.LogDensityFunction(model) - @test LogDensityProblems.logdensity(ldf_joint, theta) ≈ logjoint(model, vi) - ldf_prior = DynamicPPL.LogDensityFunction(model, getlogprior) - @test LogDensityProblems.logdensity(ldf_prior, theta) ≈ logprior(model, vi) - ldf_likelihood = DynamicPPL.LogDensityFunction(model, getloglikelihood) - @test LogDensityProblems.logdensity(ldf_likelihood, theta) ≈ - loglikelihood(model, vi) - - @testset "$(varinfo)" for varinfo in varinfos - # Note use of `getlogjoint` rather than `getlogjoint_internal` here ... - logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo) - θ = varinfo[:] - # ... because it has to match with `logjoint(model, vi)`, which always returns - # the unlinked value - @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) - @test LogDensityProblems.dimension(logdensity) == length(θ) - end - end - - @testset "capabilities" begin - model = DynamicPPL.TestUtils.DEMO_MODELS[1] - ldf = DynamicPPL.LogDensityFunction(model) - @test LogDensityProblems.capabilities(typeof(ldf)) == - LogDensityProblems.LogDensityOrder{0}() - - ldf_with_ad = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff()) - @test LogDensityProblems.capabilities(typeof(ldf_with_ad)) == - LogDensityProblems.LogDensityOrder{1}() - end -end diff --git a/test/model.jl b/test/model.jl index 6da5ea246..2830a131e 100644 --- a/test/model.jl +++ b/test/model.jl @@ -27,7 +27,6 @@ end is_type_stable_varinfo(::DynamicPPL.AbstractVarInfo) = false is_type_stable_varinfo(varinfo::DynamicPPL.NTVarInfo) = true -is_type_stable_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @@ -314,7 +313,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict{VarName,Any}()))) + vi = last(DynamicPPL.init!!(model, VarInfo())) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) @@ -492,12 +491,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end model = product_dirichlet() - varinfos = [ - DynamicPPL.untyped_varinfo(model), - DynamicPPL.typed_varinfo(model), - DynamicPPL.typed_simple_varinfo(model), - DynamicPPL.untyped_simple_varinfo(model), - ] + varinfos = [DynamicPPL.untyped_varinfo(model), DynamicPPL.typed_varinfo(model)] @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos logjoint = getlogjoint(varinfo) # unlinked space varinfo_linked = DynamicPPL.link(varinfo, model) diff --git a/test/runtests.jl b/test/runtests.jl index 1474b426a..47cff58c2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ using ForwardDiff using LogDensityProblems using MacroTools using MCMCChains +using Mooncake using StableRNGs using ReverseDiff using Mooncake @@ -54,12 +55,13 @@ include("test_util.jl") include("compiler.jl") include("varnamedvector.jl") include("varinfo.jl") - include("simple_varinfo.jl") include("model.jl") include("distribution_wrappers.jl") - include("logdensityfunction.jl") include("linking.jl") include("serialization.jl") + end + + if GROUP == "All" || GROUP == "Group2" include("pointwise_logdensities.jl") include("lkj.jl") include("contexts.jl") @@ -69,9 +71,7 @@ include("test_util.jl") include("submodels.jl") include("chains.jl") include("bijector.jl") - end - - if GROUP == "All" || GROUP == "Group2" + include("fasteval.jl") @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") include("ext/DynamicPPLJETExt.jl") @@ -80,8 +80,6 @@ include("test_util.jl") @testset "ad" begin include("ext/DynamicPPLForwardDiffExt.jl") include("ext/DynamicPPLMooncakeExt.jl") - include("ad.jl") - include("fasteval.jl") end @testset "prob and logprob macro" begin @test_throws ErrorException prob"..." diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl deleted file mode 100644 index 488cb8941..000000000 --- a/test/simple_varinfo.jl +++ /dev/null @@ -1,320 +0,0 @@ -@testset "simple_varinfo.jl" begin - @testset "constructor & indexing" begin - @testset "NamedTuple" begin - svi = SimpleVarInfo(; m=1.0) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(; m=[1.0]) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(; m=(a=[1.0],)) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - - svi = SimpleVarInfo{Float32}(; m=1.0) - @test getlogjoint(svi) isa Float32 - - svi = SimpleVarInfo((m=1.0,)) - svi = accloglikelihood!!(svi, 1.0) - @test getlogjoint(svi) == 1.0 - end - - @testset "Dict" begin - svi = SimpleVarInfo(Dict(@varname(m) => 1.0)) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(Dict(@varname(m) => [1.0])) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(Dict(@varname(m) => (a=[1.0],))) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - - svi = SimpleVarInfo(Dict(@varname(m.a) => [1.0])) - # Now we only have a variable `m.a` which is subsumed by `m`, - # but we can't guarantee that we have the "entire" `m`. - @test !haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - end - - @testset "VarNamedVector" begin - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => 1.0)) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => [1.0])) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m.a) => [1.0])) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - # The implementation of haskey and getvalue fo VarNamedVector is incomplete, the - # next test is here to remind of us that. - svi = SimpleVarInfo( - push!!(DynamicPPL.VarNamedVector(), @varname(m.a.b) => [1.0]) - ) - @test_broken !haskey(svi, @varname(m.a.b.c.d)) - end - end - - @testset "link!! & invlink!! on $(nameof(model))" for model in - DynamicPPL.TestUtils.DEMO_MODELS - values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - @testset "$name" for (name, vi) in ( - ("SVI{Dict}", SimpleVarInfo(Dict{VarName,Any}())), - ("SVI{NamedTuple}", SimpleVarInfo(values_constrained)), - ("SVI{VNV}", SimpleVarInfo(DynamicPPL.VarNamedVector())), - ("TypedVarInfo", DynamicPPL.typed_varinfo(model)), - ) - for vn in DynamicPPL.TestUtils.varnames(model) - vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) - end - vi = last(DynamicPPL.evaluate!!(model, vi)) - - # Calculate ground truth - lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true( - model, values_constrained... - ) - _, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, values_constrained... - ) - - # `link!!` - vi_linked = link!!(deepcopy(vi), model) - lp_unlinked = getlogjoint(vi_linked) - lp_linked = getlogjoint_internal(vi_linked) - @test lp_linked ≈ lp_linked_true - @test lp_unlinked ≈ lp_unlinked_true - @test logjoint(model, vi_linked) ≈ lp_unlinked - - # `invlink!!` - vi_invlinked = invlink!!(deepcopy(vi_linked), model) - lp_unlinked = getlogjoint(vi_invlinked) - also_lp_unlinked = getlogjoint_internal(vi_invlinked) - @test lp_unlinked ≈ lp_unlinked_true - @test also_lp_unlinked ≈ lp_unlinked_true - @test logjoint(model, vi_invlinked) ≈ lp_unlinked - - # Should result in same values. - @test all( - DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_invlinked, vn)) ≈ - DynamicPPL.tovec(get(values_constrained, vn)) for - vn in DynamicPPL.TestUtils.varnames(model) - ) - end - end - - @testset "SimpleVarInfo on $(nameof(model))" for model in - DynamicPPL.TestUtils.DEMO_MODELS - # We might need to pre-allocate for the variable `m`, so we need - # to see whether this is the case. - svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) - svi_dict = SimpleVarInfo(VarInfo(model), Dict) - vnv = DynamicPPL.VarNamedVector() - for (k, v) in pairs(DynamicPPL.TestUtils.rand_prior_true(model)) - vnv = push!!(vnv, VarName{k}() => v) - end - svi_vnv = SimpleVarInfo(vnv) - - @testset "$name" for (name, svi) in ( - ("NamedTuple", svi_nt), - ("Dict", svi_dict), - ("VarNamedVector", svi_vnv), - # TODO(mhauru) Fix linked SimpleVarInfos to work with our test models. - # DynamicPPL.set_transformed!!(deepcopy(svi_nt), true), - # DynamicPPL.set_transformed!!(deepcopy(svi_dict), true), - # DynamicPPL.set_transformed!!(deepcopy(svi_vnv), true), - ) - # Random seed is set in each `@testset`, so we need to sample - # a new realization for `m` here. - retval = model() - - ### Sampling ### - # Sample a new varinfo! - _, svi_new = DynamicPPL.init!!(model, svi) - - # Realization for `m` should be different wp. 1. - for vn in DynamicPPL.TestUtils.varnames(model) - @test svi_new[vn] != get(retval, vn) - end - - # Logjoint should be non-zero wp. 1. - @test getlogjoint(svi_new) != 0 - - ### Evaluation ### - values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - if DynamicPPL.is_transformed(svi) - _values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( - model, values_eval_constrained... - ) - values_eval, logπ_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, values_eval_constrained... - ) - # Make sure that these two computation paths provide the same - # transformed values. - @test values_eval == _values_prior - else - logpri_true = DynamicPPL.TestUtils.logprior_true( - model, values_eval_constrained... - ) - logπ_true = DynamicPPL.TestUtils.logjoint_true( - model, values_eval_constrained... - ) - values_eval = values_eval_constrained - end - - # No logabsdet-jacobian correction needed for the likelihood. - loglik_true = DynamicPPL.TestUtils.loglikelihood_true( - model, values_eval_constrained... - ) - - # Update the realizations in `svi_new`. - svi_eval = svi_new - for vn in DynamicPPL.TestUtils.varnames(model) - svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) - end - - # Reset the logp accumulators. - svi_eval = DynamicPPL.resetaccs!!(svi_eval) - - # Compute `logjoint` using the varinfo. - logπ = logjoint(model, svi_eval) - logpri = logprior(model, svi_eval) - loglik = loglikelihood(model, svi_eval) - - # Values should not have changed. - for vn in DynamicPPL.TestUtils.varnames(model) - @test svi_eval[vn] == get(values_eval, vn) - end - - # Compare log-probability computations. - @test logpri ≈ logpri_true - @test loglik ≈ loglik_true - @test logπ ≈ logπ_true - end - end - - @testset "Dynamic constraints" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() - - # Initialize. - svi_nt = DynamicPPL.set_transformed!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.init!!(model, svi_nt)) - svi_vnv = DynamicPPL.set_transformed!!( - SimpleVarInfo(DynamicPPL.VarNamedVector()), true - ) - svi_vnv = last(DynamicPPL.init!!(model, svi_vnv)) - - for svi in (svi_nt, svi_vnv) - # Sample with large variations in unconstrained space. - for i in 1:10 - for vn in keys(svi) - svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) - end - retval, svi = DynamicPPL.evaluate!!(model, svi) - @test retval.m == svi[@varname(m)] # `m` is unconstrained - @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` - - retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, retval.m, retval.x - ) - - # Realizations from model should all be equal to the unconstrained realization. - for vn in DynamicPPL.TestUtils.varnames(model) - @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 - end - - # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogjoint_internal(svi) - # needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375 - @test lp ≈ lp_true atol = 1.2e-5 - end - end - end - - @testset "Static transformation" begin - model = DynamicPPL.TestUtils.demo_static_transformation() - - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, DynamicPPL.TestUtils.rand_prior_true(model), [@varname(s), @varname(m)] - ) - @testset "$(short_varinfo_name(vi))" for vi in varinfos - # Initialize varinfo and link. - vi_linked = DynamicPPL.link!!(vi, model) - - # Make sure `maybe_invlink_before_eval!!` results in `invlink!!`. - @test !DynamicPPL.is_transformed( - DynamicPPL.maybe_invlink_before_eval!!(deepcopy(vi), model) - ) - - # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.init!!(model, deepcopy(vi))) - @test !DynamicPPL.is_transformed(vi_result) - - # Set the values to something that is out of domain if we're in constrained space. - for vn in keys(vi) - vi_linked = DynamicPPL.setindex!!(vi_linked, -rand(), vn) - end - - # NOTE: Evaluating a linked VarInfo, **specifically when the transformation - # is static**, will result in an invlinked VarInfo. This is because of - # `maybe_invlink_before_eval!`, which only invlinks if the transformation - # is static. (src/abstract_varinfo.jl) - retval, vi_unlinked_again = DynamicPPL.evaluate!!(model, deepcopy(vi_linked)) - - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≠ - DynamicPPL.tovec(retval.s) # `s` is unconstrained in original - @test DynamicPPL.tovec( - DynamicPPL.getindex_internal(vi_unlinked_again, @varname(s)) - ) == DynamicPPL.tovec(retval.s) # `s` is constrained in result - - # `m` should not be transformed. - @test vi_linked[@varname(m)] == retval.m - @test vi_unlinked_again[@varname(m)] == retval.m - - # Get ground truths - retval_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, retval.s, retval.m - ) - lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true(model, retval.s, retval.m) - - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≈ - DynamicPPL.tovec(retval_unconstrained.s) - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(m))) ≈ - DynamicPPL.tovec(retval_unconstrained.m) - - # The unlinked varinfo should hold the unlinked logp. - lp_unlinked = getlogjoint(vi_unlinked_again) - @test getlogjoint(vi_unlinked_again) ≈ lp_unlinked_true - end - end -end diff --git a/test/test_util.jl b/test/test_util.jl index 94fdbd744..911de1079 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -25,20 +25,6 @@ function short_varinfo_name(vi::DynamicPPL.NTVarInfo) end short_varinfo_name(::DynamicPPL.UntypedVarInfo) = "UntypedVarInfo" short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo" -function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) - return "SimpleVarInfo{<:NamedTuple,<:Ref}" -end -function short_varinfo_name(::SimpleVarInfo{<:OrderedDict,<:Ref}) - return "SimpleVarInfo{<:OrderedDict,<:Ref}" -end -function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector,<:Ref}) - return "SimpleVarInfo{<:VarNamedVector,<:Ref}" -end -short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" -short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" -function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) - return "SimpleVarInfo{<:VarNamedVector}" -end # convenient functions for testing model.jl # function to modify the representation of values based on their length diff --git a/test/varinfo.jl b/test/varinfo.jl index a1a1b370f..f9ce7171f 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,17 +1,7 @@ function check_varinfo_keys(varinfo, vns) - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} - # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, - # since `keys(varinfo_merged)` only contains `VarName` with `identity`. - # So we just check that the original keys are present. - for vn in vns - # Should have all the original keys. - @test haskey(varinfo, vn) - end - else - vns_varinfo = keys(varinfo) - # Should be equivalent. - @test union(vns_varinfo, vns) == intersect(vns_varinfo, vns) - end + vns_varinfo = keys(varinfo) + # Should be equivalent. + @test union(vns_varinfo, vns) == intersect(vns_varinfo, vns) end """ @@ -100,9 +90,6 @@ end test_base(VarInfo()) test_base(DynamicPPL.typed_varinfo(VarInfo())) - test_base(SimpleVarInfo()) - test_base(SimpleVarInfo(Dict{VarName,Any}())) - test_base(SimpleVarInfo(DynamicPPL.VarNamedVector())) end @testset "get/set/acclogp" begin @@ -129,9 +116,6 @@ end vi = VarInfo() test_varinfo_logp!(vi) test_varinfo_logp!(DynamicPPL.typed_varinfo(vi)) - test_varinfo_logp!(SimpleVarInfo()) - test_varinfo_logp!(SimpleVarInfo(Dict())) - test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) end @testset "logp accumulators" begin @@ -444,19 +428,6 @@ end vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.set_transformed!!(vi, true, vn) test_linked_varinfo(model, vi) - - ### `SimpleVarInfo` - ## `SimpleVarInfo{<:NamedTuple}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(), true) - test_linked_varinfo(model, vi) - - ## `SimpleVarInfo{<:Dict}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(Dict{VarName,Any}()), true) - test_linked_varinfo(model, vi) - - ## `SimpleVarInfo{<:VarNamedVector}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - test_linked_varinfo(model, vi) end @testset "values_as" begin @@ -514,20 +485,6 @@ end model, value_true, varnames; include_threadsafe=true ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} - # NOTE: this is broken since we'll end up trying to set - # - # varinfo[@varname(x[4:5])] = [x[4],] - # - # upon linking (since `x[4:5]` will be projected onto a 1-dimensional - # space). In the case of `SimpleVarInfo{<:NamedTuple}`, this results in - # calling `setindex!!(varinfo.values, [x[4],], @varname(x[4:5]))`, which - # in turn attempts to call `setindex!(varinfo.values.x, [x[4],], 4:5)`, - # i.e. a vector of length 1 (`[x[4],]`) being assigned to 2 indices (`4:5`). - @test_broken false - continue - end - if DynamicPPL.has_varnamedvector(varinfo) && mutating # NOTE: Can't handle mutating `link!` and `invlink!` `VarNamedVector`. @test_broken false @@ -591,12 +548,6 @@ end model, (; x=1.0), (@varname(x),); include_threadsafe=true ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - # Skip the inconcrete `SimpleVarInfo` types, since checking for type - # stability for them doesn't make much sense anyway. - if varinfo isa SimpleVarInfo{<:AbstractDict} || - varinfo isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{<:AbstractDict}} - continue - end @inferred DynamicPPL.unflatten(varinfo, varinfo[:]) end end @@ -618,9 +569,6 @@ end model, model(), vns; include_threadsafe=true ) varinfos_standard = filter(Base.Fix2(isa, VarInfo), varinfos) - varinfos_simple = filter( - Base.Fix2(isa, DynamicPPL.SimpleOrThreadSafeSimple), varinfos - ) # `VarInfo` supports subsetting using, basically, arbitrary varnames. vns_supported_standard = [ @@ -648,33 +596,18 @@ end [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], ] - # `SimpleVarInfo` only supports subsetting using the varnames as they appear - # in the model. - vns_supported_simple = filter(∈(vns), vns_supported_standard) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos # All variables. check_varinfo_keys(varinfo, vns) - # Added a `convert` to make the naming of the testsets a bit more readable. - # `SimpleVarInfo{<:NamedTuple}` only supports subsetting with "simple" varnames, - ## i.e. `VarName{sym}()` without any indexing, etc. - vns_supported = - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple && - values_as(varinfo) isa NamedTuple - vns_supported_simple - else - vns_supported_standard - end - @testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in - vns_supported + vns_supported_standard varinfo_subset = subset(varinfo, VarName[]) @test isempty(varinfo_subset) end @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in - vns_supported + vns_supported_standard varinfo_subset = subset(varinfo, vns_subset) # Should now only contain the variables in `vns_subset`. check_varinfo_keys(varinfo_subset, vns_subset) @@ -709,7 +642,7 @@ end end @testset "$(convert(Vector{VarName}, vns_subset)) order" for vns_subset in - vns_supported + vns_supported_standard varinfo_subset = subset(varinfo, vns_subset) vns_subset_reversed = reverse(vns_subset) varinfo_subset_reversed = subset(varinfo, vns_subset_reversed) @@ -718,15 +651,6 @@ end @test varinfo_subset[:] == ground_truth end end - - # For certain varinfos we should have errors. - # `SimpleVarInfo{<:NamedTuple}` can only handle varnames with `identity`. - varinfo = varinfos[findfirst(Base.Fix2(isa, SimpleVarInfo{<:NamedTuple}), varinfos)] - @testset "$(short_varinfo_name(varinfo)): failure cases" begin - @test_throws ArgumentError subset( - varinfo, [@varname(s), @varname(m), @varname(x[1])] - ) - end end @testset "merge" begin