diff --git a/HISTORY.md b/HISTORY.md index 0a0abd543..04edb7395 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,175 @@ # DynamicPPL Changelog +## 0.40 + +### `VarNamedTuple` + +DynamicPPL now exports a new type, called `VarNamedTuple`, which stores values keyed by `VarName`s. +With it are exported a few new functions for using it: `map_values!!`, `map_pairs!!`, `apply!!`. +Our documentation's Internals section now has a page about `VarNamedTuple`, how it works, and what it's good for. + +`VarNamedTuple` is now used internally in many different parts: In `VarInfo`, in `values_as_in_model`, in `LogDensityFunction`, etc. +Almost all of the below changes are the consequence from switching over to using `VarNamedTuple` for various features internally. + +### Overhaul of `VarInfo` + +DynamicPPL tracks variable values during model execution using one of the `AbstractVarInfo` types. +Previously, there were many versions of them: `VarInfo`, both "typed" and "untyped", and `SimpleVarInfo` with both `NamedTuple` and `OrderedDict` as storage backends. +These have all been replaced by a rewritten implementation of `VarInfo`. +While the basics of the `VarInfo` interface remain the same, this brings with it many changes: + +#### No more many `AbstractVarInfo` types + +`SimpleVarInfo`, `untyped_varinfo`, `typed_varinfo`, and many other constructors, some exported some not, have been removed. +The remaining one is `VarInfo(...)`, which can take a model or a collection of values. +See the docstring for details. + +Some related types and functions, that weren't exported but may have been used by some, have also been removed, most notably `VarNamedVector` and its associated functions like `loosen_types!!` and `tighten_types!!`. + +#### Setting and getting values + +Previously the various `AbstractVarInfo` types had a multitude of functions for setting values: +`push!!`, `push!`, `setindex!`, `update!`, `update_internal!`, `insert_internal!`, `reset!`, etc. +These have all been replaced by three functions + + - `setindex!!` is the one to use for simply setting a variable in `VarInfo` to a known value. It works regardless of whether the variable already exists. + - `setindex_internal!!` is the one to use for setting the internal, vectorised representation of a variable. See the docstring for details. + - `setindex_with_dist!!` is to be used when you want to set a value, but choose the internal representation based on which distribution this value is a sample for. + +The order of the arguments for some of these functions has also changed, and now more closely matches the usual convention for `setindex!!`. + +Note that `setindex!` (with a single `!`) is not defined, and thus you can't do `varinfo[varname] = new_value`. + +`unflatten` works as before, but has been renamed to `unflatten!!`, since it may mutate the first argument and aliases memory with the second argument (it uses views rather than copies of the input vector). + +#### Linking is now safer + +`link!!` and `invlink!!` on `VarInfo` used to assume that the prior distribution of a variable didn't change from one execution to another (as it does in e.g. `truncated(dist; lower=x)` where `x` is a random variable). +This is no longer the case. +Linking should thus be safer to do. +The cost to pay is that calls to `link!!` and `invlink!!` (and the non-mutating versions) now trigger a model evaluation, to determine the correct priors to use. + +#### Other miscellanea + + - The `Experimental` module had functions like `Experimental.determine_suitable_varinfo` for determining which `AbstractVarInfo` type was suitable for a given model. This is now redundant and has been removed. + - `Bijectors.bijector(::Model)`, which creates a bijector from the vectorised variable space of the model to the linked variable space of the model, now has slightly different optional arguments. See the docstring for details. + - `NamedDist` no longer allows variable names with `Colon`s in them, such as `x[:]`. + +There are probably also changes to the `VarInfo` interface that we've neglected to document here, since the overhaul of `VarInfo` has been quite complete. +If anything related to `VarInfo` is behaving unexpectedly, e.g. the arguments or return type of a function seem to have changed, please check the docstring, which should be comprehensive. + +#### Performance benefits + +The purpose of this overhaul of `VarInfo` is code simplification and performance benefits. + +TODO(mhauru) Add some basic summary of what has gotten faster by how much. + +### Changes to indexing random variables with square brackets + +0.40 internally reimplements how DynamicPPL handles random variables like `x[1]`, `x.y[2,2]`, and `x[:,1:4,5]`, i.e. ones that use indexing with square brackets. +Most of this is invisible to users, but it has some effects that show on the surface. +The gist of the changes is that any indexing by square brackets is now implicitly assumed to be indexing into a regular `Base.Array`, with 1-based indexing. +The general effect this has is that the new rules on what is and isn't allowed are stricter, forbidding some old syntax that used to be allowed, and at the same time guaranteeing that it works correctly. +(Previously there were some sharp edges around these sorts of variable names.) + +#### No more linear indexing of multidimensional arrays + +Previously you could do this: + +```julia +x = Array{Float64,2}(undef, (2, 2)) +x[1] ~ Normal() +x[1, 1] ~ Normal() +``` + +Now you can't, this will error. +If you first create a variable like `x[1]`, DynamicPPL from there on assumes that this variable only takes a single index (like a `Vector`). +It will then error if you try to index the same variable with any other number of indices. + +The same logic also bans this, which likewise was previously allowed: + +```julia +x = Array{Float64,2}(undef, (2, 2)) +x[1, 1, 1] ~ Normal() +x[1, 1] ~ Normal() +``` + +This made use of Julia allowing trailing indices of `1`. + +Note that the above models were previously quite dangerous and easy to misuse, because DynamicPPL was oblivious to the fact that e.g. `x[1]` and `x[1,1]` refer to the same element. +Both of the above examples previously created 2-dimensional models, with two distinct random variables, one of which effectively overwrote the other in the model body. + +TODO(mhauru) This may cause surprising issues when using `eachindex`, which is generally encouraged, e.g. + +``` +x = Array{Float64,2}(undef, (3, 3)) +for i in eachindex(x) + x[i] ~ Normal() +end +``` + +Maybe we should fix linear indexing before releasing? + +#### No more square bracket indexing with arbitrary keys + +Previously you could do this: + +```julia +x = Dict() +x["a"] ~ Normal() +``` + +Now you can't, this will error. +This is because DynamicPPL now assumes that if you are indexing with square brackets, you are dealing with an `Array`, for which `"a"` is not a valid index. +You can still use a dictionary on the left-hand side of a `~` statement as long as the indices are valid indices to an `Array`, e.g. integers. + +#### No more unusually indexed arrays, such as `OffsetArrays` + +Previously you could do this + +```julia +using OffsetArrays +x = OffsetArray(Vector{Float64}(undef, 3), -3) +x[-2] ~ Normal() +0.0 ~ Normal(x[-2]) +``` + +Now you can't, this will error. +This is because DynamicPPL now assumes that if you are indexing with square brackes, you are dealing with an `Array`, for which `-2` is not a valid index. + +#### The above limitations are not fundamental + +The above, new restrictions to what sort of variable names are allowed aren't fundamental. +With some effort we could e.g. add support for linear indexing, this time done properly, so that e.g. `x[1,1]` and `x[1]` would be the same variable. +Likewise, we could manually add structures to support indexing into dictionaries or `OffsetArrays`. +If this would be useful to you, let us know. + +#### This only affects `~` statements + +You can still use any arbitrary indexing within your model in statements that don't involve `~`. +For instance, you can use `OffsetArray`s, or linear indexing, as long as you don't put them on the left-hand side of a `~`. + +#### Performance benefits + +The upside of all these new limitations is that models that use square bracket indexing are now faster. +For instance, take the following model + +```julia +@model function f() + x = Vector{Float64}(undef, 1000) + for i in eachindex(x) + x[i] ~ Normal() + end + return 0.0 ~ Normal(sum(x)) +end +``` + +Evaluating the log joint for this model has gotten about 3 times faster in v0.40. + +#### Robustness benefits + +TODO(mhauru) Add an example here for how this improves `condition`ing, once `condition` uses `VarNamedTuple`. + ## 0.39.12 When constructing an `MCMCChains.Chains`, sampler statistics that are not `Union{Real,Missing}` are dropped from the chain (previously this would cause chain construction to fail). diff --git a/Project.toml b/Project.toml index 5be210c50..14637330a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.39.12" +version = "0.40" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -30,7 +30,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" @@ -40,7 +39,6 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] -DynamicPPLJETExt = ["JET"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"] DynamicPPLMooncakeExt = ["Mooncake"] @@ -48,7 +46,7 @@ DynamicPPLMooncakeExt = ["Mooncake"] [compat] ADTypes = "1" AbstractMCMC = "5.10" -AbstractPPL = "0.13.1" +AbstractPPL = "0.14" Accessors = "0.1" BangBang = "0.4.1" Bijectors = "0.15.11" @@ -62,7 +60,6 @@ DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10.12, 1" InteractiveUtils = "1" -JET = "0.9, 0.10, 0.11" KernelAbstractions = "0.9.33" LinearAlgebra = "1.6" LogDensityProblems = "2" diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 523889a7a..95d905da6 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -24,7 +24,7 @@ DynamicPPL = {path = "../"} ADTypes = "1.14.0" Chairmarks = "1.3.1" Distributions = "0.25.117" -DynamicPPL = "0.39" +DynamicPPL = "0.40" Enzyme = "0.13" ForwardDiff = "1" JSON = "1.3.0" diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index e8ffa7e0b..5be32fdef 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -9,9 +9,7 @@ using StableRNGs: StableRNG rng = StableRNG(23) -colnames = [ - "Model", "Dim", "AD Backend", "VarInfo", "Linked", "t(eval)/t(ref)", "t(grad)/t(eval)" -] +colnames = ["Model", "Dim", "AD Backend", "Linked", "t(eval)/t(ref)", "t(grad)/t(eval)"] function print_results(results_table; to_json=false) if to_json # Print to the given file as JSON @@ -58,31 +56,26 @@ function run(; to_json=false) end # Specify the combinations to test: - # (Model Name, model instance, VarInfo choice, AD backend, linked) + # (Model Name, model instance, AD backend, linked) chosen_combinations = [ ( "Simple assume observe", Models.simple_assume_observe(randn(rng)), - :typed, :forwarddiff, false, ), - ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), - ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), - ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), - ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), - ("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true), - ("Multivariate 1k", multivariate1k, :typed, :mooncake, true), - ("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true), - ("Multivariate 10k", multivariate10k, :typed, :mooncake, true), - ("Dynamic", Models.dynamic(), :typed, :mooncake, true), - ("Submodel", Models.parent(randn(rng)), :typed, :mooncake, true), - ("LDA", lda_instance, :typed, :reversediff, true), + ("Smorgasbord", smorgasbord_instance, :forwarddiff, false), + ("Smorgasbord", smorgasbord_instance, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :reversediff, true), + ("Smorgasbord", smorgasbord_instance, :mooncake, true), + ("Smorgasbord", smorgasbord_instance, :enzyme, true), + ("Loop univariate 1k", loop_univariate1k, :mooncake, true), + ("Multivariate 1k", multivariate1k, :mooncake, true), + ("Loop univariate 10k", loop_univariate10k, :mooncake, true), + ("Multivariate 10k", multivariate10k, :mooncake, true), + ("Dynamic", Models.dynamic(), :mooncake, true), + ("Submodel", Models.parent(randn(rng)), :mooncake, true), + ("LDA", lda_instance, :reversediff, true), ] # Time running a model-like function that does not use DynamicPPL, as a reference point. @@ -94,13 +87,13 @@ function run(; to_json=false) @info "Reference evaluation time: $(reference_time) seconds" results_table = Tuple{ - String,Int,String,String,Bool,Union{Float64,Missing},Union{Float64,Missing} + String,Int,String,Bool,Union{Float64,Missing},Union{Float64,Missing} }[] - for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations - @info "Running benchmark for $model_name, $varinfo_choice, $adbackend, $islinked" + for (model_name, model, adbackend, islinked) in chosen_combinations + @info "Running benchmark for $model_name, $adbackend, $islinked" relative_eval_time, relative_ad_eval_time = try - results = benchmark(model, varinfo_choice, adbackend, islinked) + results = benchmark(model, adbackend, islinked) @info " t(eval) = $(results.primal_time)" @info " t(grad) = $(results.grad_time)" (results.primal_time / reference_time), @@ -115,7 +108,6 @@ function run(; to_json=false) model_name, model_dimension(model, islinked), string(adbackend), - string(varinfo_choice), islinked, relative_eval_time, relative_ad_eval_time, @@ -131,9 +123,8 @@ struct TestCase model_name::String dim::Integer ad_backend::String - varinfo::String linked::Bool - TestCase(d::Dict{String,Any}) = new((d[c] for c in colnames[1:5])...) + TestCase(d::Dict{String,Any}) = new((d[c] for c in colnames[1:4])...) end function combine(head_filename::String, base_filename::String) head_results = try @@ -148,23 +139,22 @@ function combine(head_filename::String, base_filename::String) Dict{String,Any}[] end @info "Loaded $(length(base_results)) results from $base_filename" - # Identify unique combinations of (Model, Dim, AD Backend, VarInfo, Linked) + # Identify unique combinations of (Model, Dim, AD Backend, Linked) head_testcases = Dict( - TestCase(d) => (d[colnames[6]], d[colnames[7]]) for d in head_results + TestCase(d) => (d[colnames[5]], d[colnames[6]]) for d in head_results ) base_testcases = Dict( - TestCase(d) => (d[colnames[6]], d[colnames[7]]) for d in base_results + TestCase(d) => (d[colnames[5]], d[colnames[6]]) for d in base_results ) all_testcases = union(Set(keys(head_testcases)), Set(keys(base_testcases))) @info "$(length(all_testcases)) unique test cases found" sorted_testcases = sort( - collect(all_testcases); by=(c -> (c.model_name, c.linked, c.varinfo, c.ad_backend)) + collect(all_testcases); by=(c -> (c.model_name, c.linked, c.ad_backend)) ) results_table = Tuple{ String, Int, String, - String, Bool, String, String, @@ -179,12 +169,12 @@ function combine(head_filename::String, base_filename::String) sublabels = ["base", "this PR", "speedup"] results_colnames = [ [ - EmptyCells(5), + EmptyCells(4), MultiColumn(3, "t(eval) / t(ref)"), MultiColumn(3, "t(grad) / t(eval)"), MultiColumn(3, "t(grad) / t(ref)"), ], - [colnames[1:5]..., sublabels..., sublabels..., sublabels...], + [colnames[1:4]..., sublabels..., sublabels..., sublabels...], ] sprint_float(x::Float64) = @sprintf("%.2f", x) sprint_float(m::Missing) = "err" @@ -211,7 +201,6 @@ function combine(head_filename::String, base_filename::String) c.model_name, c.dim, c.ad_backend, - c.varinfo, c.linked, sprint_float(base_eval), sprint_float(head_eval), diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 0dc7ece6e..6bb8672c9 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -1,6 +1,6 @@ module DynamicPPLBenchmarks -using DynamicPPL: VarInfo, SimpleVarInfo, VarName +using DynamicPPL: VarInfo, VarName using DynamicPPL: DynamicPPL using DynamicPPL.TestUtils.AD: run_ad, NoTest using ADTypes: ADTypes @@ -23,7 +23,7 @@ Return the dimension of `model`, accounting for linking, if any. """ function model_dimension(model, islinked) vi = VarInfo() - model(StableRNG(23), vi) + vi = last(DynamicPPL.init!!(StableRNG(23), model, vi)) if islinked vi = DynamicPPL.link(vi, model) end @@ -52,53 +52,24 @@ function to_backend(x::Union{AbstractString,Symbol}) end """ - benchmark(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool) + benchmark(model, adbackend::Symbol, islinked::Bool) -Benchmark evaluation and gradient calculation for `model` using the selected varinfo type -and AD backend. - -Available varinfo choices: - • `:untyped` → uses `DynamicPPL.untyped_varinfo(model)` - • `:typed` → uses `DynamicPPL.typed_varinfo(model)` - • `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())` - • `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs) +Benchmark evaluation and gradient calculation for `model` using the selected AD backend. The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`). `islinked` determines whether to link the VarInfo for evaluation. """ -function benchmark(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool) +function benchmark(model, adbackend::Symbol, islinked::Bool) rng = StableRNG(23) - + vi = VarInfo(rng, model) adbackend = to_backend(adbackend) - - vi = if varinfo_choice == :untyped - DynamicPPL.untyped_varinfo(rng, model) - elseif varinfo_choice == :typed - DynamicPPL.typed_varinfo(rng, model) - elseif varinfo_choice == :simple_namedtuple - SimpleVarInfo{Float64}(model(rng)) - elseif varinfo_choice == :simple_dict - retvals = model(rng) - vns = [VarName{k}() for k in keys(retvals)] - SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals)))) - elseif varinfo_choice == :typed_vector - DynamicPPL.typed_vector_varinfo(rng, model) - elseif varinfo_choice == :untyped_vector - DynamicPPL.untyped_vector_varinfo(rng, model) - else - error("Unknown varinfo choice: $varinfo_choice") - end - - adbackend = to_backend(adbackend) - if islinked vi = DynamicPPL.link(vi, model) end - return run_ad( model, adbackend; varinfo=vi, benchmark=true, test=NoTest(), verbose=false ) end -end # module +end diff --git a/benchmarks/src/Models.jl b/benchmarks/src/Models.jl index 2c881aa95..76d4b2e93 100644 --- a/benchmarks/src/Models.jl +++ b/benchmarks/src/Models.jl @@ -2,7 +2,7 @@ Models for benchmarking Turing.jl. Each model returns a NamedTuple of all the random variables in the model that are not -observed (this is used for constructing SimpleVarInfos). +observed. """ module Models diff --git a/docs/Project.toml b/docs/Project.toml index 10a4a5c8a..cad7c2c36 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -8,7 +8,6 @@ DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" @@ -16,15 +15,14 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] AbstractMCMC = "5" -AbstractPPL = "0.13" +AbstractPPL = "0.14" Accessors = "0.1" Distributions = "0.25" Documenter = "1" DocumenterMermaid = "0.1, 0.2" -DynamicPPL = "0.39" +DynamicPPL = "0.40" FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" -JET = "0.9, 0.10, 0.11" LogDensityProblems = "2" MCMCChains = "5, 6, 7" MarginalLogDensities = "0.4" diff --git a/docs/src/api.md b/docs/src/api.md index 686549e9b..5cd94fccd 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -333,25 +333,18 @@ Please see the documentation of [AbstractPPL.jl](https://github.com/TuringLang/A ### Data Structures of Variables -DynamicPPL provides different data structures used in for storing samples and accumulation of the log-probabilities, all of which are subtypes of [`AbstractVarInfo`](@ref). +DynamicPPL provides a data structure for storing samples and accumulation of the log-probabilities, called [`VarInfo`](@ref). +The interface that `VarInfo` respects is described by the abstract type [`AbstractVarInfo`](@ref). +Internally DynamicPPL also uses a couple of other subtypes of `AbstractVarInfo`. ```@docs AbstractVarInfo ``` -But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary. - -#### `VarInfo` - ```@docs VarInfo -``` - -```@docs -DynamicPPL.untyped_varinfo -DynamicPPL.typed_varinfo -DynamicPPL.untyped_vector_varinfo -DynamicPPL.typed_vector_varinfo +DynamicPPL.TransformedValue +DynamicPPL.setindex_with_dist!! ``` One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [main Turing documentation](https://turinglang.org/docs/developers/transforms/dynamicppl/). @@ -363,14 +356,18 @@ is_transformed set_transformed!! ``` -```@docs -Base.empty! -``` +#### `VarNamedTuple`s -#### `SimpleVarInfo` +`VarInfo` is only a thin wrapper around [`VarNamedTuple`](@ref), which stores arbitrary data keyed by `VarName`s. +For more details on `VarNamedTuple`, see the Internals section of our documentation. ```@docs -SimpleVarInfo +DynamicPPL.VarNamedTuples.VarNamedTuple +DynamicPPL.VarNamedTuples.vnt_size +DynamicPPL.VarNamedTuples.apply!! +DynamicPPL.VarNamedTuples.map_pairs!! +DynamicPPL.VarNamedTuples.map_values!! +DynamicPPL.VarNamedTuples.PartialArray ``` ### Accumulators @@ -416,19 +413,10 @@ accloglikelihood!! ```@docs keys getindex -push!! empty!! isempty DynamicPPL.getindex_internal -DynamicPPL.setindex_internal! -DynamicPPL.update_internal! -DynamicPPL.insert_internal! -DynamicPPL.length_internal -DynamicPPL.reset! -DynamicPPL.update! -DynamicPPL.insert! -DynamicPPL.loosen_types!! -DynamicPPL.tighten_types!! +DynamicPPL.setindex_internal!! ``` ```@docs @@ -461,7 +449,7 @@ DynamicPPL.maybe_invlink_before_eval!! ```@docs Base.merge(::AbstractVarInfo) DynamicPPL.subset -DynamicPPL.unflatten +DynamicPPL.unflatten!! ``` ### Evaluation Contexts @@ -546,15 +534,6 @@ init get_param_eltype ``` -### Choosing a suitable VarInfo - -There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.typed_varinfo`](@ref) or [`DynamicPPL.untyped_varinfo`](@ref), depending on which supports the model: - -```@docs -DynamicPPL.Experimental.determine_suitable_varinfo -DynamicPPL.Experimental.is_suitable_varinfo -``` - ### Converting VarInfos to/from chains It is a fairly common operation to want to convert a collection of `VarInfo` objects into a chains object for downstream analysis. diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md index b04913aaf..6d87e5edc 100644 --- a/docs/src/internals/varinfo.md +++ b/docs/src/internals/varinfo.md @@ -8,293 +8,45 @@ VarInfo It contains - - a `logp` field for accumulation of the log-density evaluation, and - - a `metadata` field for storing information about the realizations of the different variables. + - a `VarNamedTuple` field called `values`, + - an `AccumulatorTuple` called `accs`, to hold accumulators. -Representing `logp` is fairly straight-forward: we'll just use a `Real` or an array of `Real`, depending on the context. +`values` takes care of storing information related to values of individual random variables, while `accs` keeps track of information that we keep accumulating in the course of evaluating through a model. -**Representing `metadata` is a bit trickier**. This is supposed to contain all the necessary information for each `VarName` to enable the different executions of the model + extraction of different properties of interest after execution, e.g. the realization / value corresponding to a variable `@varname(x)`. +Variables are regonised by their `VarName`. +We want to work with `VarName`s rather than something like `Symbol` or `String` as `VarName` contains additional structural information. +For instance, a `Symbol("x[1]")` can be a result of either `var"x[1]" ~ Normal()` or `x[1] ~ Normal()`; these scenarios are disambiguated by `VarName`. +`VarName`s also allow things such as setting values for `x[1]` and `x[2]` and getting a value for `x` as a whole. -!!! note - - We want to work with `VarName` rather than something like `Symbol` or `String` as `VarName` contains additional structural information, e.g. a `Symbol("x[1]")` can be a result of either `var"x[1]" ~ Normal()` or `x[1] ~ Normal()`; these scenarios are disambiguated by `VarName`. +To ensure that `VarInfo` is simple and intuitive to work with we want it to replicate the following functionality of `Dict`: -To ensure that `VarInfo` is simple and intuitive to work with, we want `VarInfo`, and hence the underlying `metadata`, to replicate the following functionality of `Dict`: + - `keys(::VarInfo)`: return all the `VarName`s present. + - `haskey(::VarInfo)`: check if a particular `VarName` is present. + - `getindex(::VarInfo, ::VarName)`: return the realization corresponding to a particular `VarName`. + - `setindex!!(::VarInfo, val, ::VarName)`: set the realization corresponding to a particular `VarName`. + - `empty!!(::VarInfo)`: delete all data. + - `merge(::VarInfo, ::VarInfo)`: merge two containers according to similar rules as `Dict`. - - `keys(::Dict)`: return all the `VarName`s present in `metadata`. - - `haskey(::Dict)`: check if a particular `VarName` is present in `metadata`. - - `getindex(::Dict, ::VarName)`: return the realization corresponding to a particular `VarName`. - - `setindex!(::Dict, val, ::VarName)`: set the realization corresponding to a particular `VarName`. - - `push!(::Dict, ::Pair)`: add a new key-value pair to the container. - - `delete!(::Dict, ::VarName)`: delete the realization corresponding to a particular `VarName`. - - `empty!(::Dict)`: delete all realizations in `metadata`. - - `merge(::Dict, ::Dict)`: merge two `metadata` structures according to similar rules as `Dict`. +Note that we only define the BangBang methods such as `setindex!!`, rather than the mutating ones likes `setindex!`. +This is due to the design of `VarNamedTuple`, which is explained on its own page in these docs. -*But* for general-purpose samplers, we often want to work with a simple flattened structure, typically a `Vector{<:Real}`. One can access a vectorised version of a variable's value with the following vector-like functions: +*But* for general-purpose samplers, we often want to work with a simple flattened structure, typically a `Vector{<:Real}`. +One can access a vectorised version of a variable's value with the following vector-like functions: - `getindex_internal(::VarInfo, ::VarName)`: get the flattened value of a single variable. - `getindex_internal(::VarInfo, ::Colon)`: get the flattened values of all variables. - - `getindex_internal(::VarInfo, i::Int)`: get `i`th value of the flattened vector of all values - - `setindex_internal!(::VarInfo, ::AbstractVector, ::VarName)`: set the flattened value of a variable. - - `setindex_internal!(::VarInfo, val, i::Int)`: set the `i`th value of the flattened vector of all values - - `length_internal(::VarInfo)`: return the length of the flat representation of `metadata`. + - `setindex_internal!!(::VarInfo, ::AbstractVector, ::VarName)`: set the flattened value of a variable. The functions have `_internal` in their name because internally `VarInfo` always stores values as vectorised. -Moreover, a link transformation can be applied to a `VarInfo` with `link!!` (and reversed with `invlink!!`), which applies a reversible transformation to the internal storage format of a variable that makes the range of the random variable cover all of Euclidean space. `getindex_internal` and `setindex_internal!` give direct access to the vectorised value after such a transformation, which is what samplers often need to be able sample in unconstrained space. One can also manually set a transformation by giving `setindex_internal!` a fourth, optional argument, that is a function that maps internally stored value to the actual value of the variable. +Moreover, a link transformation can be applied to a `VarInfo` with `link!!` (and reversed with `invlink!!`), which applies a reversible transformation to the internal storage format of a variable that makes the range of the random variable cover all of Euclidean space. +`getindex_internal` and `setindex_internal!` give direct access to the vectorised value after such a transformation, which is what samplers often need to be able sample in unconstrained space. -Finally, we want want the underlying representation used in `metadata` to have a few performance-related properties: +Finally, we want want the underlying storage to have a few performance-related properties: 1. Type-stable when possible, but functional when not. 2. Efficient storage and iteration when possible, but functional when not. The "but functional when not" is important as we want to support arbitrary models, which means that we can't always have these performance properties. -In the following sections, we'll outline how we achieve this in [`VarInfo`](@ref). - -## Type-stability - -Ensuring type-stability is somewhat non-trivial to address since we want this to be the case even when models mix continuous (typically `Float64`) and discrete (typically `Int`) variables. - -Suppose we have an implementation of `metadata` which implements the functionality outlined in the previous section. The way we approach this in `VarInfo` is to use a `NamedTuple` with a separate `metadata` *for each distinct `Symbol` used*. For example, if we have a model of the form - -```@example varinfo-design -using DynamicPPL, Distributions, FillArrays - -@model function demo() - x ~ product_distribution(Fill(Bernoulli(0.5), 2)) - y ~ Normal(0, 1) - return nothing -end -``` - -then we construct a type-stable representation by using a `NamedTuple{(:x, :y), Tuple{Vx, Vy}}` where - - - `Vx` is a container with `eltype` `Bool`, and - - `Vy` is a container with `eltype` `Float64`. - -Since `VarName` contains the `Symbol` used in its type, something like `getindex(varinfo, @varname(x))` can be resolved to `getindex(varinfo.metadata.x, @varname(x))` at compile-time. - -For example, with the model above we have - -```@example varinfo-design -# Type-unstable `VarInfo` -varinfo_untyped = DynamicPPL.untyped_varinfo(demo()) -typeof(varinfo_untyped.metadata) -``` - -```@example varinfo-design -# Type-stable `VarInfo` -varinfo_typed = DynamicPPL.typed_varinfo(demo()) -typeof(varinfo_typed.metadata) -``` - -They both work as expected but one results in concrete typing and the other does not: - -```@example varinfo-design -varinfo_untyped[@varname(x)], varinfo_untyped[@varname(y)] -``` - -```@example varinfo-design -varinfo_typed[@varname(x)], varinfo_typed[@varname(y)] -``` - -Notice that the untyped `VarInfo` uses `Vector{Real}` to store the boolean entries while the typed uses `Vector{Bool}`. This is because the untyped version needs the underlying container to be able to handle both the `Bool` for `x` and the `Float64` for `y`, while the typed version can use a `Vector{Bool}` for `x` and a `Vector{Float64}` for `y` due to its usage of `NamedTuple`. - -!!! warning - - Of course, this `NamedTuple` approach is *not* necessarily going to help us in scenarios where the `Symbol` does not correspond to a unique type, e.g. - - ```julia - x[1] ~ Bernoulli(0.5) - x[2] ~ Normal(0, 1) - ``` - - In this case we'll end up with a `NamedTuple((:x,), Tuple{Vx})` where `Vx` is a container with `eltype` `Union{Bool, Float64}` or something worse. This is *not* type-stable but will still be functional. - - In practice, we rarely observe such mixing of types, therefore in DynamicPPL, and more widely in Turing.jl, we use a `NamedTuple` approach for type-stability with great success. - -!!! warning - - Another downside with such a `NamedTuple` approach is that if we have a model with lots of tilde-statements, e.g. `a ~ Normal()`, `b ~ Normal()`, ..., `z ~ Normal()` will result in a `NamedTuple` with 27 entries, potentially leading to long compilation times. - - For these scenarios it can be useful to fall back to "untyped" representations. - -Hence we obtain a "type-stable when possible"-representation by wrapping it in a `NamedTuple` and partially resolving the `getindex`, `setindex!`, etc. methods at compile-time. When type-stability is *not* desired, we can simply use a single `metadata` for all `VarName`s instead of a `NamedTuple` wrapping a collection of `metadata`s. - -## Efficient storage and iteration - -Efficient storage and iteration we achieve through implementation of the `metadata`. In particular, we do so with [`DynamicPPL.VarNamedVector`](@ref): - -```@docs -DynamicPPL.VarNamedVector -``` - -In a [`DynamicPPL.VarNamedVector{<:VarName,T}`](@ref), we achieve the desiderata by storing the values for different `VarName`s contiguously in a `Vector{T}` and keeping track of which ranges correspond to which `VarName`s. - -This does require a bit of book-keeping, in particular when it comes to insertions and deletions. Internally, this is handled by assigning each `VarName` a unique `Int` index in the `varname_to_index` field, which is then used to index into the following fields: - - - `varnames::Vector{<:VarName}`: the `VarName`s in the order they appear in the `Vector{T}`. - - `ranges::Vector{UnitRange{Int}}`: the ranges of indices in the `Vector{T}` that correspond to each `VarName`. - - `transforms::Vector`: the transforms associated with each `VarName`. - -Mutating functions, e.g. `setindex_internal!(vnv::VarNamedVector, val, vn::VarName)`, are then treated according to the following rules: - - 1. If `vn` is not already present: add it to the end of `vnv.varnames`, add the `val` to the underlying `vnv.vals`, etc. - - 2. If `vn` is already present in `vnv`: - - 1. If `val` has the *same length* as the existing value for `vn`: replace existing value. - 2. If `val` has a *smaller length* than the existing value for `vn`: replace existing value and mark the remaining indices as "inactive" by increasing the entry in `vnv.num_inactive` field. - 3. If `val` has a *larger length* than the existing value for `vn`: expand the underlying `vnv.vals` to accommodate the new value, update all `VarName`s occuring after `vn`, and update the `vnv.ranges` to point to the new range for `vn`. - -This means that `VarNamedVector` is allowed to grow as needed, while "shrinking" (i.e. insertion of smaller elements) is handled by simply marking the redundant indices as "inactive". This turns out to be efficient for use-cases that we are generally interested in. - -For example, we want to optimize code-paths which effectively boil down to inner-loop in the following example: - -```julia -# Construct a `VarInfo` with types inferred from `model`. -varinfo = VarInfo(model) - -# Repeatedly sample from `model`. -for _ in 1:num_samples - rand!(rng, model, varinfo) - - # Do something with `varinfo`. - # ... -end -``` - -There are typically a few scenarios where we encounter changing representation sizes of a random variable `x`: - - 1. We're working with a transformed version `x` which is represented in a lower-dimensional space, e.g. transforming a `x ~ LKJ(2, 1)` to unconstrained `y = f(x)` takes us from 2-by-2 `Matrix{Float64}` to a 1-length `Vector{Float64}`. - 2. `x` has a random size, e.g. in a mixture model with a prior on the number of components. Here the size of `x` can vary widly between every realization of the `Model`. - -In scenario (1), we're usually *shrinking* the representation of `x`, and so we end up not making any allocations for the underlying `Vector{T}` but instead just marking the redundant part as "inactive". - -In scenario (2), we end up increasing the allocated memory for the randomly sized `x`, eventually leading to a vector that is large enough to hold realizations without needing to reallocate. But this can still lead to unnecessary memory usage, which might be undesirable. Hence one has to make a decision regarding the trade-off between memory usage and performance for the use-case at hand. - -To help with this, we have the following functions: - -```@docs -DynamicPPL.has_inactive -DynamicPPL.num_inactive -DynamicPPL.num_allocated -DynamicPPL.is_contiguous -DynamicPPL.contiguify! -``` - -For example, one might encounter the following scenario: - -```@example varinfo-design -vnv = DynamicPPL.VarNamedVector(@varname(x) => [true]) -println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))") - -for i in 1:5 - x = fill(true, rand(1:100)) - DynamicPPL.update!(vnv, x, @varname(x)) - println( - "After insertion #$(i) of length $(length(x)): number of allocated entries $(DynamicPPL.num_allocated(vnv))", - ) -end -``` - -We can then insert a call to [`DynamicPPL.contiguify!`](@ref) after every insertion whenever the allocation grows too large to reduce overall memory usage: - -```@example varinfo-design -vnv = DynamicPPL.VarNamedVector(@varname(x) => [true]) -println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))") - -for i in 1:5 - x = fill(true, rand(1:100)) - DynamicPPL.update!(vnv, x, @varname(x)) - if DynamicPPL.num_allocated(vnv) > 10 - DynamicPPL.contiguify!(vnv) - end - println( - "After insertion #$(i) of length $(length(x)): number of allocated entries $(DynamicPPL.num_allocated(vnv))", - ) -end -``` - -This does incur a runtime cost as it requires re-allocation of the `ranges` in addition to a `resize!` of the underlying `Vector{T}`. However, this also ensures that the the underlying `Vector{T}` is contiguous, which is important for performance. Hence, if we're about to do a lot of work with the `VarNamedVector` without insertions, etc., it can be worth it to do a sweep to ensure that the underlying `Vector{T}` is contiguous. - -!!! note - - Higher-dimensional arrays, e.g. `Matrix`, are handled by simply vectorizing them before storing them in the `Vector{T}`, and composing the `VarName`'s transformation with a `DynamicPPL.ReshapeTransform`. - -Continuing from the example from the previous section, we can use a `VarInfo` with a `VarNamedVector` as the `metadata` field: - -```@example varinfo-design -# Type-unstable -varinfo_untyped_vnv = DynamicPPL.untyped_vector_varinfo(varinfo_untyped) -varinfo_untyped_vnv[@varname(x)], varinfo_untyped_vnv[@varname(y)] -``` - -```@example varinfo-design -# Type-stable -varinfo_typed_vnv = DynamicPPL.typed_vector_varinfo(varinfo_typed) -varinfo_typed_vnv[@varname(x)], varinfo_typed_vnv[@varname(y)] -``` - -If we now try to `delete!` `@varname(x)` - -```@example varinfo-design -haskey(varinfo_untyped_vnv, @varname(x)) -``` - -```@example varinfo-design -DynamicPPL.has_inactive(varinfo_untyped_vnv.metadata) -``` - -```@example varinfo-design -# `delete!` -DynamicPPL.delete!(varinfo_untyped_vnv.metadata, @varname(x)) -DynamicPPL.has_inactive(varinfo_untyped_vnv.metadata) -``` - -```@example varinfo-design -haskey(varinfo_untyped_vnv, @varname(x)) -``` - -Or insert a differently-sized value for `@varname(x)` - -```@example varinfo-design -DynamicPPL.insert!(varinfo_untyped_vnv.metadata, fill(true, 1), @varname(x)) -varinfo_untyped_vnv[@varname(x)] -``` - -```@example varinfo-design -DynamicPPL.num_allocated(varinfo_untyped_vnv.metadata, @varname(x)) -``` - -```@example varinfo-design -DynamicPPL.update!(varinfo_untyped_vnv.metadata, fill(true, 4), @varname(x)) -varinfo_untyped_vnv[@varname(x)] -``` - -```@example varinfo-design -DynamicPPL.num_allocated(varinfo_untyped_vnv.metadata, @varname(x)) -``` - -### Performance summary - -In the end, we have the following "rough" performance characteristics for `VarNamedVector`: - -| Method | Is blazingly fast? | -|:----------------------------------------:|:--------------------------------------------------------------------------------------------:| -| `getindex` | ${\color{green} \checkmark}$ | -| `setindex!` on a new `VarName` | ${\color{green} \checkmark}$ | -| `delete!` | ${\color{red} \times}$ | -| `update!` on existing `VarName` | ${\color{green} \checkmark}$ if smaller or same size / ${\color{red} \times}$ if larger size | -| `values_as(::VarNamedVector, Vector{T})` | ${\color{green} \checkmark}$ if contiguous / ${\color{orange} \div}$ otherwise | - -## Other methods - -```@docs -DynamicPPL.replace_raw_storage(::DynamicPPL.VarNamedVector, vals::AbstractVector) -``` - -```@docs; canonical=false -DynamicPPL.values_as(::DynamicPPL.VarNamedVector) -``` +To understand how these are achieved, we refer the reader to the documentation on `VarNamedTuple`, which underpins `VarInfo`. diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md new file mode 100644 index 000000000..daa062d2d --- /dev/null +++ b/docs/src/internals/varnamedtuple.md @@ -0,0 +1,184 @@ +# `VarNamedTuple` + +In DynamicPPL there is often a need to store data keyed by `VarName`s. +This comes up when getting conditioned variable values from the user, when tracking values of random variables in the model outputs or inputs, etc. +Historically we've had several different approaches to this: Dictionaries, `NamedTuple`s, vectors with subranges corresponding to different `VarName`s, and various combinations thereof. + +To unify the treatment of these use cases, and handle them all in a robust and performant way, is the purpose of `VarNamedTuple`, aka VNT. +It's a data structure that can store arbitrary data, indexed by (nearly) arbitrary `VarName`s, in a type stable and performant manner. + +`VarNamedTuple` consists of nested `NamedTuple`s and `PartialArray`s. +Let's first talk about the `NamedTuple` part. +This is what is needed for handling `PropertyLens`es in `VarName`s, that is, `VarName`s consisting of nested symbols, like in `@varname(a.b.c)`. +In a `VarNamedTuple` each level of such nesting of `PropertyLens`es corresponds to a level of nested `NamedTuple`s, with the `Symbol`s of the lenses as keys. +For instance, the `VarNamedTuple` mapping `@varname(x) => 1, @varname(y.z) => 2` would be stored as + +``` +VarNamedTuple(; x=1, y=VarNamedTuple(; z=2)) +``` + +where `VarNamedTuple(; x=a, y=b)` is just a thin wrapper around the `NamedTuple` `(; x=a, y=b)`. + +It's often handy to think of this as a tree, with each node being a `VarNamedTuple`, like so: + +``` + VNT +x / \ y + 1 VNT + \ z + 2 +``` + +If all `VarName`s consisted of only `PropertyLens`es we would be done designing the data structure. +However, recall that `VarName`s allow three different kinds of lenses: `PropertyLens`es, `IndexLens`es, and `identity` (the trivial lens). +The `identity` lens presents no complications, and in fact in the above example there was an implicit identity lens in e.g. `@varname(x) => 1`. +It is the `IndexLenses` that require more structure. + +An `IndexLens` is the square bracket indexing part in `VarName`s like `@varname(x[1])`, `@varname(x[1].a.b[2:3])` and `@varname(x[:].b[1,2,3].c[1:5,:])`. +`VarNamedTuple` cannot deal with `IndexLens`es in their full generality, for reasons we'll discuss below. +Instead we restrict ourselves to `IndexLens`es where the indices are integers, explicit ranges with end points, like `1:5`, or tuples thereof. + +When storing data in a `VarNamedTuple`, we recursively go through the nested lenses in the `VarName`, inserting a new `VarNamedTuple` for every `PropertyLens`. +When we meet an `IndexLens`, we instead instert into the tree something called a `PartialArray`. + +A `PartialArray` is like a regular `Base.Array`, but with some elements possibly unset. +It is a data structure we define ourselves for use within `VarNamedTuple`s. +A `PartialArray` has an element type and a number of dimensions, and they are known at compile time, but it does not have a size, and thus is not an `AbstractArray`. +This is because if we set the elements `x[1,2]` and `x[14,10]` in a `PartialArray` called `x`, this does not mean that 14 and 10 are the ends of their respective dimensions. +The typical use of this structure in DynamicPPL is that the user may define values for elements in an array-like structure one by one, and we do not always know how large these arrays are. + +This is also the reason why `PartialArray`, and by extension `VarNamedTuple`, do not support indexing by `Colon()`, i.e. `:`, as in `x[:]`. +A `Colon()` says that we should get or set all the values along that dimension, but a `PartialArray` does not know how many values there may be. +If `x[1]` and `x[4]` have been set, asking for `x[:]` is not a well-posed question. +Note however, that concretising the `VarName` resolves this ambiguity, and makes the `VarName` fine as a key to a `VarNamedTuple`. + +`PartialArray`s have other restrictions, compared to the full indexing syntax of Julia, as well: +They do not support linearly indexing into multidimemensional arrays (as in `rand(3,3)[8]`), nor indexing with arrays of indices (as in `rand(4)[[1,3]]`), nor indexing with boolean mask arrays (as in `rand(4)[[true, false, true, false]]`). +This is mostly because we haven't seen a need to support them, and implementing them would complicate the codebase for little gain. +We may add support for them later if needed. + +`PartialArray`s can hold any values, just like `Base.Array`s, and in particular they can hold `VarNamedTuple`s. +Thus we nest them with `VarNamedTuple`s to support storing `VarName`s with arbitrary combinations of `PropertyLens`es and `IndexLens`es. +A code example illustrates this the best: + +```julia +julia> vnt = VarNamedTuple(); + +julia> vnt = setindex!!(vnt, 1.0, @varname(a)); + +julia> vnt = setindex!!(vnt, [2.0, 3.0], @varname(b.c)); + +julia> vnt = setindex!!(vnt, [:hip, :hop], @varname(d.e[2].f[3:4])); + +julia> print(vnt) +VarNamedTuple(; a=1.0, b=VarNamedTuple(; c=[2.0, 3.0]), d=VarNamedTuple(; e=PartialArray{VarNamedTuple{(:f,), Tuple{DynamicPPL.VarNamedTuples.PartialArray{Symbol, 1}}},1}((2,) => VarNamedTuple(; f=PartialArray{Symbol,1}((3,) => hip, (4,) => hop))))) +``` + +The output there may be a bit hard to parse, so to illustrate: + +```julia +julia> vnt[@varname(b)] +VarNamedTuple(; c=[2.0, 3.0]) + +julia> vnt[@varname(b.c[1])] +2.0 + +julia> vnt[@varname(d.e)] +PartialArray{VarNamedTuple{(:f,), Tuple{DynamicPPL.VarNamedTuples.PartialArray{Symbol, 1}}},1}((2,) => VarNamedTuple(; f=PartialArray{Symbol,1}((3,) => hip, (4,) => hop))) + +julia> vnt[@varname(d.e[2].f)] +PartialArray{Symbol,1}((3,) => hip, (4,) => hop) +``` + +Or as a tree drawing, where `PA` marks a `PartialArray`: + +``` + /----VNT------\ +a / | b \ d + 1 VNT VNT + | c | e + [2.0, 3.0] PA(2 => VNT) + | f + PA(3 => :hip, 4 => :hop) +``` + +The above code also highlights how setting indices in a `VarNamedTuple` is done using `BangBang.setindex!!`. +We do not define a method for `Base.setindex!` at all, `setindex!!` is the only way. +This is because `VarNamedTuple` mixes mutable and immutable data structures. +It is also for user convenience: +One does not ever have to think about whether the value that one is inserting into a `VarNamedTuple` is of the right type to fit in. +Rather the containers will flex to fit it, keeping element types concrete when possible, but making them abstract if needed. +`VarNamedTuple`, or more precisely `PartialArray`, even explicitly concretises element types whenever possible. +For instance, one can make an abstractly typed `VarNamedTuple` like so: + +```julia +julia> vnt = VarNamedTuple(); + +julia> vnt = setindex!!(vnt, 1.0, @varname(a[1])); + +julia> vnt = setindex!!(vnt, "hello", @varname(a[2])); + +julia> print(vnt) +VarNamedTuple(; a=PartialArray{Any,1}((1,) => 1.0, (2,) => hello)) +``` + +Note the element type of `PartialArray{Any}`. +But if one changes the values to make them homogeneous, the element type is automatically made concrete again: + +```julia +julia> vnt = setindex!!(vnt, "me here", @varname(a[1])); + +julia> print(vnt) +VarNamedTuple(; a=PartialArray{String,1}((1,) => me here, (2,) => hello)) +``` + +This approach is at the core of why `VarNamedTuple` is performant: +As long as one does not store inhomogeneous types within a single `PartialArray`, by assigning different types to `VarName`s like `@varname(a[1])` and `@varname(a[2])`, different variables in a `VarNamedTuple` can have different types, and all `getindex` and `setindex!!` operations remain type stable. +Note that assigning a value to `@varname(a[1].b)` but not to `@varname(a[2].b)` has the same effect as assigning values of different types to `@varname(a[1])` and `@varname(a[2])`, and also causes a loss of type stability for for `getindex` and `setindex!!`. +Although, this only affects `getindex` and `setindex!!` on sub-`VarName`s of `@varname(a)`; +You can still use the same `VarNamedTuple` to store information about an unrelated `@varname(c)` with stability. + +Note that if you `setindex!!` a new value into a `VarNamedTuple` with an `IndexLens`, this causes a `PartialArray` to be created. +However, if there already is a regular `Base.Array` stored in a `VarNamedTuple`, you can index into it with `IndexLens`es without involving `PartialArray`s. +That is, if you do `vnt = setindex!!(vnt, @varname(a), [1.0, 2.0])`, you can then either get the values with e.g. `vnt[@varname(a[1])`, which returns 1.0. +You can also set the elements with `vnt = setindex!!(vnt, @varname(a[1]), 3.0)`, and this will modify the existing `Base.Array`. +At this point you can not set any new values in that array that would be outside of its range, with something like `vnt = setindex!!(vnt, @varname(a[5]), 5.0)`. +The philosophy here is that once a `Base.Array` has been attached to a `VarName`, that takes precedence, and a `PartialArray` is only used as a fallback when we are told to store a value for `@varname(a[i])` without having any previous knowledge about what `@varname(a)` is. + +## Non-Array blocks with `IndexLens`es + +The above is all that is needed for setting regular scalar values. +However, in DynamicPPL we also have a particular need for something slightly odd: +We sometimes need to do calls like `setindex!!(vnt, @varname(a[1:5]), val)` on a `val` that is _not_ an `AbstractArray`, or even iterable at all. +Normally this would error: As a scalar value with size `()`, `val` is the wrong size to be set with `@varname(a[1:5])`, which clearly wants something with size `(5,)`. +However, we want to allow this even if `val` is not an iterable, if it is some object for which `size` is well-defined, and `size(val) == (5,)`. +In DynamicPPL this comes up when storing e.g. the priors of a model, where a random variable like `@varname(a[1:5])` may be associated with a prior that is a 5-dimensional distribution. + +Internally, a `PartialArray` is just a regular `Array` with a mask saying which elements have been set. +Hence we can't store `val` directly in the same `PartialArray`: +We need it to take up a sub-block of the array, in our example case a sub-block of length 5. +To this end, internally, `PartialArray` uses a wrapper type called `ArrayLikeWrapper`, that stores `val` together with the indices that are being used to set it. +The `PartialArray` has all its corresponding elements, in our example elements 1, 2, 3, 4, and, 5, point to the same wrapper object. + +While such blocks can be stored using a wrapper like this, some care must be taken in indexing into these blocks. +For instance, after setting a block with `setindex!!(vnt, @varname(a[1:5]), val)`, we can't `getindex(vnt, @varname(a[1]))`, since we can't return "the first element of five in `val`", because `val` may not be indexable in any way. +Similarly, if next we set `setindex!!(vnt, @varname(a[1]), some_other_value)`, that should invalidate/delete the elements `@varname(a[2:5])`, since the block only makes sense as a whole. +Because of these reasons, setting and getting blocks of well-defined size like this is allowed with `VarNamedTuple`s, but _only by always using the full range_. +For instance, if `setindex!!(vnt, @varname(a[1:5]), val)` has been set, then the only valid `getindex` key to access `val` is `@varname(a[1:5])`; +Not `@varname(a[1:10])`, nor `@varname(a[3])`, nor for anything else that overlaps with `@varname(a[1:5])`. +`haskey` likewise only returns true for `@varname(a[1:5])`, and `keys(vnt)` only has that as an element. + +The size of a value, for the purposes of inserting it into a `PartialArray`, is determined by a call to `vnt_size`. +`vnt_size` falls back to calling `Base.size`. +The reason we define a distinct function is to be able to control its behaviour, if necessary, without type piracy. + +## Limitations + +This design has a several of benefits, for performance and generality, but it also has limitations: + + 1. The lack of support for `Colon`s in `VarName`s. + 2. The lack of support for some other indexing syntaxes supported by Julia, such as linear indexing and boolean indexing. + 3. `VarNamedTuple` cannot store indices with different numbers of dimensions in the same value, so for instance `@varname(a[1])` and `@varname(a[1,1])` cannot be stored in the same `VarNamedTuple`. + 4. There is an asymmetry between storing arrays with `setindex!!(vnt, array, @varname(a))` and elements of arrays with `setindex!!(vnt, element, @varname(a[i]))`. + The former stores the whole array, which can then be indexed with both `@varname(a)` and `@varname(a[i])`. + The latter stores only individual elements, and even if all elements have been set, one still can't get the value associated with `@varname(a)` as a regular `Base.Array`. diff --git a/ext/DynamicPPLChainRulesCoreExt.jl b/ext/DynamicPPLChainRulesCoreExt.jl index 12b816c60..37c9444b3 100644 --- a/ext/DynamicPPLChainRulesCoreExt.jl +++ b/ext/DynamicPPLChainRulesCoreExt.jl @@ -16,6 +16,4 @@ ChainRulesCore.@non_differentiable BangBang.push!!( # No need + causes issues for some AD backends, e.g. Zygote. ChainRulesCore.@non_differentiable DynamicPPL.infer_nested_eltype(x) -ChainRulesCore.@non_differentiable DynamicPPL.recontiguify_ranges!(ranges) - end # module diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index 222d6a3f6..cdeb6c8e6 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -12,7 +12,7 @@ using EnzymeCore ) = nothing # Enzyme errors on Gibbs sampling without this one. @inline EnzymeCore.EnzymeRules.inactive( - ::typeof(Base.haskey), ::DynamicPPL.NTVarInfo, ::DynamicPPL.VarName + ::typeof(Base.haskey), ::DynamicPPL.VarInfo, ::DynamicPPL.VarName ) = nothing end diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl deleted file mode 100644 index cb35c5ffb..000000000 --- a/ext/DynamicPPLJETExt.jl +++ /dev/null @@ -1,56 +0,0 @@ -module DynamicPPLJETExt - -using DynamicPPL: DynamicPPL -using JET: JET - -function DynamicPPL.Experimental.is_suitable_varinfo( - model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_dppl::Bool=true -) - f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo) - # If specified, we only check errors originating somewhere in the DynamicPPL.jl. - # This way we don't just fall back to untyped if the user's code is the issue. - result = if only_dppl - JET.report_call(f, argtypes; target_modules=(JET.AnyFrameModule(DynamicPPL),)) - else - JET.report_call(f, argtypes) - end - return length(JET.get_reports(result)) == 0, result -end - -function DynamicPPL.Experimental._determine_varinfo_jet( - model::DynamicPPL.Model; only_dppl::Bool=true -) - # Generate a typed varinfo to test model type stability with - varinfo = DynamicPPL.typed_varinfo(model) - - # Check type stability of evaluation (i.e. DefaultContext) - model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext()) - eval_issuccess, eval_result = DynamicPPL.Experimental.is_suitable_varinfo( - model, varinfo; only_dppl - ) - if !eval_issuccess - @debug "Evaluation with typed varinfo failed with the following issues:" - @debug eval_result - end - - # Check type stability of initialisation (i.e. InitContext) - model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) - init_issuccess, init_result = DynamicPPL.Experimental.is_suitable_varinfo( - model, varinfo; only_dppl - ) - if !init_issuccess - @debug "Initialisation with typed varinfo failed with the following issues:" - @debug init_result - end - - # If neither of them failed, we can return the typed varinfo as it's type stable. - return if (eval_issuccess && init_issuccess) - varinfo - else - # Warn the user that we can't use the type stable one. - @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." - DynamicPPL.untyped_varinfo(model) - end -end - -end diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 6c40b4c94..09dc2bad7 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -1,6 +1,7 @@ module DynamicPPLMCMCChainsExt using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC, Random +using BangBang: setindex!! using MCMCChains: MCMCChains function getindex_varname( @@ -94,7 +95,7 @@ end """ AbstractMCMC.to_samples( ::Type{DynamicPPL.ParamsWithStats}, - chain::MCMCChains.Chains + chain::MCMCChains.Chains, ) Convert an `MCMCChains.Chains` object to an array of `DynamicPPL.ParamsWithStats`. @@ -107,11 +108,11 @@ function AbstractMCMC.to_samples( idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) # Get parameters params_matrix = map(idxs) do (sample_idx, chain_idx) - d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}() + vnt = DynamicPPL.VarNamedTuple() for vn in get_varnames(chain) - d[vn] = getindex_varname(chain, sample_idx, vn, chain_idx) + vnt = setindex!!(vnt, getindex_varname(chain, sample_idx, vn, chain_idx), vn) end - d + vnt end # Statistics stats_matrix = if :internals in MCMCChains.sections(chain) @@ -176,8 +177,8 @@ end fallback=nothing, ) -Re-evaluate `model` for each sample in `chain` using the accumulators provided in `at`, -returning an matrix of `(retval, updated_at)` tuples. +Re-evaluate `model` for each sample in `chain` using the accumulators provided in `accs`, +returning a matrix of `(retval, updated_at)` tuples. This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the initialisation strategy when re-evaluating the model. For many usecases the fallback should diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index ffb5baf25..8e53d8709 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -1,6 +1,6 @@ module DynamicPPLMarginalLogDensitiesExt -using DynamicPPL: DynamicPPL, LogDensityProblems, VarName +using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, RangeAndLinked using MarginalLogDensities: MarginalLogDensities # A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by @@ -105,11 +105,9 @@ function DynamicPPL.marginalize( ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) # Determine the indices for the variables to marginalise out. varindices = mapreduce(vcat, marginalized_varnames) do vn - if DynamicPPL.getoptic(vn) === identity - ldf._iden_varname_ranges[DynamicPPL.getsym(vn)].range - else - ldf._varname_ranges[vn].range - end + # The type assertion helps in cases where the model is type unstable and thus + # `varname_ranges` may have an abstract element type. + (ldf._varname_ranges[vn]::RangeAndLinked).range end mld = MarginalLogDensities.MarginalLogDensity( LogDensityFunctionWrapper(ldf, varinfo), @@ -214,7 +212,7 @@ function DynamicPPL.VarInfo( if unmarginalized_params !== nothing full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params end - return DynamicPPL.unflatten(original_vi, full_params) + return DynamicPPL.unflatten!!(original_vi, full_params) end end diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 3b847ec15..3d6af8c3b 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -10,7 +10,7 @@ Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{ typeof(DynamicPPL._get_range_and_linked),Vararg } Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{ - typeof(Base.haskey),DynamicPPL.NTVarInfo,DynamicPPL.VarName + typeof(Base.haskey),DynamicPPL.VarInfo,DynamicPPL.VarName } end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index fda428eaa..9961125e2 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -46,7 +46,10 @@ import Base: # VarInfo export AbstractVarInfo, VarInfo, - SimpleVarInfo, + VarNamedTuple, + map_pairs!!, + map_values!!, + apply!!, AbstractAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, @@ -178,29 +181,28 @@ Abstract supertype for data structures that capture random variables when execut probabilistic model and accumulate log densities such as the log likelihood or the log joint probability of the model. -See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref). +See also: [`VarInfo`](@ref) """ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") +include("varnamedtuple.jl") +using .VarNamedTuples: VarNamedTuples, VarNamedTuple, map_pairs!!, map_values!!, apply!! include("contexts.jl") include("contexts/default.jl") include("contexts/init.jl") -include("contexts/transformation.jl") include("contexts/prefix.jl") include("contexts/conditionfix.jl") # Must come after contexts/prefix.jl include("model.jl") include("varname.jl") include("distribution_wrappers.jl") include("submodel.jl") -include("varnamedvector.jl") include("accumulators.jl") include("default_accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") -include("simple_varinfo.jl") include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") @@ -208,7 +210,6 @@ include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") -include("experimental.jl") include("chains.jl") include("bijector.jl") @@ -220,27 +221,6 @@ include("deprecated.jl") if isdefined(Base.Experimental, :register_error_hint) function __init__() - # Better error message if users forget to load JET.jl - Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ - requires_jet = - exc.f === DynamicPPL.Experimental._determine_varinfo_jet && - length(argtypes) >= 2 && - argtypes[1] <: Model && - argtypes[2] <: AbstractContext - requires_jet |= - exc.f === DynamicPPL.Experimental.is_suitable_varinfo && - length(argtypes) >= 3 && - argtypes[1] <: Model && - argtypes[2] <: AbstractContext && - argtypes[3] <: AbstractVarInfo - if requires_jet - print( - io, - "\n$(exc.f) requires JET.jl to be loaded. Please run `using JET` before calling $(exc.f).", - ) - end - end - # Same for MarginalLogDensities.jl Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ requires_mld = diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 898b6caf9..51341e3d4 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -32,6 +32,9 @@ in the execution of a given `Model`. This is in constrast to `StaticTransformation` which transforms all variables _before_ the execution of a given `Model`. +Different VarInfo types should implement their own methods for `link!!` and `invlink!!` for +`DynamicTransformation`. + See also: [`StaticTransformation`](@ref). """ struct DynamicTransformation <: AbstractTransformation end @@ -53,23 +56,6 @@ struct StaticTransformation{F} <: AbstractTransformation bijector::F end -""" - merge_transformations(transformation_left, transformation_right) - -Merge two transformations. - -The main use of this is in [`merge(::AbstractVarInfo, ::AbstractVarInfo)`](@ref). -""" -function merge_transformations(::NoTransformation, ::NoTransformation) - return NoTransformation() -end -function merge_transformations(::DynamicTransformation, ::DynamicTransformation) - return DynamicTransformation() -end -function merge_transformations(left::StaticTransformation, right::StaticTransformation) - return StaticTransformation(merge_bijectors(left.bijector, right.bijector)) -end - function merge_bijectors(left::Bijectors.NamedTransform, right::Bijectors.NamedTransform) return Bijectors.NamedTransform(merge_bijector(left.bs, right.bs)) end @@ -502,69 +488,19 @@ If no `Type` is provided, return values as stored in `varinfo`. # Examples -`SimpleVarInfo` with `NamedTuple`: - -```jldoctest -julia> data = (x = 1.0, m = [2.0]); - -julia> values_as(SimpleVarInfo(data)) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), NamedTuple) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{VarName{sym, typeof(identity)} where sym, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), Vector) -2-element Vector{Float64}: - 1.0 - 2.0 -``` - -`SimpleVarInfo` with `OrderedDict`: - -```jldoctest -julia> data = OrderedDict{Any,Any}(@varname(x) => 1.0, @varname(m) => [2.0]); - -julia> values_as(SimpleVarInfo(data)) -OrderedDict{Any, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), NamedTuple) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{Any, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), Vector) -2-element Vector{Float64}: - 1.0 - 2.0 -``` - -`VarInfo` with `NamedTuple` of `Metadata`: - ```jldoctest julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = DynamicPPL.typed_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); + vi = DynamicPPL.VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); -julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; +julia> vi = DynamicPPL.setindex!!(vi, 1.0, @varname(s)); -julia> # For the sake of brevity, let's just check the type. - md = values_as(vi); md.s isa Union{DynamicPPL.Metadata, DynamicPPL.VarNamedVector} -true +julia> vi = DynamicPPL.setindex!!(vi, 2.0, @varname(m)); julia> values_as(vi, NamedTuple) (s = 1.0, m = 2.0) julia> values_as(vi, OrderedDict) -OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries: +OrderedDict{Any, Any} with 2 entries: s => 1.0 m => 2.0 @@ -573,32 +509,6 @@ julia> values_as(vi, Vector) 1.0 2.0 ``` - -`VarInfo` with `Metadata`: - -```jldoctest -julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = DynamicPPL.untyped_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); - -julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; - -julia> # For the sake of brevity, let's just check the type. - values_as(vi) isa Union{DynamicPPL.Metadata, Vector} -true - -julia> values_as(vi, NamedTuple) -(s = 1.0, m = 2.0) - -julia> values_as(vi, OrderedDict) -OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries: - s => 1.0 - m => 2.0 - -julia> values_as(vi, Vector) -2-element Vector{Real}: - 1.0 - 2.0 -``` """ function values_as end @@ -625,13 +535,6 @@ function Base.eltype(vi::AbstractVarInfo) return eltype(T) end -""" - has_varnamedvector(varinfo::VarInfo) - -Returns `true` if `varinfo` uses `VarNamedVector` as metadata. -""" -has_varnamedvector(vi::AbstractVarInfo) = false - # TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert # the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which # might result in a `Vector{Any}`. @@ -655,20 +558,20 @@ demo (generic function with 2 methods) julia> model = demo(); -julia> varinfo = VarInfo(model); +julia> vi = VarInfo(model); -julia> keys(varinfo) +julia> keys(vi) 4-element Vector{VarName}: s m x[1] x[2] -julia> for (i, vn) in enumerate(keys(varinfo)) - varinfo[vn] = i +julia> for (i, vn) in enumerate(keys(vi)) + vi = DynamicPPL.setindex!!(vi, Float64(i), vn) end -julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] +julia> vi[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] 4-element Vector{Float64}: 1.0 2.0 @@ -676,59 +579,59 @@ julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] 4.0 julia> # Extract one with only `m`. - varinfo_subset1 = subset(varinfo, [@varname(m),]); + vi_subset1 = subset(vi, [@varname(m),]); -julia> keys(varinfo_subset1) -1-element Vector{VarName{:m, typeof(identity)}}: +julia> keys(vi_subset1) +1-element Vector{VarName}: m -julia> varinfo_subset1[@varname(m)] +julia> vi_subset1[@varname(m)] 2.0 julia> # Extract one with both `s` and `x[2]`. - varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]); + vi_subset2 = subset(vi, [@varname(s), @varname(x[2])]); -julia> keys(varinfo_subset2) +julia> keys(vi_subset2) 2-element Vector{VarName}: s x[2] -julia> varinfo_subset2[[@varname(s), @varname(x[2])]] +julia> vi_subset2[[@varname(s), @varname(x[2])]] 2-element Vector{Float64}: 1.0 4.0 ``` -`subset` is particularly useful when combined with [`merge(varinfo::AbstractVarInfo)`](@ref) +`subset` is particularly useful when combined with [`merge(vi::AbstractVarInfo)`](@ref) ```jldoctest varinfo-subset julia> # Merge the two. - varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2); + vi_subset_merged = merge(vi_subset1, vi_subset2); -julia> keys(varinfo_subset_merged) +julia> keys(vi_subset_merged) 3-element Vector{VarName}: m s x[2] -julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]] +julia> vi_subset_merged[[@varname(s), @varname(m), @varname(x[2])]] 3-element Vector{Float64}: 1.0 2.0 4.0 julia> # Merge the two with the original. - varinfo_merged = merge(varinfo, varinfo_subset_merged); + vi_merged = merge(vi, vi_subset_merged); -julia> keys(varinfo_merged) +julia> keys(vi_merged) 4-element Vector{VarName}: s m x[1] x[2] -julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] +julia> vi_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] 4-element Vector{Float64}: 1.0 2.0 @@ -824,33 +727,9 @@ See also: [`default_transformation`](@ref), [`invlink!!`](@ref). function link!!(vi::AbstractVarInfo, model::Model) return link!!(default_transformation(model, vi), vi, model) end -function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) +function link!!(vi::AbstractVarInfo, vns, model::Model) return link!!(default_transformation(model, vi), vi, vns, model) end -function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - # Note that in practice this method is only called for SimpleVarInfo, because VarInfo - # has a dedicated implementation - model = setleafcontext(model, DynamicTransformationContext{false}()) - vi = last(evaluate!!(model, vi)) - return set_transformed!!(vi, t) -end -function link!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model -) - # TODO(mhauru) This assumes that the user has defined the bijector using the same - # variable ordering as what `vi[:]` and `unflatten(vi, x)` use. This is a bad user - # interface, and it's also dangerous for any AbstractVarInfo types that may not respect - # a particular ordering, such as SimpleVarInfo{Dict}. - b = inverse(t.bijector) - x = vi[:] - y, logjac = with_logabsdet_jacobian(b, x) - # Set parameters and add the logjac term. - vi = unflatten(vi, y) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, logjac) - end - return set_transformed!!(vi, t) -end """ link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) @@ -867,7 +746,7 @@ See also: [`default_transformation`](@ref), [`invlink`](@ref). function link(vi::AbstractVarInfo, model::Model) return link(default_transformation(model, vi), vi, model) end -function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) +function link(vi::AbstractVarInfo, vns, model::Model) return link(default_transformation(model, vi), vi, vns, model) end function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) @@ -890,32 +769,9 @@ See also: [`default_transformation`](@ref), [`link!!`](@ref). function invlink!!(vi::AbstractVarInfo, model::Model) return invlink!!(default_transformation(model, vi), vi, model) end -function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) +function invlink!!(vi::AbstractVarInfo, vns, model::Model) return invlink!!(default_transformation(model, vi), vi, vns, model) end -function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) - # Note that in practice this method is only called for SimpleVarInfo, because VarInfo - # has a dedicated implementation - model = setleafcontext(model, DynamicTransformationContext{true}()) - vi = last(evaluate!!(model, vi)) - return set_transformed!!(vi, NoTransformation()) -end -function invlink!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model -) - b = t.bijector - y = vi[:] - x, inv_logjac = with_logabsdet_jacobian(b, y) - - # Mildly confusing: we need to _add_ the logjac of the inverse transform, - # because we are trying to remove the logjac of the forward transform - # that was previously accumulated when linking. - vi = unflatten(vi, x) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, inv_logjac) - end - return set_transformed!!(vi, NoTransformation()) -end """ invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) @@ -933,7 +789,7 @@ See also: [`default_transformation`](@ref), [`link`](@ref). function invlink(vi::AbstractVarInfo, model::Model) return invlink(default_transformation(model, vi), vi, model) end -function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) +function invlink(vi::AbstractVarInfo, vns, model::Model) return invlink(default_transformation(model, vi), vi, vns, model) end function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) @@ -980,12 +836,12 @@ julia> # Change the `default_transformation` for our model to be a julia> model = demo(); -julia> vi = SimpleVarInfo(x=1.0) -SimpleVarInfo((x = 1.0,), 0.0) +julia> vi = setindex!!(VarInfo(), 1.0, @varname(x)); + +julia> vi[@varname(x)] +1.0 -julia> # Uses the `inverse` of `MyBijector`, which we have defined as `identity` - vi_linked = link!!(vi, model) -Transformed SimpleVarInfo((x = 1.0,), 0.0) +julia> vi_linked = link!!(vi, model); julia> # Now performs a single `invlink!!` before model evaluation. logjoint(model, vi_linked) @@ -1013,11 +869,11 @@ end # Utilities """ - unflatten(vi::AbstractVarInfo, x::AbstractVector) + unflatten!!(vi::AbstractVarInfo, x::AbstractVector) Return a new instance of `vi` with the values of `x` assigned to the variables. """ -function unflatten end +function unflatten!! end """ to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) diff --git a/src/accumulators.jl b/src/accumulators.jl index ed5f28ec2..b39ab120c 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -118,8 +118,7 @@ See also: [`split`](@ref) """ function combine end -# TODO(mhauru) The existence of this function makes me sad. See comment in unflatten in -# src/varinfo.jl. +# TODO(mhauru) The existence of this function makes me sad. See comment in src/model.jl. """ convert_eltype(::Type{T}, acc::AbstractAccumulator) diff --git a/src/bijector.jl b/src/bijector.jl index 31fe7cd88..576205641 100644 --- a/src/bijector.jl +++ b/src/bijector.jl @@ -1,60 +1,65 @@ +struct BijectorAccumulator <: AbstractAccumulator + bijectors::Vector{Any} + sizes::Vector{Int} +end -""" - bijector(model::Model[, sym2ranges = Val(false)]) +BijectorAccumulator() = BijectorAccumulator(Bijectors.Bijector[], UnitRange{Int}[]) + +function Base.:(==)(acc1::BijectorAccumulator, acc2::BijectorAccumulator) + return (acc1.bijectors == acc2.bijectors && acc1.sizes == acc2.sizes) +end + +function Base.copy(acc::BijectorAccumulator) + return BijectorAccumulator(copy(acc.bijectors), copy(acc.sizes)) +end + +accumulator_name(::Type{<:BijectorAccumulator}) = :Bijector + +function _zero(acc::BijectorAccumulator) + return BijectorAccumulator(empty(acc.bijectors), empty(acc.sizes)) +end +reset(acc::BijectorAccumulator) = _zero(acc) +split(acc::BijectorAccumulator) = _zero(acc) +function combine(acc1::BijectorAccumulator, acc2::BijectorAccumulator) + return BijectorAccumulator( + vcat(acc1.bijectors, acc2.bijectors), vcat(acc1.sizes, acc2.sizes) + ) +end + +function accumulate_assume!!(acc::BijectorAccumulator, val, logjac, vn, right) + bijector = _compose_no_identity( + to_linked_vec_transform(right), from_vec_transform(right) + ) + push!(acc.bijectors, bijector) + push!(acc.sizes, prod(output_size(to_vec_transform(right), right); init=1)) + return acc +end + +accumulate_observe!!(acc::BijectorAccumulator, right, left, vn) = acc -Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d` -denoting the dimensionality of the latent variables. """ -function Bijectors.bijector( - model::DynamicPPL.Model, - (::Val{sym2ranges})=Val(false); - varinfo=DynamicPPL.VarInfo(model), -) where {sym2ranges} - dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) - - num_ranges = sum([ - length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) - ]) - ranges = Vector{UnitRange{Int}}(undef, num_ranges) - idx = 0 - range_idx = 1 - - # ranges might be discontinuous => values are vectors of ranges rather than just ranges - sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() - for sym in keys(varinfo.metadata) - sym_lookup[sym] = Vector{UnitRange{Int}}() - for r in varinfo.metadata[sym].ranges - ranges[range_idx] = idx .+ r - push!(sym_lookup[sym], ranges[range_idx]) - range_idx += 1 - end - - idx += varinfo.metadata[sym].ranges[end][end] - end + bijector(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - bs = map(tuple(dists...)) do d - b = Bijectors.bijector(d) - if d isa Distributions.UnivariateDistribution - b - else - # Wrap a bijector `f` such that it operates on vectors of length `prod(in_size)` - # and produces a vector of length `prod(Bijectors.output(f, in_size))`. - in_size = size(d) - vec_in_length = prod(in_size) - reshape_inner = Bijectors.Reshape((vec_in_length,), in_size) - out_size = Bijectors.output_size(b, in_size) - vec_out_length = prod(out_size) - reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,)) - reshape_outer ∘ b ∘ reshape_inner - end - end +Returns a `Stacked <: Bijector` which maps from constrained to unconstrained space. + +The input to the bijector is a vector of values for the whole model, like the input to +`unflatten!!`. These are in constrained space, i.e., respecting variable constraints. +The output is a vector of unconstrained values. - if sym2ranges - return ( - Bijectors.Stacked(bs, ranges), - (; collect(zip(keys(sym_lookup), values(sym_lookup)))...), - ) - else - return Bijectors.Stacked(bs, ranges) +`init_strategy` is passed to `DynamicPPL.init!!` to determine what values the model is +evaluated with. This may affect the results if the prior distributions or constraints of +variables are dependent on other variables. +""" +function Bijectors.bijector( + model::DynamicPPL.Model, init_strategy::AbstractInitStrategy=InitFromPrior() +) + vi = OnlyAccsVarInfo((BijectorAccumulator(),)) + vi = last(DynamicPPL.init!!(model, vi, init_strategy)) + acc = getacc(vi, Val(:Bijector)) + ranges = foldl(acc.sizes; init=UnitRange{Int}[]) do cumulant, sz + last_index = length(cumulant) > 0 ? last(cumulant).stop : 0 + push!(cumulant, (last_index + 1):(last_index + sz)) + return cumulant end + return Bijectors.Stacked(acc.bijectors, ranges) end diff --git a/src/chains.jl b/src/chains.jl index 8ce4979c6..cfd27d87a 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -5,7 +5,7 @@ A struct which contains parameter values extracted from a `VarInfo`, along with statistics associated with the VarInfo. The statistics are provided as a NamedTuple and are optional. """ -struct ParamsWithStats{P<:OrderedDict{<:VarName,<:Any},S<:NamedTuple} +struct ParamsWithStats{P<:VarNamedTuple,S<:NamedTuple} params::P stats::S end @@ -38,7 +38,6 @@ function ParamsWithStats( include_colon_eq::Bool=true, include_log_probs::Bool=true, ) - varinfo = maybe_to_typed_varinfo(varinfo) accs = if include_log_probs ( DynamicPPL.LogPriorAccumulator(), @@ -64,13 +63,6 @@ function ParamsWithStats( return ParamsWithStats(params, stats) end -# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's much faster to -# convert it to a typed varinfo first, hence this method. -# https://github.com/TuringLang/Turing.jl/issues/2604 -maybe_to_typed_varinfo(vi::UntypedVarInfo) = typed_varinfo(vi) -maybe_to_typed_varinfo(vi::UntypedVectorVarInfo) = typed_vector_varinfo(vi) -maybe_to_typed_varinfo(vi::AbstractVarInfo) = vi - """ ParamsWithStats( varinfo::AbstractVarInfo, @@ -121,7 +113,7 @@ Generate a `ParamsWithStats` by re-evaluating the given `ldf` with the provided `param_vector`. This method is intended to replace the old method of obtaining parameters and statistics -via `unflatten` plus re-evaluation. It is faster for two reasons: +via `unflatten!!` plus re-evaluation. It is faster for two reasons: 1. It does not rely on `deepcopy`-ing the VarInfo object (this used to be mandatory as otherwise re-evaluation would mutate the VarInfo, rendering it unusable for subsequent @@ -136,10 +128,7 @@ function ParamsWithStats( include_log_probs::Bool=true, ) where {Tlink} strategy = InitFromParams( - VectorWithRanges{Tlink}( - ldf._iden_varname_ranges, ldf._varname_ranges, param_vector - ), - nothing, + VectorWithRanges{Tlink}(ldf._varname_ranges, param_vector), nothing ) accs = if include_log_probs ( diff --git a/src/compiler.jl b/src/compiler.jl index 84a9a4857..daf4f59ad 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,22 +1,32 @@ const INTERNALNAMES = (:__model__, :__varinfo__) +drop_escape(x) = x +function drop_escape(expr::Expr) + Meta.isexpr(expr, :escape) && return drop_escape(expr.args[1]) + return Expr(expr.head, map(x -> drop_escape(x), expr.args)...) +end + +get_top_level_symbol(expr::Symbol) = expr +function get_top_level_symbol(expr::Expr) + # TODO(penelopeysm): what about Meta.isexpr(expr, :$)? + if Meta.isexpr(expr, :ref) + return get_top_level_symbol(expr.args[1]) + elseif Meta.isexpr(expr, :.) + return get_top_level_symbol(expr.args[1]) + else + error("unreachable") + end +end + """ need_concretize(expr) -Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or -requires a dynamic optic. +Determine whether `expr` defines a VarName that needs to be concretised. -# Examples +Note that, although we parse VarNames using our own lenses, Accessors.need_dynamic_optic is +actually still 'good enough' to determine whether we need to concretise or not. -```jldoctest; setup=:(using Accessors) -julia> DynamicPPL.need_concretize(:(x[1, :])) -true - -julia> DynamicPPL.need_concretize(:(x[1, end])) -true - -julia> DynamicPPL.need_concretize(:(x[1, 1])) -false +Eventually, we can hopefully never concretise anything. """ function need_concretize(expr) return Accessors.need_dynamic_optic(expr) || begin @@ -32,13 +42,17 @@ end """ make_varname_expression(expr) -Return a `VarName` based on `expr`, concretizing it if necessary. +Return a `VarName` based on `expr`. """ function make_varname_expression(expr) - # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact - # that in DynamicPPL we the entire function body. Instead we should be - # more selective with our escape. Until that's the case, we remove them all. - return AbstractPPL.drop_escape(varname(expr, need_concretize(expr))) + # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact that in + # DynamicPPL we the entire function body. Instead we should be more selective with our + # escape. Until that's the case, we remove them all. + # TODO(penelopeysm): We still concretise things, because PartialArray does not + # understand dynamic indices. This is not necessarily a bad thing for performance, but + # it would be nice to not NEED to have to do it. That would require shadow arrays. See + # #1194. + return drop_escape(AbstractPPL.varname(expr, need_concretize(expr))) end """ @@ -55,10 +69,9 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases: When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. -If `vn` is specified, it will be assumed to refer to a expression which -evaluates to a `VarName`, and this will be used in the subsequent checks. -If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be -used in its place. +If `vn` is specified, it will be assumed to refer to a expression which evaluates to a +`VarName`, and this will be used in the subsequent checks. If `vn` is not specified, +`(@varname \$expr)` will be used in its place. """ function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr)) return quote @@ -221,9 +234,6 @@ variables. # Example ```jldoctest; setup=:(using Distributions, LinearAlgebra) -julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); vns[end] -x[:, 2] - julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns[end] x[1, 2] @@ -241,31 +251,20 @@ end function unwrap_right_left_vns(right::NamedDist, left::AbstractMatrix, ::VarName) return unwrap_right_left_vns(right.dist, left, right.name) end -function unwrap_right_left_vns( - right::MultivariateDistribution, left::AbstractMatrix, vn::VarName -) - # This an expression such as `x .~ MvNormal()` which we interpret as - # x[:, i] ~ MvNormal() - # for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`, - # and we therefore add the `Colon()` below. - vns = map(axes(left, 2)) do i - return AbstractPPL.concretize(Accessors.IndexLens((Colon(), i)) ∘ vn, left) - end - return unwrap_right_left_vns(right, left, vns) -end function unwrap_right_left_vns( right::Union{Distribution,AbstractArray{<:Distribution}}, left::AbstractArray, vn::VarName, ) vns = map(CartesianIndices(left)) do i - return Accessors.IndexLens(Tuple(i)) ∘ vn + sym, optic = getsym(vn), getoptic(vn) + return VarName{sym}(AbstractPPL.Index(Tuple(i), (;), AbstractPPL.Iden()) ∘ optic) end return unwrap_right_left_vns(right, left, vns) end resolve_varnames(vn::VarName, _) = vn -resolve_varnames(vn::VarName, dist::NamedDist) = dist.name +resolve_varnames(::VarName, dist::NamedDist) = dist.name ################# # Main Compiler # @@ -434,14 +433,14 @@ end function generate_assign(left, right) # A statement `x := y` reduces to `x = y`, but if __varinfo__ has an accumulator for - # ValuesAsInModel then in addition we push! the pair of `x` and `y` to the accumulator. + # ValuesAsInModel then in addition we push!! the pair of `x` and `y` to the accumulator. @gensym acc right_val vn return quote $right_val = $right if $(DynamicPPL.is_extracting_values)(__varinfo__) $vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left))) __varinfo__ = $(map_accumulator!!)( - $acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel) + $acc -> push!!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel) ) end $left = $right_val @@ -463,9 +462,18 @@ function generate_tilde_literal(left, right) end end -assign_or_set!!(lhs::Symbol, rhs) = AbstractPPL.drop_escape(:($lhs = $rhs)) -function assign_or_set!!(lhs::Expr, rhs) - return AbstractPPL.drop_escape(:($BangBang.@set!! $lhs = $rhs)) +assign_or_set!!(lhs::Symbol, rhs, vn) = drop_escape(:($lhs = $rhs)) +function assign_or_set!!(lhs::Expr, rhs, vn) + left_top_sym = get_top_level_symbol(lhs) + return drop_escape( + :( + $left_top_sym = $(Accessors.set)( + $left_top_sym, + $(AbstractPPL.with_mutation)($(AbstractPPL.getoptic)($vn)), + $rhs, + ) + ), + ) end """ @@ -487,12 +495,13 @@ function generate_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $(DynamicPPL.isfixed(left, vn)) # $left may not be a simple varname, it might be x.a or x[1], in which case we - # need to use BangBang.@set!! to safely set it. + # need to use Accessors.set to safely set it. $(assign_or_set!!( left, :($(DynamicPPL.getfixed_nested)( __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) )), + vn, )) elseif $isassumption $(generate_tilde_assume(left, dist, vn)) @@ -520,7 +529,7 @@ function generate_tilde(left, right) $vn, __varinfo__, ) - $(assign_or_set!!(left, value)) + $(assign_or_set!!(left, value, vn)) $value end end @@ -531,11 +540,17 @@ function generate_tilde_assume(left, right, vn) # with multiple arguments on the LHS, we need to capture the return-values # and then update the LHS variables one by one. @gensym value - expr = :($left = $value) - if left isa Expr - expr = AbstractPPL.drop_escape( - Accessors.setmacro(BangBang.prefermutation, expr; overwrite=true) + expr = if left isa Expr # as opposed to Symbol + left_top_sym = get_top_level_symbol(left) + :( + $left_top_sym = $(Accessors.set)( + $left_top_sym, + $(AbstractPPL.with_mutation)($(AbstractPPL.getoptic)($vn)), + $value, + ) ) + else + :($left = $value) end return quote diff --git a/src/contexts.jl b/src/contexts.jl index 46c5b8855..0eccf7b53 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -25,18 +25,18 @@ Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), effectively updating the child context. # Examples -```jldoctest -julia> using DynamicPPL: DynamicTransformationContext, ConditionContext +```jldoctest; setup=:(using Random) +julia> using DynamicPPL: InitContext, ConditionContext julia> ctx = ConditionContext((; a = 1)); julia> DynamicPPL.childcontext(ctx) DefaultContext() -julia> ctx_prior = DynamicPPL.setchildcontext(ctx, DynamicTransformationContext{true}()); +julia> ctx_prior = DynamicPPL.setchildcontext(ctx, InitContext(MersenneTwister(23), InitFromPrior())); julia> DynamicPPL.childcontext(ctx_prior) -DynamicTransformationContext{true}() +InitContext{MersenneTwister, InitFromPrior}(MersenneTwister(23), InitFromPrior()) ``` """ setchildcontext @@ -60,8 +60,8 @@ in which case effectively append `right` to `left`, dropping the original leaf context of `left`. # Examples -```jldoctest -julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, DynamicTransformationContext +```jldoctest; setup=:(using Random) +julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, InitContext julia> struct ParentContext{C} <: AbstractParentContext context::C @@ -77,8 +77,8 @@ julia> ctx = ParentContext(ParentContext(DefaultContext())) ParentContext(ParentContext(DefaultContext())) julia> # Replace the leaf context with another leaf. - leafcontext(setleafcontext(ctx, DynamicTransformationContext{true}())) -DynamicTransformationContext{true}() + leafcontext(setleafcontext(ctx, InitContext(MersenneTwister(23), InitFromPrior()))) +InitContext{MersenneTwister, InitFromPrior}(MersenneTwister(23), InitFromPrior()) julia> # Append another parent context. setleafcontext(ctx, ParentContext(DefaultContext())) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 80a494c23..9899fa47a 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -56,12 +56,13 @@ used to determine whether the float type needs to be modified). In case that wasn't enough: in fact, even the above is not always true. Firstly, the accumulator argument is only true when evaluating with ThreadSafeVarInfo. See the comments -in `DynamicPPL.unflatten` for more details. For non-threadsafe evaluation, Julia is capable -of automatically promoting the types on its own. Secondly, the promotion only matters if you -are trying to directly assign into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar -tracer type, for example using `xs[i] = MyDual`. This doesn't actually apply to -tilde-statements like `xs[i] ~ ...` because those use `Accessors.@set` under the hood, which -also does the promotion for you. For the gory details, see the following issues: +in `DynamicPPL.unflatten!!` for more details. For non-threadsafe evaluation, Julia is +capable of automatically promoting the types on its own. Secondly, the promotion only +matters if you are trying to directly assign into a `Vector{Float64}` with a +`ForwardDiff.Dual` or similar tracer type, for example using `xs[i] = MyDual`. This doesn't +actually apply to tilde-statements like `xs[i] ~ ...` because those use `Accessors.set` +under the hood, which also does the promotion for you. For the gory details, see the +following issues: - https://github.com/TuringLang/DynamicPPL.jl/issues/906 for accumulator types - https://github.com/TuringLang/DynamicPPL.jl/issues/823 for type argument promotion @@ -169,7 +170,7 @@ InitFromParams(params) = InitFromParams(params, InitFromPrior()) function init( rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams{P} -) where {P<:Union{AbstractDict{<:VarName},NamedTuple}} +) where {P<:Union{AbstractDict{<:VarName},NamedTuple,VarNamedTuple}} # TODO(penelopeysm): It would be nice to do a check to make sure that all # of the parameters in `p.params` were actually used, and either warn or # error if they aren't. This is actually quite non-trivial though because @@ -206,17 +207,20 @@ an unlinked value. $(TYPEDFIELDS) """ -struct RangeAndLinked +struct RangeAndLinked{T<:Tuple} # indices that the variable corresponds to in the vectorised parameter range::UnitRange{Int} # whether it's linked is_linked::Bool + # original size of the variable before vectorisation + original_size::T end +VarNamedTuples.vnt_size(ral::RangeAndLinked) = ral.original_size + """ VectorWithRanges{Tlink}( - iden_varname_ranges::NamedTuple, - varname_ranges::Dict{VarName,RangeAndLinked}, + varname_ranges::VarNamedTuple, vect::AbstractVector{<:Real}, ) @@ -228,26 +232,14 @@ this `VectorWithRanges` are linked/not linked, or `nothing` if either the linkin not known or is mixed, i.e. some are linked while others are not. Using `nothing` does not affect functionality or correctness, but causes more work to be done at runtime, with possible impacts on type stability and performance. - -In the simplest case, this could be accomplished only with a single dictionary mapping -VarNames to ranges and link status. However, for performance reasons, we separate out -VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All -non-identity-optic VarNames are stored in the `varname_ranges` Dict. - -It would be nice to improve the NamedTuple and Dict approach. See, e.g. -https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ -struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}} - # This NamedTuple stores the ranges for identity VarNames - iden_varname_ranges::N - # This Dict stores the ranges for all other VarNames - varname_ranges::Dict{VarName,RangeAndLinked} +struct VectorWithRanges{Tlink,VNT<:VarNamedTuple,T<:AbstractVector{<:Real}} + # Ranges for all VarNames + varname_ranges::VNT # The full parameter vector which we index into to get variable values vect::T - function VectorWithRanges{Tlink}( - iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T - ) where {Tlink,N,T} + function VectorWithRanges{Tlink}(varname_ranges::VNT, vect::T) where {Tlink,VNT,T} if !(Tlink isa Union{Bool,Nothing}) throw( ArgumentError( @@ -255,17 +247,17 @@ struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}} ), ) end - return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect) + return new{Tlink,VNT,T}(varname_ranges, vect) end end -function _get_range_and_linked( - vr::VectorWithRanges, ::VarName{sym,typeof(identity)} -) where {sym} - return vr.iden_varname_ranges[sym] -end function _get_range_and_linked(vr::VectorWithRanges, vn::VarName) - return vr.varname_ranges[vn] + # The type assertion does nothing if VectorWithRanges has concrete element types, as is + # the case for all type stable models. However, if the model is not type stable, + # vr.varname_ranges[vn] may infer to have type `Any`. In this case it is helpful to + # assert that it is a RangeAndLinked, because even though it remains non-concrete, + # it'll allow the compiler to infer the types of `range` and `is_linked`. + return vr.varname_ranges[vn]::RangeAndLinked end function init( ::Random.AbstractRNG, @@ -317,67 +309,11 @@ end function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) - in_varinfo = haskey(vi, vn) val, transform = init(ctx.rng, vn, dist, ctx.strategy) - x, inv_logjac = with_logabsdet_jacobian(transform, val) - # Determine whether to insert a transformed value into the VarInfo. - # If the VarInfo alrady had a value for this variable, we will - # keep the same linked status as in the original VarInfo. If not, we - # check the rest of the VarInfo to see if other variables are linked. - # is_transformed(vi) returns true if vi is nonempty and all variables in vi - # are linked. - insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi) - val_to_insert, logjac = if insert_transformed_value - # Calculate the forward logjac and sum them up. - y, fwd_logjac = with_logabsdet_jacobian(link_transform(dist), x) - # Note that if we use VectorWithRanges with a full VarInfo, this double-Jacobian - # calculation wastes a lot of time going from linked vectorised -> unlinked -> - # linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. - # - # However, `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which - # case this branch is never hit (since `in_varinfo` will always be false). It does - # mean that the combination of InitFromParams{<:VectorWithRanges} with a full, - # linked, VarInfo will be very slow. That should never really be used, though. So - # (at least for now) we can leave this branch in for full generality with other - # combinations of init strategies / VarInfo. - # - # TODO(penelopeysm): Figure out one day how to refactor this. The crux of the issue - # is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`, - # which is NOT the same as `inverse(link_transform)` (because there is an additional - # vectorisation step). We need `init` and `tilde_assume!!` to share this information - # but it's not clear right now how to do this. In my opinion, there are a couple of - # potential ways forward: - # - # 1. Just remove metadata entirely so that there is never any need to construct - # a linked vectorised value again. This would require us to use VAIMAcc as the only - # way of getting values. I consider this the best option, but it might take a long - # time. - # - # 2. Clean up the behaviour of bijectors so that we can have a complete separation - # between the linking and vectorisation parts of it. That way, `x` can either be - # unlinked, unlinked vectorised, linked, or linked vectorised, and regardless of - # which it is, we should only need to apply at most one linking and one - # vectorisation transform. Doing so would allow us to remove the first call to - # `with_logabsdet_jacobian`, and instead compose and/or uncompose the - # transformations before calling `with_logabsdet_jacobian` once. - y, -inv_logjac + fwd_logjac - else - x, -inv_logjac - end - # Add the new value to the VarInfo. `push!!` errors if the value already - # exists, hence the need for setindex!!. - if in_varinfo - vi = setindex!!(vi, val_to_insert, vn) - else - vi = push!!(vi, vn, val_to_insert, dist) - end - # Neither of these set the `trans` flag so we have to do it manually if - # necessary. - if insert_transformed_value - vi = set_transformed!!(vi, true, vn) - end + x, init_logjac = with_logabsdet_jacobian(transform, val) + vi, logjac = setindex_with_dist!!(vi, x, dist, vn) # `accumulate_assume!!` wants untransformed values as the second argument. - vi = accumulate_assume!!(vi, x, logjac, vn, dist) + vi = accumulate_assume!!(vi, x, init_logjac + logjac, vn, dist) # We always return the untransformed value here, as that will determine # what the lhs of the tilde-statement is set to. return x, vi diff --git a/src/contexts/transformation.jl b/src/contexts/transformation.jl deleted file mode 100644 index c2eee2863..000000000 --- a/src/contexts/transformation.jl +++ /dev/null @@ -1,44 +0,0 @@ -""" - struct DynamicTransformationContext{isinverse} <: AbstractContext - -When a model is evaluated with this context, transform the accompanying `AbstractVarInfo` to -constrained space if `isinverse` or unconstrained if `!isinverse`. - -Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the -`DynamicTransformationContext` methods with more efficient implementations. -`DynamicTransformationContext` is a fallback for when we need to evaluate the model to know -how to do the transformation, used by e.g. `SimpleVarInfo`. -""" -struct DynamicTransformationContext{isinverse} <: AbstractContext end - -function tilde_assume!!( - ::DynamicTransformationContext{isinverse}, - right::Distribution, - vn::VarName, - vi::AbstractVarInfo, -) where {isinverse} - # vi[vn, right] always provides the value in unlinked space. - x = vi[vn, right] - - if is_transformed(vi, vn) - isinverse || @warn "Trying to link an already transformed variable ($vn)" - else - isinverse && @warn "Trying to invlink a non-transformed variable ($vn)" - end - - transform = isinverse ? identity : link_transform(right) - y, logjac = with_logabsdet_jacobian(transform, x) - vi = accumulate_assume!!(vi, x, logjac, vn, right) - vi = setindex!!(vi, y, vn) - return x, vi -end - -function tilde_observe!!( - ::DynamicTransformationContext, - right::Distribution, - left, - vn::Union{VarName,Nothing}, - vi::AbstractVarInfo, -) - return tilde_observe!!(DefaultContext(), right, left, vn, vi) -end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 8810b9819..79e625e36 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -3,7 +3,6 @@ module DebugUtils using ..DynamicPPL using Random: Random -using Accessors: Accessors using InteractiveUtils: InteractiveUtils using DocStringExtensions diff --git a/src/experimental.jl b/src/experimental.jl deleted file mode 100644 index 8c82dca68..000000000 --- a/src/experimental.jl +++ /dev/null @@ -1,98 +0,0 @@ -module Experimental - -using DynamicPPL: DynamicPPL - -# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. -""" - is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) - -Check if the `model` supports evaluation using the provided `varinfo`. - -!!! warning - Loading JET.jl is required before calling this function. - -# Arguments -- `model`: The model to verify the support for. -- `varinfo`: The varinfo to verify the support for. - -# Keyword Arguments -- `only_dppl`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`. - -# Returns -- `issuccess`: `true` if the model supports the varinfo, otherwise `false`. -- `report`: The result of `report_call` from JET.jl. -""" -function is_suitable_varinfo end - -# Internal hook for JET.jl to overload. -function _determine_varinfo_jet end - -""" - determine_suitable_varinfo(model; only_dppl::Bool=true) - -Return a suitable varinfo for the given `model`. - -See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref). - -!!! warning - For full functionality, this requires JET.jl to be loaded. - If JET.jl is not loaded, this function will assume the model is compatible with typed varinfo. - -# Arguments -- `model`: The model for which to determine the varinfo. - -# Keyword Arguments -- `only_dppl`: If `true`, only consider error reports within DynamicPPL.jl. - -# Examples - -```jldoctest -julia> using DynamicPPL.Experimental: determine_suitable_varinfo - -julia> using JET: JET # needs to be loaded for full functionality - -julia> @model function model_with_random_support() - x ~ Bernoulli() - if x - y ~ Normal() - else - z ~ Normal() - end - end -model_with_random_support (generic function with 2 methods) - -julia> model = model_with_random_support(); - -julia> # Typed varinfo cannot handle this random support model properly - # as using a single execution of the model will not see all random variables. - # Hence, this this model requires untyped varinfo. - vi = determine_suitable_varinfo(model); -┌ Warning: Model seems incompatible with typed varinfo. Falling back to untyped varinfo. -└ @ DynamicPPLJETExt ~/.julia/dev/DynamicPPL.jl/ext/DynamicPPLJETExt.jl:48 - -julia> vi isa typeof(DynamicPPL.untyped_varinfo(model)) -true - -julia> # In contrast, a simple model with no random support can be handled by typed varinfo. - @model model_with_static_support() = x ~ Normal() -model_with_static_support (generic function with 2 methods) - -julia> vi = determine_suitable_varinfo(model_with_static_support()); - -julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support())) -true -``` -""" -function determine_suitable_varinfo(model::DynamicPPL.Model; only_dppl::Bool=true) - # If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that. - return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing - _determine_varinfo_jet(model; only_dppl) - else - # Warn the user. - @warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo." - # Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat). - DynamicPPL.typed_varinfo(model, context) - end -end - -end diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 8c7b5f7db..def2b7756 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -121,8 +121,7 @@ julia> length(extract_priors(rng, model)[@varname(x)]) extract_priors(args::Union{Model,AbstractVarInfo}...) = extract_priors(Random.default_rng(), args...) function extract_priors(rng::Random.AbstractRNG, model::Model) - varinfo = VarInfo() - varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),)) + varinfo = OnlyAccsVarInfo((PriorDistributionAccumulator(),)) varinfo = last(init!!(rng, model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 5a61eb531..930e89ba3 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -13,8 +13,6 @@ using DynamicPPL: OnlyAccsVarInfo, RangeAndLinked, VectorWithRanges, - Metadata, - VarNamedVector, default_accumulators, float_type_with_fallback, getlogjoint, @@ -98,11 +96,12 @@ from: Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a given set of parameters: -1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters - inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. +1. With `unflatten!!` + `evaluate!!` with `DefaultContext`: this stores a vector of + parameters inside a VarInfo's metadata, then reads parameter values from the VarInfo + during evaluation. -2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores - them inside a VarInfo's metadata. +2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and + stores them inside a VarInfo's metadata. In general, both of these approaches work fine, but the fact that they modify the VarInfo's metadata can often be quite wasteful. In particular, it is very common that the only outputs @@ -124,7 +123,7 @@ In particular, it is not clear: - which parts of the vector correspond to which random variables, and - whether the variables are linked or unlinked. -Traditionally, this problem has been solved by `unflatten`, because that function would +Traditionally, this problem has been solved by `unflatten!!`, because that function would place values into the VarInfo's metadata alongside the information about ranges and linking. That way, when we evaluate with `DefaultContext`, we can read this information out again. However, we want to avoid using a metadata. Thus, here, we _extract this information from @@ -132,21 +131,16 @@ the VarInfo_ a single time when constructing a `LogDensityFunction` object. Insi LogDensityFunction, we store a mapping from VarNames to ranges in that vector, along with link status. -For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all -other VarNames, this is stored in a Dict. The internal data structure used to represent this -could almost certainly be optimised further. See e.g. the discussion in -https://github.com/TuringLang/DynamicPPL.jl/issues/1116. - -When evaluating the model, this allows us to combine the parameter vector together with those -ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quickly read -parameter values from the vector. +When evaluating the model, this allows us to combine the parameter vector together with +those ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quickly +read parameter values from the vector. Note that this assumes that the ranges and link status are static throughout the lifetime of the `LogDensityFunction` object. Therefore, a `LogDensityFunction` object cannot handle models which have variable numbers of parameters, or models which may visit random variables in different orders depending on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a general limitation of vectorised parameters: the original -`unflatten` + `evaluate!!` approach also fails with such models. +`unflatten!!` + `evaluate!!` approach also fails with such models. """ struct LogDensityFunction{ # true if all variables are linked; false if all variables are unlinked; nothing if @@ -155,7 +149,7 @@ struct LogDensityFunction{ M<:Model, AD<:Union{ADTypes.AbstractADType,Nothing}, F, - N<:NamedTuple, + VNT<:VarNamedTuple, ADP<:Union{Nothing,DI.GradientPrep}, # type of the vector passed to logdensity functions X<:AbstractVector, @@ -164,8 +158,7 @@ struct LogDensityFunction{ model::M adtype::AD _getlogdensity::F - _iden_varname_ranges::N - _varname_ranges::Dict{VarName,RangeAndLinked} + _varname_ranges::VNT _adprep::ADP _dim::Int _accs::AC @@ -185,14 +178,11 @@ struct LogDensityFunction{ ) # Figure out which variable corresponds to which index, and # which variables are linked. - all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + all_ranges = get_ranges_and_linked(varinfo) # Figure out if all variables are linked, unlinked, or mixed link_statuses = Bool[] - for ral in all_iden_ranges - push!(link_statuses, ral.is_linked) - end - for (_, ral) in all_ranges - push!(link_statuses, ral.is_linked) + for vn in keys(all_ranges) + push!(link_statuses, all_ranges[vn].is_linked) end Tlink = if all(link_statuses) true @@ -211,7 +201,7 @@ struct LogDensityFunction{ else # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) - args = (model, getlogdensity, all_iden_ranges, all_ranges, accs) + args = (model, getlogdensity, all_ranges, accs) if _use_closure(adtype) DI.prepare_gradient(LogDensityAt{Tlink}(args...), adtype, x) else @@ -229,12 +219,12 @@ struct LogDensityFunction{ typeof(model), typeof(adtype), typeof(getlogdensity), - typeof(all_iden_ranges), + typeof(all_ranges), typeof(prep), typeof(x), typeof(accs), }( - model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim, accs + model, adtype, getlogdensity, all_ranges, prep, dim, accs ) end end @@ -269,8 +259,7 @@ ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulato ::Val{Tlink}, model::Model, getlogdensity::Any, - iden_varname_ranges::NamedTuple, - varname_ranges::Dict{VarName,RangeAndLinked}, + varname_ranges::VarNamedTuple, ) where {Tlink} Calculate the log density at the given `params`, using the provided @@ -281,13 +270,10 @@ function logdensity_at( ::Val{Tlink}, model::Model, getlogdensity::Any, - iden_varname_ranges::NamedTuple, - varname_ranges::Dict{VarName,RangeAndLinked}, + varname_ranges::VarNamedTuple, accs::AccumulatorTuple, ) where {Tlink} - strategy = InitFromParams( - VectorWithRanges{Tlink}(iden_varname_ranges, varname_ranges, params), nothing - ) + strategy = InitFromParams(VectorWithRanges{Tlink}(varname_ranges, params), nothing) _, vi = DynamicPPL.init!!(model, OnlyAccsVarInfo(accs), strategy) return getlogdensity(vi) end @@ -296,8 +282,7 @@ end LogDensityAt{Tlink}( model::Model, getlogdensity::Any, - iden_varname_ranges::NamedTuple, - varname_ranges::Dict{VarName,RangeAndLinked}, + varname_ranges::VarNamedTuple, accs::AccumulatorTuple, ) where {Tlink} @@ -305,34 +290,21 @@ A callable struct that behaves in the same way as `logdensity_at`, but stores th other information internally. Having two separate functions/structs allows for better performance with AD backends. """ -struct LogDensityAt{Tlink,M<:Model,F,N<:NamedTuple,A<:AccumulatorTuple} +struct LogDensityAt{Tlink,M<:Model,F,V<:VarNamedTuple,A<:AccumulatorTuple} model::M getlogdensity::F - iden_varname_ranges::N - varname_ranges::Dict{VarName,RangeAndLinked} + varname_ranges::V accs::A function LogDensityAt{Tlink}( - model::M, - getlogdensity::F, - iden_varname_ranges::N, - varname_ranges::Dict{VarName,RangeAndLinked}, - accs::A, - ) where {Tlink,M,F,N,A} - return new{Tlink,M,F,N,A}( - model, getlogdensity, iden_varname_ranges, varname_ranges, accs - ) + model::M, getlogdensity::F, varname_ranges::V, accs::A + ) where {Tlink,M,F,V,A} + return new{Tlink,M,F,V,A}(model, getlogdensity, varname_ranges, accs) end end function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink} return logdensity_at( - params, - Val{Tlink}(), - f.model, - f.getlogdensity, - f.iden_varname_ranges, - f.varname_ranges, - f.accs, + params, Val{Tlink}(), f.model, f.getlogdensity, f.varname_ranges, f.accs ) end @@ -340,13 +312,7 @@ function LogDensityProblems.logdensity( ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real} ) where {Tlink} return logdensity_at( - params, - Val{Tlink}(), - ldf.model, - ldf._getlogdensity, - ldf._iden_varname_ranges, - ldf._varname_ranges, - ldf._accs, + params, Val{Tlink}(), ldf.model, ldf._getlogdensity, ldf._varname_ranges, ldf._accs ) end @@ -359,11 +325,7 @@ function LogDensityProblems.logdensity_and_gradient( return if _use_closure(ldf.adtype) DI.value_and_gradient( LogDensityAt{Tlink}( - ldf.model, - ldf._getlogdensity, - ldf._iden_varname_ranges, - ldf._varname_ranges, - ldf._accs, + ldf.model, ldf._getlogdensity, ldf._varname_ranges, ldf._accs ), ldf._adprep, ldf.adtype, @@ -378,7 +340,6 @@ function LogDensityProblems.logdensity_and_gradient( DI.Constant(Val{Tlink}()), DI.Constant(ldf.model), DI.Constant(ldf._getlogdensity), - DI.Constant(ldf._iden_varname_ranges), DI.Constant(ldf._varname_ranges), DI.Constant(ldf._accs), ) @@ -457,73 +418,29 @@ _use_closure(::ADTypes.AbstractADType) = false # Helper functions to extract ranges and link status # ###################################################### -# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The -# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges -# and link status. So there is no motivation to use SimpleVarInfo inside a -# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue -# that there is no purpose in supporting untyped VarInfo either. """ get_ranges_and_linked(varinfo::VarInfo) Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter representation, along with whether each variable is linked or unlinked. -This function should return a tuple containing: - -- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` -- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. +This function returns a VarNamedTuple mapping all VarNames to their corresponding +`RangeAndLinked`. """ -function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = 1 - for sym in syms - md = varinfo.metadata[sym] - this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) - all_iden_ranges = merge(all_iden_ranges, this_md_iden) - all_ranges = merge(all_ranges, this_md_others) - end - return all_iden_ranges, all_ranges -end -function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) - all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) - return all_iden, all_others -end -function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = start_offset - for (vn, idx) in md.idcs - is_linked = md.is_transformed[idx] - range = md.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += length(range) - end - return all_iden_ranges, all_ranges, offset -end -function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = start_offset - for (vn, idx) in vnv.varname_to_index - is_linked = vnv.is_unconstrained[idx] - range = vnv.ranges[idx] .+ (start_offset - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += length(range) - end - return all_iden_ranges, all_ranges, offset +function get_ranges_and_linked(vi::VarInfo) + vnt, _ = mapreduce( + identity, + function ((vnt, offset), pair) + vn, tv = pair + val = tv.val + range = offset:(offset + length(val) - 1) + offset += length(val) + ral = RangeAndLinked(range, is_transformed(tv), tv.size) + vnt = setindex!!(vnt, ral, vn) + return vnt, offset + end, + vi.values; + init=(VarNamedTuple(), 1), + ) + return vnt end diff --git a/src/model.jl b/src/model.jl index b7f797944..f144880d3 100644 --- a/src/model.jl +++ b/src/model.jl @@ -501,19 +501,19 @@ true julia> # Since we conditioned on `a.m`, it is not treated as a random variable. # However, `a.x` will still be a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: +1-element Vector{VarName}: a.x julia> # We can also condition on `a.m` _outside_ of the PrefixContext: cm = condition(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); julia> conditioned(cm) -Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: +Dict{VarName{:a, AbstractPPL.Property{:m, AbstractPPL.Iden}}, Float64} with 1 entry: a.m => 1.0 julia> # Now `a.x` will be sampled. keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: +1-element Vector{VarName}: a.x ``` """ @@ -833,25 +833,25 @@ julia> # Returns all the variables we have fixed on + their values. (x = 100.0, m = 1.0) julia> # The rest of this is the same as the `condition` example above. - cm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0); + fm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0); -julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)]) +julia> Set(keys(fixed(fm))) == Set([@varname(a.m), @varname(x)]) true -julia> keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: +julia> keys(VarInfo(fm)) +1-element Vector{VarName}: a.x -julia> # We can also condition on `a.m` _outside_ of the PrefixContext: - cm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); +julia> # We can also fix `a.m` _outside_ of the PrefixContext: + fm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); -julia> fixed(cm) -Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: +julia> fixed(fm) +Dict{VarName{:a, AbstractPPL.Property{:m, AbstractPPL.Iden}}, Float64} with 1 entry: a.m => 1.0 julia> # Now `a.x` will be sampled. - keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + keys(VarInfo(fm)) +1-element Vector{VarName}: a.x ``` """ @@ -1087,7 +1087,9 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()))) + vi = VarInfo() + vi = setaccs!!(vi, DynamicPPL.AccumulatorTuple()) + x = last(init!!(rng, model, vi)) return values_as(x, T) end @@ -1240,6 +1242,117 @@ function Distributions.loglikelihood(model::Model, params) return getloglikelihood(last(init!!(model, vi, ctx))) end +""" + logjoint(model::Model, values::Union{NamedTuple,AbstractDict}) + +Return the log joint probability of variables `values` for the probabilistic `model`. + +See [`logprior`](@ref) and [`loglikelihood`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + logjoint(demo([1.0]), (m = 100.0, )) +-9902.33787706641 + +julia> # Using a `OrderedDict`. + logjoint(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-9902.33787706641 + +julia> # Truth. + logpdf(Normal(100.0, 1.0), 1.0) + logpdf(Normal(), 100.0) +-9902.33787706641 +``` +""" +function logjoint(model::Model, values::Union{NamedTuple,AbstractDict}) + accs = AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) + vi = OnlyAccsVarInfo(accs) + _, vi = DynamicPPL.init!!(model, vi, InitFromParams(values, nothing)) + return getlogjoint(vi) +end + +""" + logprior(model::Model, values::Union{NamedTuple,AbstractDict}) + +Return the log prior probability of variables `values` for the probabilistic `model`. + +See also [`logjoint`](@ref) and [`loglikelihood`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + logprior(demo([1.0]), (m = 100.0, )) +-5000.918938533205 + +julia> # Using a `OrderedDict`. + logprior(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-5000.918938533205 + +julia> # Truth. + logpdf(Normal(), 100.0) +-5000.918938533205 +``` +""" +function logprior(model::Model, values::Union{NamedTuple,AbstractDict}) + accs = AccumulatorTuple((LogPriorAccumulator(),)) + vi = OnlyAccsVarInfo(accs) + _, vi = DynamicPPL.init!!(model, vi, InitFromParams(values, nothing)) + return getlogprior(vi) +end + +""" + loglikelihood(model::Model, values::Union{NamedTuple,AbstractDict}) + +Return the log likelihood of variables `values` for the probabilistic `model`. + +See also [`logjoint`](@ref) and [`logprior`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + loglikelihood(demo([1.0]), (m = 100.0, )) +-4901.418938533205 + +julia> # Using a `OrderedDict`. + loglikelihood(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-4901.418938533205 + +julia> # Truth. + logpdf(Normal(100.0, 1.0), 1.0) +-4901.418938533205 +``` +""" +function Distributions.loglikelihood(model::Model, values::Union{NamedTuple,AbstractDict}) + accs = AccumulatorTuple((LogLikelihoodAccumulator(),)) + vi = OnlyAccsVarInfo(accs) + _, vi = DynamicPPL.init!!(model, vi, InitFromParams(values, nothing)) + return getloglikelihood(vi) +end + # Implemented & documented in DynamicPPLMCMCChainsExt function predict end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl deleted file mode 100644 index ec02f4c94..000000000 --- a/src/simple_varinfo.jl +++ /dev/null @@ -1,549 +0,0 @@ -""" - $(TYPEDEF) - -A simple wrapper of the parameters with a `logp` field for -accumulation of the logdensity. - -Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict`. - -# Fields -$(FIELDS) - -# Notes -The major differences between this and `NTVarInfo` are: -1. `SimpleVarInfo` does not require linearization. -2. `SimpleVarInfo` can use more efficient bijectors. -3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either - a) no indexing is used in tilde-statements, or - b) the values have been specified with the correct shapes. - -# Examples -## General usage -```jldoctest simplevarinfo-general; setup=:(using Distributions) -julia> using StableRNGs - -julia> @model function demo() - m ~ Normal() - x = Vector{Float64}(undef, 2) - for i in eachindex(x) - x[i] ~ Normal() - end - return x - end -demo (generic function with 2 methods) - -julia> m = demo(); - -julia> rng = StableRNG(42); - -julia> # In the `NamedTuple` version we need to provide the place-holder values for - # the variables which are using "containers", e.g. `Array`. - # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo((x = ones(2), ))); - -julia> # (✓) Vroom, vroom! FAST!!! - vi[@varname(x[1])] -0.4471218424633827 - -julia> # We can also access arbitrary varnames pointing to `x`, e.g. - vi[@varname(x)] -2-element Vector{Float64}: - 0.4471218424633827 - 1.3736306979834252 - -julia> vi[@varname(x[1:2])] -2-element Vector{Float64}: - 0.4471218424633827 - 1.3736306979834252 - -julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); -ERROR: FieldError: type NamedTuple has no field `x`, available fields: `m` -[...] - -julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); - -julia> # (✓) Sort of fast, but only possible at runtime. - vi[@varname(x[1])] --1.019202452456547 - -julia> # In addtion, we can only access varnames as they appear in the model! - vi[@varname(x)] -ERROR: x was not found in the dictionary provided -[...] - -julia> vi[@varname(x[1:2])] -ERROR: x[1:2] was not found in the dictionary provided -[...] -``` - -_Technically_, it's possible to use any implementation of `AbstractDict` in place of -`OrderedDict`, but `OrderedDict` ensures that certain operations, e.g. linearization/flattening -of the values in the varinfo, are consistent between evaluations. Hence `OrderedDict` is -the preferred implementation of `AbstractDict` to use here. - -You can also sample in _transformed_ space: - -```jldoctest simplevarinfo-general -julia> @model demo_constrained() = x ~ Exponential() -demo_constrained (generic function with 2 methods) - -julia> m = demo_constrained(); - -julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); - -julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ -1.8632965762164932 - -julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)); - -julia> vi[@varname(x)] # (✓) -∞ < x < ∞ --0.21080155351918753 - -julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; - -julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! -true - -julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); - -julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.6225185067787314 - -julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; - -julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! -true -``` - -Evaluation in transformed space of course also works: - -```jldoctest simplevarinfo-general -julia> vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0))) - -julia> # (✓) Positive probability mass on negative numbers! - getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) --1.3678794411714423 - -julia> # While if we forget to indicate that it's transformed: - vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0))) - -julia> # (✓) No probability mass on negative numbers! - getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) --Inf -``` - -## Indexing -Using `NamedTuple` as underlying storage. - -```jldoctest -julia> svi_nt = SimpleVarInfo((m = (a = [1.0], ), )); - -julia> svi_nt[@varname(m)] -(a = [1.0],) - -julia> svi_nt[@varname(m.a)] -1-element Vector{Float64}: - 1.0 - -julia> svi_nt[@varname(m.a[1])] -1.0 - -julia> svi_nt[@varname(m.a[2])] -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] -[...] - -julia> svi_nt[@varname(m.b)] -ERROR: FieldError: type NamedTuple has no field `b`, available fields: `a` -[...] -``` - -Using `OrderedDict` as underlying storage. -```jldoctest -julia> svi_dict = SimpleVarInfo(OrderedDict(@varname(m) => (a = [1.0], ))); - -julia> svi_dict[@varname(m)] -(a = [1.0],) - -julia> svi_dict[@varname(m.a)] -1-element Vector{Float64}: - 1.0 - -julia> svi_dict[@varname(m.a[1])] -1.0 - -julia> svi_dict[@varname(m.a[2])] -ERROR: m.a[2] was not found in the dictionary provided -[...] - -julia> svi_dict[@varname(m.b)] -ERROR: m.b was not found in the dictionary provided -[...] -``` -""" -struct SimpleVarInfo{NT,Accs<:AccumulatorTuple where {N},C<:AbstractTransformation} <: - AbstractVarInfo - "underlying representation of the realization represented" - values::NT - "tuple of accumulators for things like log prior and log likelihood" - accs::Accs - "represents whether it assumes variables to be transformed" - transformation::C -end - -function Base.:(==)(vi1::SimpleVarInfo, vi2::SimpleVarInfo) - return vi1.values == vi2.values && - vi1.accs == vi2.accs && - vi1.transformation == vi2.transformation -end - -transformation(vi::SimpleVarInfo) = vi.transformation - -function SimpleVarInfo(values, accs) - return SimpleVarInfo(values, accs, NoTransformation()) -end -function SimpleVarInfo{T}(values) where {T<:Real} - return SimpleVarInfo(values, default_accumulators(T)) -end -function SimpleVarInfo(values) - return SimpleVarInfo{LogProbType}(values) -end -function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict{<:VarName}}) - return if isempty(values) - # Can't infer from values, so we just use default. - SimpleVarInfo{LogProbType}(values) - else - # Infer from `values`. - SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(values)))}(values) - end -end - -# Using `kwargs` to specify the values. -function SimpleVarInfo{T}(; kwargs...) where {T<:Real} - return SimpleVarInfo{T}(NamedTuple(kwargs)) -end -function SimpleVarInfo(; kwargs...) - return SimpleVarInfo(NamedTuple(kwargs)) -end - -# Constructor from `Model`. -function SimpleVarInfo{T}( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) where {T<:Real} - return last(init!!(rng, model, SimpleVarInfo{T}(), init_strategy)) -end -function SimpleVarInfo{T}( - model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() -) where {T<:Real} - return SimpleVarInfo{T}(Random.default_rng(), model, init_strategy) -end -# Constructors without type param -function SimpleVarInfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return SimpleVarInfo{LogProbType}(rng, model, init_strategy) -end -function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return SimpleVarInfo{LogProbType}(Random.default_rng(), model, init_strategy) -end - -# Constructor from `VarInfo`. -function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} - values = values_as(vi, D) - return SimpleVarInfo(values, copy(getaccs(vi))) -end -function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} - values = values_as(vi, D) - accs = map(acc -> convert_eltype(T, acc), getaccs(vi)) - return SimpleVarInfo(values, accs) -end - -function untyped_simple_varinfo(model::Model) - varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) - return last(init!!(model, varinfo)) -end - -function typed_simple_varinfo(model::Model) - varinfo = SimpleVarInfo{Float64}() - return last(init!!(model, varinfo)) -end - -function unflatten(svi::SimpleVarInfo, x::AbstractVector) - vals = unflatten(svi.values, x) - return SimpleVarInfo(vals, svi.accs, svi.transformation) -end - -function BangBang.empty!!(vi::SimpleVarInfo) - return resetaccs!!(Accessors.@set vi.values = empty!!(vi.values)) -end -Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) - -getaccs(vi::SimpleVarInfo) = vi.accs -setaccs!!(vi::SimpleVarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs - -""" - keys(vi::SimpleVarInfo) - -Return an iterator of keys present in `vi`. -""" -Base.keys(vi::SimpleVarInfo) = keys(vi.values) -Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values)) - -function Base.show(io::IO, mime::MIME"text/plain", svi::SimpleVarInfo) - if !(svi.transformation isa NoTransformation) - print(io, "Transformed ") - end - - return print(io, "SimpleVarInfo(", svi.values, ", ", repr(mime, getaccs(svi)), ")") -end - -function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) - return from_maybe_linked_internal(vi, vn, dist, getindex(vi, vn)) -end -function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) - vals_linked = mapreduce(vcat, vns) do vn - getindex(vi, vn, dist) - end - return recombine(dist, vals_linked, length(vns)) -end - -Base.getindex(vi::SimpleVarInfo, vn::VarName) = getindex_internal(vi, vn) - -# `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than -# just `Vector`. -function Base.getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) - return map(Base.Fix1(getindex, vi), vns) -end -# HACK: Needed to disambiguate. -Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) - -Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) - -getindex_internal(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) -# `AbstractDict` -function getindex_internal( - vi::SimpleVarInfo{<:Union{AbstractDict,VarNamedVector}}, vn::VarName -) - return getvalue(vi.values, vn) -end - -Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) - -function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) - # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. - return Accessors.@set vi.values = set!!(vi.values, vn, val) -end - -# TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with -# same symbol and same type of, say, `IndexLens`, for improved `.~` performance. -function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) - for (vn, val) in zip(vns, vals) - vi = BangBang.setindex!!(vi, val, vn) - end - return vi -end - -function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName) - # For dictlike objects, we treat the entire `vn` as a _key_ to set. - dict = values_as(vi) - # Attempt to split into `parent` and `child` optic. - parent, child, issuccess = splitoptic(getoptic(vn)) do optic - o = optic === nothing ? identity : optic - haskey(dict, VarName{getsym(vn)}(o)) - end - # When combined with `VarInfo`, `nothing` is equivalent to `identity`. - keyoptic = parent === nothing ? identity : parent - - dict_new = if !issuccess - # Split doesn't exist ⟹ we're working with a new key. - BangBang.setindex!!(dict, val, vn) - else - # Split exists ⟹ trying to set an existing key. - vn_key = VarName{getsym(vn)}(keyoptic) - BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) - end - return Accessors.@set vi.values = dict_new -end - -# `NamedTuple` -function BangBang.push!!( - vi::SimpleVarInfo{<:NamedTuple}, ::VarName{sym,typeof(identity)}, value, ::Distribution -) where {sym} - return Accessors.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,))) -end -function BangBang.push!!( - vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}, value, ::Distribution -) where {sym} - return Accessors.@set vi.values = set!!(vi.values, vn, value) -end - -# `AbstractDict` -function BangBang.push!!( - vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, value, ::Distribution -) - vi.values[vn] = value - return vi -end - -function BangBang.push!!( - vi::SimpleVarInfo{<:VarNamedVector}, vn::VarName, value, ::Distribution -) - # The semantics of push!! for SimpleVarInfo and VarNamedVector are different. For - # SimpleVarInfo, push!! allows the key to exist already, for VarNamedVector it does not. - # Hence we need to call update!! here, which has the same semantics as push!! does for - # SimpleVarInfo. - return Accessors.@set vi.values = setindex!!(vi.values, value, vn) -end - -const SimpleOrThreadSafeSimple{T,V,C} = Union{ - SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}} -} - -# Necessary for `matchingvalue` to work properly. -Base.eltype(svi::SimpleVarInfo) = infer_nested_eltype(typeof(svi.values)) -Base.eltype(tsvi::ThreadSafeVarInfo{<:SimpleVarInfo}) = eltype(tsvi.varinfo) - -# `subset` -function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) - return SimpleVarInfo( - _subset(varinfo.values, vns), map(copy, getaccs(varinfo)), varinfo.transformation - ) -end - -function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName} - vns_present = collect(keys(x)) - vns_found = filter( - vn_present -> any(subsumes(vn, vn_present) for vn in vns), vns_present - ) - C = ConstructionBase.constructorof(typeof(x)) - if isempty(vns_found) - return C() - else - return C(vn => x[vn] for vn in vns_found) - end -end - -function _subset(x::NamedTuple, vns) - # NOTE: Here we can only handle `vns` that contain `identity` as optic. - if any(Base.Fix1(!==, identity) ∘ getoptic, vns) - throw( - ArgumentError( - "Cannot subset `NamedTuple` with non-`identity` `VarName`. " * - "For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.", - ), - ) - end - - syms = map(getsym, vns) - x_syms = filter(Base.Fix2(in, syms), keys(x)) - return NamedTuple{Tuple(x_syms)}(Tuple(map(Base.Fix1(getindex, x), x_syms))) -end - -_subset(x::VarNamedVector, vns) = subset(x, vns) - -# `merge` -function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) - values = merge(varinfo_left.values, varinfo_right.values) - accs = map(copy, getaccs(varinfo_right)) - transformation = merge_transformations( - varinfo_left.transformation, varinfo_right.transformation - ) - return SimpleVarInfo(values, accs, transformation) -end - -function set_transformed!!(vi::SimpleVarInfo, trans) - return set_transformed!!(vi, trans ? DynamicTransformation() : NoTransformation()) -end -function set_transformed!!(vi::SimpleVarInfo, transformation::AbstractTransformation) - return Accessors.@set vi.transformation = transformation -end -function set_transformed!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) - return Accessors.@set vi.varinfo = set_transformed!!(vi.varinfo, trans) -end -function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) - # We keep this method around just to obey the AbstractVarInfo interface. - # However, note that this would only be a valid operation if it would be a - # no-op, which we check here. - if trans != is_transformed(vi) - error( - "Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.", - ) - end - return vi -end - -is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) -is_transformed(vi::SimpleVarInfo, ::VarName) = is_transformed(vi) -function is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) - return is_transformed(vi.varinfo, vn) -end -is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = is_transformed(vi.varinfo) - -values_as(vi::SimpleVarInfo) = vi.values -values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values -function values_as(vi::SimpleVarInfo, ::Type{Vector}) - isempty(vi) && return Any[] - return mapreduce(tovec, vcat, values(vi.values)) -end -function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict} - return ConstructionBase.constructorof(D)(zip(keys(vi), values(vi.values))) -end -function values_as(vi::SimpleVarInfo{<:AbstractDict}, ::Type{NamedTuple}) - return NamedTuple((Symbol(k), v) for (k, v) in vi.values) -end -function values_as(vi::SimpleVarInfo, ::Type{T}) where {T} - return values_as(vi.values, T) -end - -# Allow usage of `NamedBijector` too. -function link!!( - t::StaticTransformation{<:Bijectors.NamedTransform}, - vi::SimpleVarInfo{<:NamedTuple}, - ::Model, -) - b = inverse(t.bijector) - x = vi.values - y, logjac = with_logabsdet_jacobian(b, x) - vi_new = Accessors.@set(vi.values = y) - if hasacc(vi_new, Val(:LogJacobian)) - vi_new = acclogjac!!(vi_new, logjac) - end - return set_transformed!!(vi_new, t) -end - -function invlink!!( - t::StaticTransformation{<:Bijectors.NamedTransform}, - vi::SimpleVarInfo{<:NamedTuple}, - ::Model, -) - b = t.bijector - y = vi.values - x, inv_logjac = with_logabsdet_jacobian(b, y) - vi_new = Accessors.@set(vi.values = x) - # Mildly confusing: we need to _add_ the logjac of the inverse transform, - # because we are trying to remove the logjac of the forward transform - # that was previously accumulated when linking. - if hasacc(vi_new, Val(:LogJacobian)) - vi_new = acclogjac!!(vi_new, inv_logjac) - end - return set_transformed!!(vi_new, NoTransformation()) -end - -# With `SimpleVarInfo`, when we're not working with linked variables, there's no need to do anything. -from_internal_transform(vi::SimpleVarInfo, ::VarName) = identity -from_internal_transform(vi::SimpleVarInfo, ::VarName, dist) = identity -# TODO: Should the following methods specialize on the case where we have a `StaticTransformation{<:Bijectors.NamedTransform}`? -from_linked_internal_transform(vi::SimpleVarInfo, ::VarName) = identity -function from_linked_internal_transform(vi::SimpleVarInfo, ::VarName, dist) - return invlink_transform(dist) -end - -has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector diff --git a/src/test_utils.jl b/src/test_utils.jl index f584055b3..ebb516844 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1,6 +1,7 @@ module TestUtils using AbstractMCMC +using AbstractPPL: AbstractPPL using DynamicPPL using LinearAlgebra using Distributions diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index a030b479e..6bcd9547e 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -242,7 +242,7 @@ Everything else is optional, and can be categorised into several groups: Finally, note that these only reflect the parameters used for _evaluating_ the gradient. If you also want to control the parameters used for _preparing_ the gradient, then you need to manually set these parameters in - the VarInfo object, for example using `vi = DynamicPPL.unflatten(vi, + the VarInfo object, for example using `vi = DynamicPPL.unflatten!!(vi, prep_params)`. You could then evaluate the gradient at a different set of parameters using the `params` keyword argument. diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index c48d2ddfd..7182f511e 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -36,16 +36,12 @@ function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPP # varinfos.) Thus we only test evaluation with VarInfos that are already # filled with values. @testset "evaluation" begin - # Generate a new filled untyped varinfo - _, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) - typed_vi = DynamicPPL.typed_varinfo(untyped_vi) + # Generate a new filled varinfo + _, vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) # Set the test context as the new leaf context new_model = DynamicPPL.setleafcontext(model, context) - # Check that evaluation works - for vi in [untyped_vi, typed_vi] - _, vi = DynamicPPL.evaluate!!(new_model, vi) - @test vi isa DynamicPPL.VarInfo - end + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo end end @@ -53,7 +49,7 @@ function test_parent_context(context::DynamicPPL.AbstractContext, model::Dynamic @testset "get/set leaf and child contexts" begin # Ensure we're using a different leaf context than the current. leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext - DynamicPPL.DynamicTransformationContext{false}() + DynamicPPL.InitContext(Random.MersenneTwister(1234), InitFromPrior()) else DefaultContext() end @@ -73,13 +69,12 @@ function test_parent_context(context::DynamicPPL.AbstractContext, model::Dynamic @testset "initialisation and evaluation" begin new_model = contextualize(model, context) - for vi in [DynamicPPL.VarInfo(), DynamicPPL.typed_varinfo(DynamicPPL.VarInfo())] - # Initialisation - _, vi = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) - @test vi isa DynamicPPL.VarInfo - # Evaluation - _, vi = DynamicPPL.evaluate!!(new_model, vi) - @test vi isa DynamicPPL.VarInfo - end + vi = DynamicPPL.VarInfo() + # Initialisation + _, vi = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) + @test vi isa DynamicPPL.VarInfo + # Evaluation + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo end end diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index e7fb16fbe..50e13f912 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -89,10 +89,10 @@ function logprior_true_with_logabsdet_jacobian end Return a collection of `VarName` as they are expected to appear in the model. Even though it is recommended to implement this by hand for a particular `Model`, -a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. +a default implementation using [`VarInfo`](@ref) is provided. """ function varnames(model::Model) - result = collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict()))))) + result = collect(keys(last(DynamicPPL.init!!(model, VarInfo())))) # Concretise the element type. return [x for x in result] end @@ -104,7 +104,7 @@ Return a `NamedTuple` compatible with `varnames(model)` where the values represe the posterior mean under `model`. "Compatible" means that a `varname` from `varnames(model)` can be used to extract the -corresponding value using `get`, e.g. `get(posterior_mean(model), varname)`. +corresponding value using e.g. `AbstractPPL.getvalue(posterior_mean(model), varname)`. """ function posterior_mean end diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index 84e1f10d8..e244f956f 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -7,6 +7,26 @@ # # Some additionally contain an implementation of `rand_prior_true`. +""" + varnames(model::Model) + +Return the VarNames defined in `model`, as a Vector. +""" +function varnames end + +# TODO(mhauru) The fact that the below function exists is a sign that we are inconsistent in +# how we handle IndexLenses. This should hopefully be resolved once we consistently use +# VarNamedTuple rather than dictionaries everywhere. +""" + varnames_split(model::Model) + +Return the VarNames in `model`, with any ranges or colons split into individual indices. + +The default implementation is to just return `varnames(model)`. If something else is needed, +this should be defined separately. +""" +varnames_split(model::Model) = varnames(model) + """ demo_dynamic_constraint() @@ -77,6 +97,9 @@ end function varnames(model::Model{typeof(demo_one_variable_multiple_constraints)}) return [@varname(x[1]), @varname(x[2]), @varname(x[3]), @varname(x[4:5])] end +function varnames_split(model::Model{typeof(demo_one_variable_multiple_constraints)}) + return [@varname(x[1]), @varname(x[2]), @varname(x[3]), @varname(x[4]), @varname(x[5])] +end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_one_variable_multiple_constraints)}, x ) @@ -565,6 +588,67 @@ function varnames(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}) return [@varname(s), @varname(m)] end +@model function demo_nested_colons( + x=(; data=[(; subdata=transpose([1.5 2.0;]))]), ::Type{TV}=Array{Float64} +) where {TV} + n = length(x.data[1].subdata) + d = n ÷ 2 + s = (; params=[(; subparams=TV(undef, (d, 1, 2)))]) + s.params[1].subparams[:, 1, :] ~ reshape( + product_distribution(fill(InverseGamma(2, 3), n)), d, 2 + ) + s_vec = vec(s.params[1].subparams) + # TODO(mhauru) The below element type concretisation is because of + # https://github.com/JuliaFolds2/BangBang.jl/issues/39 + # which causes, when this is evaluated with an untyped VarInfo, s_vec to be an + # Array{Any}. + s_vec = [x for x in s_vec] + m ~ MvNormal(zeros(n), Diagonal(s_vec)) + + x.data[1].subdata[:, 1] ~ MvNormal(m, Diagonal(s_vec)) + + return (; s=s, m=m, x=x) +end +function logprior_true(model::Model{typeof(demo_nested_colons)}, s, m) + n = length(model.args.x.data[1].subdata) + # TODO(mhauru) We need to enforce a convention on whether this function gets called + # with the parameters as the model returns them, or with the parameters "unpacked". + # Currently different tests do different things. + s_vec = if s isa NamedTuple + vec(s.params[1].subparams) + else + vec(s) + end + return loglikelihood(InverseGamma(2, 3), s_vec) + + logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m) +end +function loglikelihood_true(model::Model{typeof(demo_nested_colons)}, s, m) + # TODO(mhauru) We need to enforce a convention on whether this function gets called + # with the parameters as the model returns them, or with the parameters "unpacked". + # Currently different tests do different things. + s_vec = if s isa NamedTuple + vec(s.params[1].subparams) + else + vec(s) + end + return loglikelihood(MvNormal(m, Diagonal(s_vec)), model.args.x.data[1].subdata) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_nested_colons)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s.params[1].subparams, m) +end +function varnames(::Model{typeof(demo_nested_colons)}) + return [@varname(s.params[1].subparams[:, 1, :]), @varname(m)] +end +function varnames_split(::Model{typeof(demo_nested_colons)}) + return [ + @varname(s.params[1].subparams[1, 1, 1]), + @varname(s.params[1].subparams[1, 1, 2]), + @varname(m), + ] +end + const UnivariateAssumeDemoModels = Union{ Model{typeof(demo_assume_dot_observe)}, Model{typeof(demo_assume_dot_observe_literal)}, @@ -615,8 +699,8 @@ function likelihood_optima(model::MultivariateAssumeDemoModels) vals = rand_prior_true(model) # NOTE: These are "as close to zero as we can get". - vals.s[1] = 1e-32 - vals.s[2] = 1e-32 + vals.s[1] = floatmin() + vals.s[2] = floatmin() vals.m[1] = 1.5 vals.m[2] = 2.0 @@ -668,8 +752,8 @@ function likelihood_optima(model::MatrixvariateAssumeDemoModels) vals = rand_prior_true(model) # NOTE: These are "as close to zero as we can get". - vals.s[1, 1] = 1e-32 - vals.s[1, 2] = 1e-32 + vals.s[1, 1] = floatmin() + vals.s[1, 2] = floatmin() vals.m[1] = 1.5 vals.m[2] = 2.0 @@ -701,6 +785,51 @@ function rand_prior_true(rng::Random.AbstractRNG, model::MatrixvariateAssumeDemo return vals end +function posterior_mean(model::Model{typeof(demo_nested_colons)}) + # Get some containers to fill. + vals = rand_prior_true(model) + + vals.s.params[1].subparams[1, 1, 1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s.params[1].subparams[1, 1, 2] = 8 / 3 + vals.m[2] = 1 + + return vals +end +function likelihood_optima(model::Model{typeof(demo_nested_colons)}) + # Get some containers to fill. + vals = rand_prior_true(model) + + # NOTE: These are "as close to zero as we can get". + vals.s.params[1].subparams[1, 1, 1] = floatmin() + vals.s.params[1].subparams[1, 1, 2] = floatmin() + + vals.m[1] = 1.5 + vals.m[2] = 2.0 + + return vals +end +function posterior_optima(model::Model{typeof(demo_nested_colons)}) + # Get some containers to fill. + vals = rand_prior_true(model) + + # TODO: Figure out exact for `s[1]`. + vals.s.params[1].subparams[1, 1, 1] = 0.890625 + vals.s.params[1].subparams[1, 1, 2] = 1 + vals.m[1] = 3 / 4 + vals.m[2] = 1 + + return vals +end +function rand_prior_true(rng::Random.AbstractRNG, ::Model{typeof(demo_nested_colons)}) + svec = rand(rng, InverseGamma(2, 3), 2) + return (; + s=(; params=[(; subparams=reshape(svec, (1, 1, 2)))]), + m=rand(rng, MvNormal(zeros(2), Diagonal(svec))), + ) +end + """ A collection of models corresponding to the posterior distribution defined by the generative process @@ -749,6 +878,7 @@ const DEMO_MODELS = ( demo_dot_assume_observe_submodel(), demo_dot_assume_observe_matrix_index(), demo_assume_matrix_observe_matrix_index(), + # demo_nested_colons(), ) """ diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 6483b29e8..146917dc9 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -10,7 +10,7 @@ Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in """ function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...) for vn in vns - val = get(vals, vn) + val = AbstractPPL.getvalue(vals, vn) # TODO(mhauru) Workaround for https://github.com/JuliaLang/LinearAlgebra.jl/pull/1404 # Remove once the fix is all Julia versions we support. if val isa Cholesky @@ -33,34 +33,14 @@ of the varinfo instances. function setup_varinfos( model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false ) - # VarInfo - vi_untyped_metadata = DynamicPPL.untyped_varinfo(model) - vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model) - vi_typed_metadata = DynamicPPL.typed_varinfo(model) - vi_typed_vnv = DynamicPPL.typed_vector_varinfo(model) - - # SimpleVarInfo - svi_typed = SimpleVarInfo(example_values) - svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}()) - svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) - - varinfos = map(( - vi_untyped_metadata, - vi_untyped_vnv, - vi_typed_metadata, - vi_typed_vnv, - svi_typed, - svi_untyped, - svi_vnv, - )) do vi - # Set them all to the same values and evaluate logp. - vi = update_values!!(vi, example_values, varnames) - last(DynamicPPL.evaluate!!(model, vi)) + vi = DynamicPPL.VarInfo(model) + vi = update_values!!(vi, example_values, varnames) + vi = last(DynamicPPL.evaluate!!(model, vi)) + + varinfos = if include_threadsafe + (vi, DynamicPPL.ThreadSafeVarInfo(deepcopy(vi))) + else + (vi,) end - - if include_threadsafe - varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo ∘ deepcopy, varinfos)...) - end - return varinfos end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index c7ab106a2..547dd6a1e 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -36,6 +36,9 @@ function getacc(vi::ThreadSafeVarInfo, accname::Val) return foldl(combine, other_accs; init=main_acc) end +function Base.copy(vi::ThreadSafeVarInfo) + return ThreadSafeVarInfo(copy(vi.varinfo), deepcopy(vi.accs_by_thread)) +end hasacc(vi::ThreadSafeVarInfo, accname::Val) = hasacc(vi.varinfo, accname) acckeys(vi::ThreadSafeVarInfo) = acckeys(vi.varinfo) @@ -62,12 +65,6 @@ function map_accumulators!!(func::Function, vi::ThreadSafeVarInfo) return vi end -has_varnamedvector(vi::ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) - -function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution) - return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist) -end - syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) @@ -85,61 +82,6 @@ function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...) end -function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, model::Model) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, model) -end - -function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, model::Model) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, model) -end - -function link( - t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameTuple, model::Model -) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, vns, model) -end - -function invlink( - t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameTuple, model::Model -) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, vns, model) -end - -# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. -# NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure -# consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates -# to define `getacc(vi)`. -function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = setleafcontext(model, DynamicTransformationContext{false}()) - return set_transformed!!(last(evaluate!!(model, vi)), t) -end - -function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = setleafcontext(model, DynamicTransformationContext{true}()) - return set_transformed!!(last(evaluate!!(model, vi)), NoTransformation()) -end - -function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return link!!(t, deepcopy(vi), model) -end - -function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return invlink!!(t, deepcopy(vi), model) -end - -# These two StaticTransformation methods needed to resolve ambiguities -function link!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model -) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, model) -end - -function invlink!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model -) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, model) -end - function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model) # Defer to the wrapped `AbstractVarInfo` object. # NOTE: When computing `getacc` for `ThreadSafeVarInfo` we do include the @@ -159,6 +101,11 @@ function getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::D return getindex(vi.varinfo, vns, dist) end +function setindex_with_dist!!(vi::ThreadSafeVarInfo, val, dist::Distribution, vn::VarName) + vi_inner, logjac = setindex_with_dist!!(vi.varinfo, val, dist, vn) + return Accessors.@set(vi.varinfo = vi_inner), logjac +end + function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vn::VarName) return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn) end @@ -195,8 +142,8 @@ end getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.varinfo, vn) -function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector) - return Accessors.@set vi.varinfo = unflatten(vi.varinfo, x) +function unflatten!!(vi::ThreadSafeVarInfo, x::AbstractVector) + return Accessors.@set vi.varinfo = unflatten!!(vi.varinfo, x) end function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) diff --git a/src/utils.jl b/src/utils.jl index 75fb805dc..2a35db779 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,13 +1,7 @@ -# singleton for indicating if no default arguments are present -struct NoDefault end -const NO_DEFAULT = NoDefault() +# subset is defined here to avoid circular dependencies between files. Methods for it are +# defined in other files. +function subset end -# A short-hand for a type commonly used in type signatures for VarInfo methods. -VarNameTuple = NTuple{N,VarName} where {N} - -# TODO(mhauru) This is currently used in the transformation functions of NoDist, -# ReshapeTransform, and UnwrapSingletonTransform, and in VarInfo. We should also use it in -# SimpleVarInfo and maybe other places. """ The type for all log probability variables. @@ -49,6 +43,7 @@ function typed_identity end @inline typed_identity(x) = x @inline Bijectors.with_logabsdet_jacobian(::typeof(typed_identity), x) = (x, zero(LogProbType)) +@inline Bijectors.inverse(::typeof(typed_identity)) = typed_identity """ @addlogprob!(ex) @@ -401,7 +396,7 @@ from_vec_transform(dist::Distribution) = from_vec_transform_for_size(size(dist)) from_vec_transform(::UnivariateDistribution) = UnwrapSingletonTransform() from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ∘ ReshapeTransform(size(dist)) -struct ProductNamedTupleUnvecTransform{names,T<:NamedTuple{names}} +struct ProductNamedTupleUnvecTransform{names,T<:NamedTuple{names}} <: Bijectors.Bijector dists::T # The `i`-th input range corresponds to the segment of the input vector # that belongs to the `i`-th distribution. @@ -434,13 +429,30 @@ end return expr end +@generated function (inv_trf::Bijectors.Inverse{<:ProductNamedTupleUnvecTransform{names}})( + x::NamedTuple{names} +) where {names} + exprs = Expr[] + for name in names + push!(exprs, :(to_vec_transform(inv_trf.orig.dists.$name)(x.$name))) + end + return :(vcat($(exprs...))) +end + function from_vec_transform(dist::Distributions.ProductNamedTupleDistribution) return ProductNamedTupleUnvecTransform(dist) end + function Bijectors.with_logabsdet_jacobian(f::ProductNamedTupleUnvecTransform, x) return f(x), zero(LogProbType) end +function Bijectors.with_logabsdet_jacobian( + inv_f::Bijectors.Inverse{<:ProductNamedTupleUnvecTransform}, x +) + return inv_f(x), zero(LogProbType) +end + # This function returns the length of the vector that the function from_vec_transform # expects. This helps us determine which segment of a concatenated vector belongs to which # variable. @@ -484,8 +496,6 @@ end # UnivariateDistributions need to be handled as a special case, because size(dist) is (), # which makes the usual machinery think we are dealing with a 0-dim array, whereas in # actuality we are dealing with a scalar. -# TODO(mhauru) Hopefully all this can go once the old Gibbs sampler is removed and -# VarNamedVector takes over from Metadata. function from_linked_vec_transform(dist::UnivariateDistribution) f_invlink = invlink_transform(dist) f_vec = from_vec_transform(inverse(f_invlink), size(dist)) @@ -528,275 +538,6 @@ tovec(t::Tuple) = mapreduce(tovec, vcat, t) tovec(nt::NamedTuple) = mapreduce(tovec, vcat, values(nt)) tovec(C::Cholesky) = tovec(Matrix(C.UL)) -""" - recombine(dist::Union{UnivariateDistribution,MultivariateDistribution}, vals::AbstractVector, n::Int) - -Recombine `vals`, representing a batch of samples from `dist`, so that it's a compatible with `dist`. - -!!! warning - This only supports `UnivariateDistribution` and `MultivariateDistribution`, which are the only two - distribution types which are allowed on the right-hand side of a `.~` statement in a model. -""" -function recombine(::UnivariateDistribution, val::AbstractVector, ::Int) - # This is just a no-op, since we're trying to convert a vector into a vector. - return copy(val) -end -function recombine(d::MultivariateDistribution, val::AbstractVector, n::Int) - # Here `val` is of the length `length(d) * n` and so we need to reshape it. - return copy(reshape(val, length(d), n)) -end - -####################### -# Convenience methods # -####################### -""" - collect_maybe(x) - -Return `x` if `x` is an array, otherwise return `collect(x)`. -""" -collect_maybe(x) = collect(x) -collect_maybe(x::AbstractArray) = x - -####################### -# BangBang.jl related # -####################### -function set!!(obj, optic::AbstractPPL.ALLOWED_OPTICS, value) - opticmut = BangBang.prefermutation(optic) - return Accessors.set(obj, opticmut, value) -end -function set!!(obj, vn::VarName{sym}, value) where {sym} - optic = BangBang.prefermutation( - AbstractPPL.getoptic(vn) ∘ Accessors.PropertyLens{sym}() - ) - return Accessors.set(obj, optic, value) -end - -############################# -# AbstractPPL.jl extensions # -############################# -# This is preferable to `haskey` because the order of arguments is different, and -# we're more likely to specialize on the key in these settings rather than the container. -# TODO: I'm not sure about this name. -""" - canview(optic, container) - -Return `true` if `optic` can be used to view `container`, and `false` otherwise. - -# Examples -```jldoctest; setup=:(using Accessors; using DynamicPPL: canview) -julia> canview(@o(_.a), (a = 1.0, )) -true - -julia> canview(@o(_.a), (b = 1.0, )) # property `a` does not exist -false - -julia> canview(@o(_.a[1]), (a = [1.0, 2.0], )) -true - -julia> canview(@o(_.a[3]), (a = [1.0, 2.0], )) # out of bounds -false -``` -""" -canview(optic, container) = false -canview(::typeof(identity), _) = true -function canview(optic::Accessors.PropertyLens{field}, x) where {field} - return hasproperty(x, field) -end - -# `IndexLens`: only relevant if `x` supports indexing. -canview(optic::Accessors.IndexLens, x) = false -function canview(optic::Accessors.IndexLens, x::AbstractArray) - return checkbounds(Bool, x, optic.indices...) -end - -# `ComposedOptic`: check that we can view `.inner` and `.outer`, but using -# value extracted using `.inner`. -function canview(optic::Accessors.ComposedOptic, x) - return canview(optic.inner, x) && canview(optic.outer, optic.inner(x)) -end - -""" - parent(vn::VarName) - -Return the parent `VarName`. - -# Examples -```julia-repl; setup=:(using DynamicPPL: parent) -julia> parent(@varname(x.a[1])) -x.a - -julia> (parent ∘ parent)(@varname(x.a[1])) -x - -julia> (parent ∘ parent ∘ parent)(@varname(x.a[1])) -x -``` -""" -function parent(vn::VarName) - p = parent(getoptic(vn)) - return p === nothing ? VarName{getsym(vn)}(identity) : VarName{getsym(vn)}(p) -end - -""" - parent(optic) - -Return the parent optic. If `optic` doesn't have a parent, -`nothing` is returned. - -See also: [`parent_and_child`]. - -# Examples -```jldoctest; setup=:(using Accessors; using DynamicPPL: parent) -julia> parent(@o(_.a[1])) -(@o _.a) - -julia> # Parent of optic without parents results in `nothing`. - (parent ∘ parent)(@o(_.a[1])) === nothing -true -``` -""" -parent(optic::AbstractPPL.ALLOWED_OPTICS) = first(parent_and_child(optic)) - -""" - parent_and_child(optic) - -Return a 2-tuple of optics `(parent, child)` where `parent` is the -parent optic of `optic` and `child` is the child optic of `optic`. - -If `optic` does not have a parent, we return `(nothing, optic)`. - -See also: [`parent`]. - -# Examples -```jldoctest; setup=:(using Accessors; using DynamicPPL: parent_and_child) -julia> parent_and_child(@o(_.a[1])) -((@o _.a), (@o _[1])) - -julia> parent_and_child(@o(_.a)) -(nothing, (@o _.a)) -``` -""" -parent_and_child(optic::AbstractPPL.ALLOWED_OPTICS) = (nothing, optic) -function parent_and_child(optic::Accessors.ComposedOptic) - p, child = parent_and_child(optic.outer) - parent = p === nothing ? optic.inner : p ∘ optic.inner - return parent, child -end - -""" - splitoptic(condition, optic) - -Return a 3-tuple `(parent, child, issuccess)` where, if `issuccess` is `true`, -`parent` is a optic such that `condition(parent)` is `true` and `child ∘ parent == optic`. - -If `issuccess` is `false`, then no such split could be found. - -# Examples -```jldoctest; setup=:(using Accessors; using DynamicPPL: splitoptic) -julia> p, c, issucesss = splitoptic(@o(_.a[1])) do parent - # Succeeds! - parent == @o(_.a) - end -((@o _.a), (@o _[1]), true) - -julia> c ∘ p -(@o _.a[1]) - -julia> splitoptic(@o(_.a[1])) do parent - # Fails! - parent == @o(_.b) - end -(nothing, (@o _.a[1]), false) -``` -""" -function splitoptic(condition, optic) - current_parent, current_child = parent_and_child(optic) - # We stop if either a) `condition` is satisfied, or b) we reached the root. - while !condition(current_parent) && current_parent !== nothing - current_parent, c = parent_and_child(current_parent) - current_child = current_child ∘ c - end - - return current_parent, current_child, condition(current_parent) -end - -""" - remove_parent_optic(vn_parent::VarName, vn_child::VarName) - -Remove the parent optic `vn_parent` from `vn_child`. - -# Examples -```jldoctest; setup = :(using Accessors; using DynamicPPL: remove_parent_optic) -julia> remove_parent_optic(@varname(x), @varname(x.a)) -(@o _.a) - -julia> remove_parent_optic(@varname(x), @varname(x.a[1])) -(@o _.a[1]) - -julia> remove_parent_optic(@varname(x.a), @varname(x.a[1])) -(@o _[1]) - -julia> remove_parent_optic(@varname(x.a), @varname(x.a[1].b)) -(@o _[1].b) - -julia> remove_parent_optic(@varname(x.a), @varname(x.a)) -ERROR: Could not find x.a in x.a - -julia> remove_parent_optic(@varname(x.a[2]), @varname(x.a[1])) -ERROR: Could not find x.a[2] in x.a[1] -``` -""" -function remove_parent_optic(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym} - _, child, issuccess = splitoptic(getoptic(vn_child)) do optic - o = optic === nothing ? identity : optic - o == getoptic(vn_parent) - end - - issuccess || error("Could not find $vn_parent in $vn_child") - return child -end - -# HACK(torfjelde): This makes it so it works on iterators, etc. by default. -# TODO(torfjelde): Do better. -""" - unflatten(original, x::AbstractVector) - -Return instance of `original` constructed from `x`. -""" -function unflatten(original, x::AbstractVector) - lengths = map(length, original) - end_indices = cumsum(lengths) - return map(zip(original, lengths, end_indices)) do (v, l, end_idx) - start_idx = end_idx - l + 1 - return unflatten(v, @view(x[start_idx:end_idx])) - end -end - -unflatten(::Real, x::Real) = x -unflatten(::Real, x::AbstractVector) = only(x) -unflatten(::AbstractVector{<:Real}, x::Real) = vcat(x) -unflatten(::AbstractVector{<:Real}, x::AbstractVector) = x -unflatten(original::AbstractArray{<:Real}, x::AbstractVector) = reshape(x, size(original)) - -function unflatten(original::Tuple, x::AbstractVector) - lengths = map(length, original) - end_indices = cumsum(lengths) - return ntuple(length(original)) do i - v = original[i] - l = lengths[i] - end_idx = end_indices[i] - start_idx = end_idx - l + 1 - return unflatten(v, @view(x[start_idx:end_idx])) - end -end -function unflatten(original::NamedTuple{names}, x::AbstractVector) where {names} - return NamedTuple{names}(unflatten(values(original), x)) -end -function unflatten(original::AbstractDict, x::AbstractVector) - D = ConstructionBase.constructorof(typeof(original)) - return D(zip(keys(original), unflatten(collect(values(original)), x))) -end - """ update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) @@ -804,7 +545,7 @@ Return instance similar to `vi` but with `vns` set to values from `vals`. """ function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) for vn in vns - vi = DynamicPPL.setindex!!(vi, get(vals, vn), vn) + vi = DynamicPPL.setindex!!(vi, AbstractPPL.getvalue(vals, vn), vn) end return vi end @@ -900,52 +641,26 @@ _merge(::NamedTuple{()}, right::AbstractDict) = right _merge(left::NamedTuple, right::AbstractDict) = merge(to_varname_dict(left), right) """ - unique_syms(vns::T) where {T<:NTuple{N,VarName}} - -Return the unique symbols of the variables in `vns`. - -Note that `unique_syms` is only defined for `Tuple`s of `VarName`s and, unlike -`Base.unique`, returns a `Tuple`. The point of `unique_syms` is that it supports constant -propagating the result, which is possible only when the argument and the return value are -`Tuple`s. -""" -@generated function unique_syms(::T) where {T<:VarNameTuple} - retval = Expr(:tuple) - syms = [first(vn.parameters) for vn in T.parameters] - for sym in unique(syms) - push!(retval.args, QuoteNode(sym)) - end - return retval -end + basetypeof(x) +Return `typeof(x)` stripped of its type parameters. """ - group_varnames_by_symbol(vns::NTuple{N,VarName}) where {N} - -Return a `NamedTuple` of the variables in `vns` grouped by symbol. - -Note that `group_varnames_by_symbol` only accepts a `Tuple` of `VarName`s. This allows it to -be type stable. - -Example: -```julia -julia> vns_tuple = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2])) -(x, y[1], x.a, z[15], y[2]) +basetypeof(x::T) where {T} = Base.typename(T).wrapper -julia> vns_nt = (; x=[@varname(x), @varname(x.a)], y=[@varname(y[1]), @varname(y[2])], z=[@varname(z[15])]) -(x = VarName{:x}[x, x.a], y = VarName{:y, IndexLens{Tuple{Int64}}}[y[1], y[2]], z = VarName{:z, IndexLens{Tuple{Int64}}}[z[15]]) +const MaybeTypedIdentity = Union{typeof(typed_identity),typeof(identity)} -julia> group_varnames_by_symbol(vns_tuple) == vns_nt -``` +# TODO(mhauru) Might add another specialisation to _compose_no_identity, where if +# ReshapeTransforms are composed with each other or with a an UnwrapSingeltonTransform, only +# the latter one would be kept. """ -function group_varnames_by_symbol(vns::VarNameTuple) - syms = unique_syms(vns) - elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) - return NamedTuple{syms}(elements) -end + _compose_no_identity(f, g) -""" - basetypeof(x) +Like `f ∘ g`, but if `f` or `g` is `identity` it is omitted. -Return `typeof(x)` stripped of its type parameters. +This helps avoid trivial cases of `ComposedFunction` that would cause unnecessary type +conflicts. """ -basetypeof(x::T) where {T} = Base.typename(T).wrapper +_compose_no_identity(f, g) = f ∘ g +_compose_no_identity(::MaybeTypedIdentity, g) = g +_compose_no_identity(f, ::MaybeTypedIdentity) = f +_compose_no_identity(::MaybeTypedIdentity, ::MaybeTypedIdentity) = typed_identity diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 71baebe92..9ee622424 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -10,14 +10,14 @@ wants to extract the realization of a model in a constrained space. # Fields $(TYPEDFIELDS) """ -struct ValuesAsInModelAccumulator <: AbstractAccumulator +struct ValuesAsInModelAccumulator{VNT<:VarNamedTuple} <: AbstractAccumulator "values that are extracted from the model" - values::OrderedDict{<:VarName} + values::VNT "whether to extract variables on the LHS of :=" include_colon_eq::Bool end function ValuesAsInModelAccumulator(include_colon_eq) - return ValuesAsInModelAccumulator(OrderedDict{VarName,Any}(), include_colon_eq) + return ValuesAsInModelAccumulator(VarNamedTuple(), include_colon_eq) end function Base.:(==)(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator) @@ -30,6 +30,9 @@ end accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel +# TODO(mhauru) We could start using reset!!, which could call empty!! on the VarNamedTuple. +# This would create VarNamedTuples that share memory with the original one, saving +# allocations but also making them not capable of taking in any arbitrary VarName. function _zero(acc::ValuesAsInModelAccumulator) return ValuesAsInModelAccumulator(empty(acc.values), acc.include_colon_eq) end @@ -45,8 +48,11 @@ function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumula ) end -function Base.push!(acc::ValuesAsInModelAccumulator, vn::VarName, val) - setindex!(acc.values, deepcopy(val), vn) +function BangBang.push!!(acc::ValuesAsInModelAccumulator, vn::VarName, val) + # TODO(mhauru) The deepcopy here is quite unfortunate. It is needed so that the model + # body can go mutating the object without that reactively affecting the value in the + # accumulator, which should be as it was at `~` time. Could there be a way around this? + Accessors.@reset acc.values = setindex!!(acc.values, deepcopy(val), vn) return acc end @@ -56,7 +62,7 @@ function is_extracting_values(vi::AbstractVarInfo) end function accumulate_assume!!(acc::ValuesAsInModelAccumulator, val, logjac, vn, right) - return push!(acc, vn, val) + return push!!(acc, vn, val) end accumulate_observe!!(acc::ValuesAsInModelAccumulator, right, left, vn) = acc @@ -75,6 +81,8 @@ working in unconstrained space. Hence this method is a "safe" way of obtaining realizations in constrained space at the cost of additional model evaluations. +Returns a `VarNamedTuple`. + # Arguments - `model::Model`: model to extract realizations from. - `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`. @@ -99,35 +107,30 @@ julia> @model function model_changing_support() julia> model = model_changing_support(); -julia> # Construct initial type-stable `VarInfo`. +julia> # Construct initial `VarInfo`. varinfo = VarInfo(rng, model); julia> # Link it so it works in unconstrained space. varinfo_linked = DynamicPPL.link(varinfo, model); -julia> # Perform computations in unconstrained space, e.g. changing the values of `θ`. +julia> # Perform computations in unconstrained space, e.g. changing the values of `vals`. # Flip `x` so we hit the other support of `y`. - θ = [!varinfo[@varname(x)], rand(rng)]; + vals = [!varinfo[@varname(x)], rand(rng)]; julia> # Update the `VarInfo` with the new values. - varinfo_linked = DynamicPPL.unflatten(varinfo_linked, θ); + varinfo_linked = DynamicPPL.unflatten!!(varinfo_linked, vals); julia> # Determine the expected support of `y`. - lb, ub = θ[1] == 1 ? (0, 1) : (11, 12) + lb, ub = vals[1] == 1 ? (0, 1) : (11, 12) (0, 1) julia> # Approach 1: Convert back to constrained space using `invlink` and extract. varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model); -julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions - # used in the very first model evaluation, hence the support of `y` - # is not updated even though `x` has changed. - lb ≤ first(varinfo_invlinked[@varname(y)]) ≤ ub -false +julia> lb ≤ first(varinfo_invlinked[@varname(y)]) ≤ ub +true julia> # Approach 2: Extract realizations using `values_as_in_model`. - # (✓) `values_as_in_model` will re-run the model and extract - # the correct realization of `y` given the new values of `x`. lb ≤ values_as_in_model(model, true, varinfo_linked)[@varname(y)] ≤ ub true ``` diff --git a/src/varinfo.jl b/src/varinfo.jl index f78fbe01b..191537ad8 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1,1810 +1,504 @@ -#### -#### Types for typed and untyped VarInfo -#### - -#################### -# VarInfo metadata # -#################### - """ -The `Metadata` struct stores some metadata about the parameters of the model. This helps -query certain information about a variable, such as its distribution, which samplers -sample this variable, its value and whether this value is transformed to real space or -not. - -Let `md` be an instance of `Metadata`: -- `md.vns` is the vector of all `VarName` instances. -- `md.idcs` is the dictionary that maps each `VarName` instance to its index in - `md.vns`, `md.ranges` `md.dists`, and `md.is_transformed`. -- `md.vns[md.idcs[vn]] == vn`. -- `md.dists[md.idcs[vn]]` is the distribution of `vn`. -- `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`. -- `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`. -- `md.is_transformed` is a BitVector of true/false flags for whether a variable has been - transformed. `md.is_transformed[md.idcs[vn]]` is the value corresponding to `vn`. - -To make `md::Metadata` type stable, all the `md.vns` must have the same symbol -and distribution type. However, one can have a Julia variable, say `x`, that is a -matrix or a hierarchical array sampled in partitions, e.g. -`x[1][:] ~ MvNormal(zeros(2), I); x[2][:] ~ MvNormal(ones(2), I)`, and is managed by -a single `md::Metadata` so long as all the distributions on the RHS of `~` are of the -same type. Type unstable `Metadata` will still work but will have inferior performance. -When sampling, the first iteration uses a type unstable `Metadata` for all the -variables then a specialized `Metadata` is used for each symbol along with a function -barrier to make the rest of the sampling type stable. -""" -struct Metadata{ - TIdcs<:Dict{<:VarName,Int}, - TDists<:AbstractVector{<:Distribution}, - TVN<:AbstractVector{<:VarName}, - TVal<:AbstractVector{<:Real}, -} - # Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists` - idcs::TIdcs # Dict{<:VarName,Int} - - # Vector of identifiers for the random variables, where `vns[idcs[vn]] == vn` - vns::TVN # AbstractVector{<:VarName} - - # Vector of index ranges in `vals` corresponding to `vns` - # Each `VarName` `vn` has a single index or a set of contiguous indices in `vals` - ranges::Vector{UnitRange{Int}} - - # Vector of values of all the univariate, multivariate and matrix variables - # The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]` - vals::TVal # AbstractVector{<:Real} - - # Vector of distributions correpsonding to `vns` - dists::TDists # AbstractVector{<:Distribution} - - is_transformed::BitVector -end + VarInfo{Linked,T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo -function Base.:(==)(md1::Metadata, md2::Metadata) - return ( - md1.idcs == md2.idcs && - md1.vns == md2.vns && - md1.ranges == md2.ranges && - md1.vals == md2.vals && - md1.dists == md2.dists && - md1.is_transformed == md2.is_transformed - ) -end +The default implementation of `AbstractVarInfo`, storing variable values and accumulators. -########### -# VarInfo # -########### +The `Linked` type parameter is either `true` or `false` to mark that all variables in this +`VarInfo` are linked, or `nothing` to indicate that some variables may be linked and some +not, and a runtime check is needed. -""" - struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo - metadata::Tmeta - accs::Accs - end +`VarInfo` is quite a thin wrapper around a `VarNamedTuple` storing the variable values, +and a tuple of accumulators. The only really noteworthy thing about it is that it stores +the values of variables vectorised as instances of `TransformedValue`. That is, it stores +each value as a vector and a transformation to be applied to that vector to get the actual +value. It also stores whether the transformation is such that it guarantees all real vectors +to be valid internal representations of the variable (i.e., whether the variable has been +linked), as well as the size of the actual post-transformation value. These are all fields +of [`TransformedValue`](@ref). -A light wrapper over some kind of metadata. +Note that `setindex!!` and `getindex` on `VarInfo` take and return values in the support of +the original distribution. To get access to the internal vectorised values, use +[`getindex_internal`](@ref), [`setindex_internal!!`](@ref), and [`unflatten!!`](@ref). -The type of the metadata can be one of a number of options. It may either be a -`Metadata` or a `VarNamedVector`, _or_, it may be a `NamedTuple` which maps -symbols to `Metadata` or `VarNamedVector` instances. Here, a _symbol_ refers -to a Julia variable and may consist of one or more `VarName`s which appear on -the left-hand side of tilde statements. For example, `x[1]` and `x[2]` both -have the same symbol `x`. +There's also a `VarInfo`-specific function [`setindex_with_dist!!`](@ref), which sets a +variable's value with a transformation based on the statistical distribution this value is +a sample for. -Several type aliases are provided for these forms of VarInfos: -- `VarInfo{<:Metadata}` is `UntypedVarInfo` -- `VarInfo{<:VarNamedVector}` is `UntypedVectorVarInfo` -- `VarInfo{<:NamedTuple}` is `NTVarInfo` +For more details on the internal storage, see documentation of [`TransformedValue`](@ref) and +[`VarNamedTuple`](@ref). -The NamedTuple form, i.e. `NTVarInfo`, is useful for maintaining type stability -of model evaluation. However, the element type of NamedTuples are not contained -in its type itself: thus, there is no way to use the type system to determine -whether the elements of the NamedTuple are `Metadata` or `VarNamedVector`. +# Fields +$(TYPEDFIELDS) -Note that for NTVarInfo, it is the user's responsibility to ensure that each -symbol is visited at least once during model evaluation, regardless of any -stochastic branching. """ -struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo - metadata::Tmeta +struct VarInfo{Linked,T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo + values::T accs::Accs -end -function VarInfo(meta=Metadata()) - return VarInfo(meta, default_accumulators()) -end - -""" - VarInfo( - [rng::Random.AbstractRNG], - model, - [init_strategy::AbstractInitStrategy] - ) - -Generate a `VarInfo` object for the given `model`, by initialising it with the -given `rng` and `init_strategy`. - -!!! warning - - This function currently returns a `VarInfo` with its metadata field set to - a `NamedTuple` of `Metadata`. This is an implementation detail. In general, - this function may return any kind of object that satisfies the - `AbstractVarInfo` interface. If you require precise control over the type - of `VarInfo` returned, use the internal functions `untyped_varinfo`, - `typed_varinfo`, `untyped_vector_varinfo`, or `typed_vector_varinfo` - instead. -""" -function VarInfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return typed_varinfo(rng, model, init_strategy) -end -function VarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return VarInfo(Random.default_rng(), model, init_strategy) -end - -const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} -const UntypedVarInfo = VarInfo{<:Metadata} -# TODO: NTVarInfo carries no information about the type of the actual metadata -# i.e. the elements of the NamedTuple. It could be Metadata or it could be -# VarNamedVector. -# Resolving this ambiguity would likely require us to replace NamedTuple with -# something which carried both its keys as well as its values' types as type -# parameters. -const NTVarInfo = VarInfo{<:NamedTuple} -const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ - VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} -} - -function Base.:(==)(vi1::VarInfo, vi2::VarInfo) - return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs) -end -# NOTE: This is kind of weird, but it effectively preserves the "old" -# behavior where we're allowed to call `link!` on the same `VarInfo` -# multiple times. -transformation(::VarInfo) = DynamicTransformation() - -# No-op if we're already working with a `VarNamedVector`. -metadata_to_varnamedvector(vnv::VarNamedVector) = vnv -function metadata_to_varnamedvector(md::Metadata) - idcs = copy(md.idcs) - vns = copy(md.vns) - ranges = copy(md.ranges) - vals = copy(md.vals) - is_trans = map(Base.Fix1(is_transformed, md), md.vns) - transforms = map(md.dists, is_trans) do dist, trans - if trans - return from_linked_vec_transform(dist) - else - return from_vec_transform(dist) - end + function VarInfo{Linked}( + values::T, accs::Accs + ) where {Linked,T<:VarNamedTuple,Accs<:AccumulatorTuple} + return new{Linked,T,Accs}(values, accs) end - - return VarNamedVector( - OrderedDict{eltype(keys(idcs)),Int}(idcs), vns, ranges, vals, transforms, is_trans - ) end -function has_varnamedvector(vi::VarInfo) - return vi.metadata isa VarNamedVector || - (vi isa NTVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata))) -end +# TODO(mhauru) The policy of vectorising all values was set when the old VarInfo type was +# using a Vector as the internal storage in all cases. We should revisit this, and allow +# values to be stored "raw", since VarNamedTuple supports it. -######################## -# VarInfo constructors # -######################## +# TODO(mhauru) Related to the above, I think we should reconsider whether we should store +# transformations at all. We rarely use them, since they may be dynamic in a model. +# tilde_assume!! rather gets the transformation from the current distribution encountered +# during model execution. However, this would change the interface quite a lot, so I want to +# finish implementing VarInfo using VNT (mostly) respecting the old interface first. +# TODO(mhauru) We are considering removing `transform` completely, and forcing people to use +# ValuesAsInModelAcc instead. If that is done, we may want to move the Linked type parameter +# to just be a bool field. It's currently a type parameter to make the type of `transform` +# easier to type infer, but if `transform` no longer exists, it might start to cause +# unnecessary type inconcreteness in the elements of PartialArray. """ - untyped_varinfo([rng, ]model[, init_strategy]) + TransformedValue{Linked,ValType,TransformType,SizeType} -Construct a VarInfo object for the given `model`, which has just a single -`Metadata` as its metadata field. +A struct for storing a variable's value in its internal (vectorised) form. -# Arguments -- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation -- `model::Model`: The model for which to create the varinfo object -- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. +The type parameter `Linked` is a `Bool` indicating whether the variable is linked, i.e. +whether the transformation maps all real vectors to valid values. +# Fields +$(TYPEDFIELDS) """ -function untyped_varinfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return last(init!!(rng, model, VarInfo(Metadata()), init_strategy)) -end -function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return untyped_varinfo(Random.default_rng(), model, init_strategy) +struct TransformedValue{Linked,ValType,TransformType,SizeType} + "The internal (vectorised) value." + val::ValType + """The transformation from internal (vectorised) to actual value. In other words, the + actual value of the variable being stored is `transform(val)`.""" + transform::TransformType + """The size of the actual value after transformation. This is needed when a + `TransformedValue` is stored as a block in an array.""" + size::SizeType + + function TransformedValue{Linked}( + val::ValType, transform::TransformType, size::SizeType + ) where {Linked,ValType,TransformType,SizeType} + return new{Linked,ValType,TransformType,SizeType}(val, transform, size) + end end -""" - typed_varinfo(vi::UntypedVarInfo) +is_transformed(::TransformedValue{Linked}) where {Linked} = Linked -This function finds all the unique `sym`s from the instances of `VarName{sym}` found in -`vi.metadata.vns`. It then extracts the metadata associated with each symbol from the -global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as -a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each -symbol. -""" -function typed_varinfo(vi::UntypedVarInfo) - meta = vi.metadata - new_metas = Metadata[] - # Symbols of all instances of `VarName{sym}` in `vi.vns` - syms_tuple = Tuple(syms(vi)) - for s in syms_tuple - # Find all indices in `vns` with symbol `s` - inds = findall(vn -> getsym(vn) === s, meta.vns) - n = length(inds) - # New `vns` - sym_vns = getindex.((meta.vns,), inds) - # New idcs - sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) - # New dists - sym_dists = getindex.((meta.dists,), inds) - # New is_transformed - sym_is_transformed = meta.is_transformed[inds] - - # Extract new ranges and vals - _ranges = getindex.((meta.ranges,), inds) - # `copy.()` is a workaround to reduce the eltype from Real to Int or Float64 - _vals = [copy.(meta.vals[_ranges[i]]) for i in 1:n] - sym_ranges = Vector{eltype(_ranges)}(undef, n) - start = 0 - for i in 1:n - sym_ranges[i] = (start + 1):(start + length(_vals[i])) - start += length(_vals[i]) - end - sym_vals = foldl(vcat, _vals) +VarNamedTuples.vnt_size(tv::TransformedValue) = tv.size - push!( - new_metas, - Metadata( - sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_is_transformed - ), - ) +VarInfo() = VarInfo{false}(VarNamedTuple(), default_accumulators()) + +function VarInfo(values::Union{NamedTuple,AbstractDict}) + vi = VarInfo() + for (k, v) in pairs(values) + vn = k isa Symbol ? VarName{k}() : k + vi = setindex!!(vi, v, vn) end - nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, copy(vi.accs)) -end -function typed_varinfo(vi::NTVarInfo) - # This function preserves the behaviour of typed_varinfo(vi) where vi is - # already a NTVarInfo - has_varnamedvector(vi) && error( - "Cannot convert VarInfo with NamedTuple of VarNamedVector to VarInfo with NamedTuple of Metadata", - ) return vi end -""" - typed_varinfo([rng, ]model[, init_strategy]) - -Return a VarInfo object for the given `model`, which has a NamedTuple of -`Metadata` structs as its metadata field. -# Arguments -- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation -- `model::Model`: The model for which to create the varinfo object -- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. -""" -function typed_varinfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return typed_varinfo(untyped_varinfo(rng, model, init_strategy)) -end -function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return typed_varinfo(Random.default_rng(), model, init_strategy) +function VarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return VarInfo(Random.default_rng(), model, init_strategy) end -""" - untyped_vector_varinfo([rng, ]model[, init_strategy]) - -Return a VarInfo object for the given `model`, which has just a single -`VarNamedVector` as its metadata field. - -# Arguments -- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation -- `model::Model`: The model for which to create the varinfo object -- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. -""" -function untyped_vector_varinfo(vi::UntypedVarInfo) - md = metadata_to_varnamedvector(vi.metadata) - return VarInfo(md, copy(vi.accs)) -end -function untyped_vector_varinfo( +function VarInfo( rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return last(init!!(rng, model, VarInfo(VarNamedVector()), init_strategy)) -end -function untyped_vector_varinfo( - model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() -) - return untyped_vector_varinfo(Random.default_rng(), model, init_strategy) + return last(init!!(rng, model, VarInfo(), init_strategy)) end -""" - typed_vector_varinfo([rng, ]model[, init_strategy]) - -Return a VarInfo object for the given `model`, which has a NamedTuple of -`VarNamedVector`s as its metadata field. - -# Arguments -- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation -- `model::Model`: The model for which to create the varinfo object -- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. -""" -function typed_vector_varinfo(vi::NTVarInfo) - md = map(metadata_to_varnamedvector, vi.metadata) - return VarInfo(md, copy(vi.accs)) -end -function typed_vector_varinfo(vi::UntypedVectorVarInfo) - new_metas = group_by_symbol(vi.metadata) - nt = NamedTuple(new_metas) - return VarInfo(nt, copy(vi.accs)) -end -function typed_vector_varinfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return typed_vector_varinfo(untyped_vector_varinfo(rng, model, init_strategy)) -end -function typed_vector_varinfo( - model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() -) - return typed_vector_varinfo(Random.default_rng(), model, init_strategy) +getaccs(vi::VarInfo) = vi.accs +function setaccs!!(vi::VarInfo{Linked}, accs::AccumulatorTuple) where {Linked} + return VarInfo{Linked}(vi.values, accs) end -""" - vector_length(varinfo::VarInfo) - -Return the length of the vector representation of `varinfo`. -""" -vector_length(varinfo::VarInfo) = length(varinfo.metadata) -vector_length(varinfo::NTVarInfo) = sum(length, varinfo.metadata) -vector_length(md::Metadata) = sum(length, md.ranges) +transformation(::VarInfo) = DynamicTransformation() -function unflatten(vi::VarInfo, x::AbstractVector) - md = unflatten_metadata(vi.metadata, x) - return VarInfo(md, vi.accs) +function Base.copy(vi::VarInfo{Linked}) where {Linked} + return VarInfo{Linked}(copy(vi.values), copy(getaccs(vi))) end +Base.haskey(vi::VarInfo, vn::VarName) = haskey(vi.values, vn) +Base.length(vi::VarInfo) = length(vi.values) +Base.keys(vi::VarInfo) = keys(vi.values) +Base.values(vi::VarInfo) = mapreduce(p -> p.second.val, push!, vi.values; init=Any[]) -# We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in -# utils.jl. -@generated function unflatten_metadata( - metadata::NamedTuple{names}, x::AbstractVector -) where {names} - exprs = [] - offset = :(0) - for f in names - mdf = :(metadata.$f) - len = :(sum(length, $mdf.ranges)) - push!(exprs, :($f = unflatten_metadata($mdf, x[($offset + 1):($offset + $len)]))) - offset = :($offset + $len) - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) +function Base.getindex(vi::VarInfo, vn::VarName) + tv = getindex(vi.values, vn) + return tv.transform(tv.val) end -function unflatten_metadata(md::Metadata, x::AbstractVector) - return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.is_transformed) +function Base.getindex(vi::VarInfo, vns::AbstractVector{<:VarName}) + return [getindex(vi, vn) for vn in vns] end -unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) - -#### -#### Internal functions -#### - -""" - Metadata() - -Construct an empty type unstable instance of `Metadata`. -""" -function Metadata() - vals = Vector{Real}() - is_transformed = BitVector() - - return Metadata( - Dict{VarName,Int}(), - Vector{VarName}(), - Vector{UnitRange{Int}}(), - vals, - Vector{Distribution}(), - is_transformed, - ) -end +Base.isempty(vi::VarInfo) = isempty(vi.values) +Base.empty(vi::VarInfo) = VarInfo{false}(empty(vi.values), map(reset, vi.accs)) +BangBang.empty!!(vi::VarInfo) = VarInfo{false}(empty!!(vi.values), map(reset, vi.accs)) """ - empty!(meta::Metadata) + setindex_internal!!(vi::VarInfo, val, vn::VarName) -Empty the fields of `meta`. +Set the internal (vectorised) value of variable `vn` in `vi` to `val`. -This is useful when using a sampling algorithm that assumes an empty `meta`, e.g. `SMC`. +This does not change the transformation or linked status of the variable. """ -function empty!(meta::Metadata) - empty!(meta.idcs) - empty!(meta.vns) - empty!(meta.ranges) - empty!(meta.vals) - empty!(meta.dists) - empty!(meta.is_transformed) - return meta +function setindex_internal!!(vi::VarInfo{Linked}, val, vn::VarName) where {Linked} + old_tv = getindex(vi.values, vn) + new_tv = TransformedValue{is_transformed(old_tv)}(val, old_tv.transform, old_tv.size) + new_values = setindex!!(vi.values, new_tv, vn) + return VarInfo{Linked}(new_values, vi.accs) end -# Removes the first element of a NamedTuple. The pairs in a NamedTuple are ordered, so this is well-defined. -if VERSION < v"1.1" - _tail(nt::NamedTuple{names}) where {names} = NamedTuple{Base.tail(names)}(nt) -else - _tail(nt::NamedTuple) = Base.tail(nt) -end +# TODO(mhauru) It shouldn't really be VarInfo's business to know about `dist`. However, +# we need `dist` to determine the linking transformation (or even just the vectorisation +# transformation in the case of ProductNamedTupleDistribions), and if we leave the work +# of doing the transformation to the caller (tilde_assume!!), it'll be done even when e.g. +# using OnlyAccsVarInfo. Hence having this function. It should eventually hopefully be +# removed once VAIMAcc is the only way to get values out of an evaluation. +""" + setindex_with_dist!!(vi::VarInfo, val, dist::Distribution, vn::VarName) -function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) - metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, map(copy, getaccs(varinfo))) -end +Set the value of `vn` in `vi` to `val`, applying a transformation based on `dist`. -function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) - vns_syms = Set(unique(map(getsym, vns))) - syms = filter(Base.Fix2(in, vns_syms), keys(metadata)) - metadatas = map(syms) do sym - subset(getfield(metadata, sym), filter(==(sym) ∘ getsym, vns)) - end - return NamedTuple{syms}(metadatas) -end +`val` is taken to be the actual value of the variable, and is transformed into the internal +(vectorised) representation using a transformation based on `dist`. If the variable is +currently linked in `vi`, or doesn't exist in `vi` but all other variables in `vi` are +linked, the linking transformation is used; otherwise, the standard vector transformation is +used. -# The above method is type unstable since we don't know which symbols are in `vns`. -# In the below special case, when all `vns` have the same symbol, we can write a type stable -# version. - -@generated function subset( - metadata::NamedTuple{names}, vns::AbstractVector{<:VarName{sym}} -) where {names,sym} - return if (sym in names) - # TODO(mhauru) Note that this could still generate an empty metadata object if none - # of the lenses in `vns` are in `metadata`. Not sure if that's okay. Checking for - # emptiness would make this type unstable again. - :((; $sym=subset(metadata.$sym, vns))) +Returns the modified `vi` together with the log absolute determinant of the Jacobian of the +transformation applied. +""" +function setindex_with_dist!!( + vi::VarInfo{Linked}, val, dist::Distribution, vn::VarName +) where {Linked} + link = if Linked === nothing + haskey(vi, vn) ? is_transformed(vi, vn) : is_transformed(vi) else - :(NamedTuple{}()) + Linked end -end - -function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:VarName} - # TODO: Should we error if `vns` contains a variable that is not in `metadata`? - # Find all the vns in metadata that are subsumed by one of the given vns. - vns = filter(vn -> any(subsumes(vn_given, vn) for vn_given in vns_given), metadata.vns) - indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns) - indices = if isempty(vns) - Dict{VarName,Int}() + transform = if link + from_linked_vec_transform(dist) else - Dict(vn => i for (i, vn) in enumerate(vns)) - end - # Construct new `vals` and `ranges`. - vals_original = metadata.vals - ranges_original = metadata.ranges - # Allocate the new `vals`. and `ranges`. - vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]; init=0)) - ranges = similar(ranges_original, length(vns)) - # The new range `r` for `vns[i]` is offset by `offset` and - # has the same length as the original range `r_original`. - # The new `indices` (from above) ensures ordering according to `vns`. - # NOTE: This means that the order of the variables in `vns` defines the order - # in the resulting `varinfo`! This can have performance implications, e.g. - # if in the model we have something like - # - # for i = 1:N - # x[i] ~ Normal() - # end - # - # and we then we do - # - # subset(varinfo, [@varname(x[i]) for i in shuffle(keys(varinfo))]) - # - # the resulting `varinfo` will have `vals` ordered differently from the - # original `varinfo`, which can have performance implications. - offset = 0 - for (idx, idx_original) in enumerate(indices_for_vns) - r_original = ranges_original[idx_original] - r = (offset + 1):(offset + length(r_original)) - vals[r] = vals_original[r_original] - ranges[idx] = r - offset = r[end] - end - - dists = metadata.dists[indices_for_vns] - is_transformed = metadata.is_transformed[indices_for_vns] - return Metadata(indices, vns, ranges, vals, dists, is_transformed) -end - -function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) - return _merge(varinfo_left, varinfo_right) -end - -function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) - metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - accs = map(copy, getaccs(varinfo_right)) - return VarInfo(metadata, accs) -end - -function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) - return merge(vnv_left, vnv_right) -end - -@generated function merge_metadata( - metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right} -) where {names_left,names_right} - names = Expr(:tuple) - vals = Expr(:tuple) - # Loop over `names_left` first because we want to preserve the order of the variables. - for sym in names_left - push!(names.args, QuoteNode(sym)) - if sym in names_right - push!(vals.args, :(merge_metadata(metadata_left.$sym, metadata_right.$sym))) - else - push!(vals.args, :(metadata_left.$sym)) - end - end - # Loop over remaining variables in `names_right`. - names_right_only = filter(∉(names_left), names_right) - for sym in names_right_only - push!(names.args, QuoteNode(sym)) - push!(vals.args, :(metadata_right.$sym)) - end - - return :(NamedTuple{$names}($vals)) -end - -function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) - # Extract the varnames. - vns_left = metadata_left.vns - vns_right = metadata_right.vns - vns_both = union(vns_left, vns_right) - - # Determine `eltype` of `vals`. - T_left = eltype(metadata_left.vals) - T_right = eltype(metadata_right.vals) - T = promote_type(T_left, T_right) - # TODO: Is this necessary? - if !(T <: Real) - T = Real + from_vec_transform(dist) end - - # Determine `eltype` of `dists`. - D_left = eltype(metadata_left.dists) - D_right = eltype(metadata_right.dists) - D = promote_type(D_left, D_right) - # TODO: Is this necessary? - if !(D <: Distribution) - D = Distribution - end - - # Initialize required fields for `metadata`. - vns = eltype(vns_both)[] - idcs = Dict{eltype(vns_both),Int}() - ranges = Vector{UnitRange{Int}}() - vals = T[] - dists = D[] - transformed = BitVector() - - # Range offset. - offset = 0 - - for (idx, vn) in enumerate(vns_both) - idcs[vn] = idx - push!(vns, vn) - metadata_for_vn = vn in vns_right ? metadata_right : metadata_left - - val = getindex_internal(metadata_for_vn, vn) - append!(vals, val) - r = (offset + 1):(offset + length(val)) - push!(ranges, r) - offset = r[end] - dist = getdist(metadata_for_vn, vn) - push!(dists, dist) - push!(transformed, is_transformed(metadata_for_vn, vn)) - end - - return Metadata(idcs, vns, ranges, vals, dists, transformed) + transformed_val, logjac = with_logabsdet_jacobian(inverse(transform), val) + # All values for which `size` is not defined are assumed to be scalars. + val_size = hasmethod(size, Tuple{typeof(val)}) ? size(val) : () + tv = TransformedValue{link}(transformed_val, transform, val_size) + new_linked = Linked == link ? Linked : nothing + vi = VarInfo{new_linked}(setindex!!(vi.values, tv, vn), vi.accs) + return vi, logjac end -const VarView = Union{Int,UnitRange,Vector{Int}} - -""" - setval!(vi::UntypedVarInfo, val, vview::Union{Int, UnitRange, Vector{Int}}) - -Set the value of `vi.vals[vview]` to `val`. -""" -setval!(vi::UntypedVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val - -""" - getmetadata(vi::VarInfo, vn::VarName) - -Return the metadata in `vi` that belongs to `vn`. -""" -getmetadata(vi::VarInfo, vn::VarName) = vi.metadata -getmetadata(vi::NTVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) - -""" - getidx(vi::VarInfo, vn::VarName) - -Return the index of `vn` in the metadata of `vi` corresponding to `vn`. -""" -getidx(vi::VarInfo, vn::VarName) = getidx(getmetadata(vi, vn), vn) -getidx(md::Metadata, vn::VarName) = md.idcs[vn] - -""" - getrange(vi::VarInfo, vn::VarName) - -Return the index range of `vn` in the metadata of `vi`. -""" -getrange(vi::VarInfo, vn::VarName) = getrange(getmetadata(vi, vn), vn) -getrange(md::Metadata, vn::VarName) = md.ranges[getidx(md, vn)] - -""" - setrange!(vi::VarInfo, vn::VarName, range) - -Set the index range of `vn` in the metadata of `vi` to `range`. +# TODO(mhauru) The below is somewhat unsafe or incomplete: For instance, from_vec_transform +# isn't defined for NamedTuples. However, this is needed in some places where values for +# in a VarInfo are set outside the context of a `tilde_assume!!` and no distribution is +# available. Hopefully we'll get rid of this eventually. """ -setrange!(vi::VarInfo, vn::VarName, range) = setrange!(getmetadata(vi, vn), vn, range) -setrange!(md::Metadata, vn::VarName, range) = md.ranges[getidx(md, vn)] = range + setindex!!(vi::VarInfo, val, vn::VarName) -""" - getdist(vi::VarInfo, vn::VarName) +Set the value of `vn` in `vi` to `val`. -Return the distribution from which `vn` was sampled in `vi`. +The transformation for `vn` is reset to be the standard vector transformation for values of +the type of `val` and linking status is set to false. """ -getdist(vi::VarInfo, vn::VarName) = getdist(getmetadata(vi, vn), vn) -getdist(md::Metadata, vn::VarName) = md.dists[getidx(md, vn)] -# TODO(mhauru) Remove this once the old Gibbs sampler stuff is gone. -function getdist(::VarNamedVector, ::VarName) - throw(ErrorException("getdist does not exist for VarNamedVector")) -end - -getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, vn), vn) -# TODO(torfjelde): Use `view` instead of `getindex`. Requires addressing type-stability issues though, -# since then we might be returning a `SubArray` rather than an `Array`, which is typically -# what a bijector would result in, even if the input is a view (`SubArray`). -# TODO(torfjelde): An alternative is to implement `view` directly instead. -getindex_internal(md::Metadata, vn::VarName) = getindex(md.vals, getrange(md, vn)) -function getindex_internal(vi::VarInfo, vns::Vector{<:VarName}) - return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns) -end -getindex_internal(vi::VarInfo, ::Colon) = getindex_internal(vi.metadata, Colon()) -# NOTE: `mapreduce` over `NamedTuple` results in worse type-inference. -# See for example https://github.com/JuliaLang/julia/pull/46381. -function getindex_internal(vi::NTVarInfo, ::Colon) - return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata)) -end -function getindex_internal(vi::VarInfo{NamedTuple{(),Tuple{}}}, ::Colon) - return float(Real)[] -end -function getindex_internal(md::Metadata, ::Colon) - return mapreduce( - Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0) - ) +function BangBang.setindex!!(vi::VarInfo{Linked}, val, vn::VarName) where {Linked} + new_linked = Linked == false ? false : nothing + transform = from_vec_transform(val) + transformed_val = inverse(transform)(val) + tv = TransformedValue{false}(transformed_val, transform, size(val)) + return VarInfo{new_linked}(setindex!!(vi.values, tv, vn), vi.accs) end """ - setval!(vi::VarInfo, val, vn::VarName) + set_transformed!!(vi::VarInfo, linked::Bool, vn::VarName) -Set the value(s) of `vn` in the metadata of `vi` to `val`. +Set the linked status of variable `vn` in `vi` to `linked`. -The values may or may not be transformed to Euclidean space. +This does not change the value or transformation of the variable. """ -setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn) -function setval!(md::Metadata, val::AbstractVector, vn::VarName) - return md.vals[getrange(md, vn)] = val -end -function setval!(md::Metadata, val, vn::VarName) - return md.vals[getrange(md, vn)] = tovec(val) -end - -function set_transformed!!(vi::NTVarInfo, val::Bool, vn::VarName) - md = set_transformed!!(getmetadata(vi, vn), val, vn) - return Accessors.@set vi.metadata[getsym(vn)] = md -end - -function set_transformed!!(vi::VarInfo, val::Bool, vn::VarName) - md = set_transformed!!(getmetadata(vi, vn), val, vn) - return VarInfo(md, vi.accs) +function set_transformed!!(vi::VarInfo{Linked}, linked::Bool, vn::VarName) where {Linked} + old_tv = getindex(vi.values, vn) + new_tv = TransformedValue{linked}(old_tv.val, old_tv.transform, old_tv.size) + new_values = setindex!!(vi.values, new_tv, vn) + new_linked = Linked == linked ? Linked : nothing + return VarInfo{new_linked}(new_values, vi.accs) end -function set_transformed!!(metadata::Metadata, val::Bool, vn::VarName) - metadata.is_transformed[getidx(metadata, vn)] = val - return metadata +# VarInfo does not care whether the transformation was Static or Dynamic, it just tracks +# whether one was applied at all. +function set_transformed!!(vi::VarInfo, ::AbstractTransformation, vn::VarName) + return set_transformed!!(vi, true, vn) end -function set_transformed!!(vi::VarInfo, val::Bool) - for vn in keys(vi) - vi = set_transformed!!(vi, val, vn) - end +set_transformed!!(vi::VarInfo, ::AbstractTransformation) = set_transformed!!(vi, true) - return vi +function set_transformed!!(vi::VarInfo, ::NoTransformation, vn::VarName) + return set_transformed!!(vi, false, vn) end set_transformed!!(vi::VarInfo, ::NoTransformation) = set_transformed!!(vi, false) -# HACK: This is necessary to make something like `link!!(transformation, vi, model)` -# work properly, which will transform the variables according to `transformation` -# and then call `set_transformed!!(vi, transformation)`. An alternative would be to add -# the `transformation` to the `VarInfo` object, but at the moment doesn't seem -# worth it as `VarInfo` has its own way of handling transformations. -set_transformed!!(vi::VarInfo, ::AbstractTransformation) = set_transformed!!(vi, true) - -""" - syms(vi::VarInfo) -Returns a tuple of the unique symbols of random variables in `vi`. -""" -syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols -syms(vi::NTVarInfo) = keys(vi.metadata) - -_getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs) -_getidcs(vi::NTVarInfo) = _getidcs(vi.metadata) - -@generated function _getidcs(metadata::NamedTuple{names}) where {names} - exprs = [] - for f in names - push!(exprs, :($f = findinds(metadata.$f))) +function set_transformed!!(vi::VarInfo, linked::Bool) + new_values = map_values!!(vi.values) do tv + TransformedValue{linked}(tv.val, tv.transform, tv.size) end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) + return VarInfo{linked}(new_values, vi.accs) end -@inline findinds(f_meta::Metadata) = eachindex(f_meta.vns) -findinds(vnv::VarNamedVector) = 1:length(vnv.varnames) - """ - all_varnames_grouped_by_symbol(vi::NTVarInfo) + getindex_internal(vi::VarInfo, vn::VarName) -Return a `NamedTuple` of the variables in `vi` grouped by symbol. +Get the internal (vectorised) value of variable `vn` in `vi`. """ -all_varnames_grouped_by_symbol(vi::NTVarInfo) = all_varnames_grouped_by_symbol(vi.metadata) - -@generated function all_varnames_grouped_by_symbol(md::NamedTuple{names}) where {names} - expr = Expr(:tuple) - for f in names - push!(expr.args, :($f = keys(md.$f))) - end - return expr -end +getindex_internal(vi::VarInfo, vn::VarName) = getindex(vi.values, vn).val +# TODO(mhauru) The below should be removed together with unflatten!!. +getindex_internal(vi::VarInfo, ::Colon) = values_as(vi, Vector) -#### -#### APIs for typed and untyped VarInfo -#### - -function BangBang.empty!!(vi::VarInfo) - _empty!(vi.metadata) - vi = resetaccs!!(vi) - return vi -end - -_empty!(metadata) = empty!(metadata) -@generated function _empty!(metadata::NamedTuple{names}) where {names} - expr = Expr(:block) - for f in names - push!(expr.args, :(empty!(metadata.$f))) - end - return expr -end - -# `keys` -Base.keys(md::Metadata) = md.vns -Base.keys(vi::VarInfo) = Base.keys(vi.metadata) - -# HACK: Necessary to avoid returning `Any[]` which won't dispatch correctly -# on other methods in the codebase which requires `Vector{<:VarName}`. -Base.keys(vi::NTVarInfo{<:NamedTuple{()}}) = VarName[] -@generated function Base.keys(vi::NTVarInfo{<:NamedTuple{names}}) where {names} - expr = Expr(:call) - push!(expr.args, :vcat) - - for n in names - push!(expr.args, :(keys(vi.metadata.$n))) +function is_transformed(vi::VarInfo{Linked}, vn::VarName) where {Linked} + return if Linked === nothing + is_transformed(getindex(vi.values, vn)) + else + Linked end - - return expr -end - -is_transformed(vi::VarInfo, vn::VarName) = is_transformed(getmetadata(vi, vn), vn) -is_transformed(md::Metadata, vn::VarName) = md.is_transformed[getidx(md, vn)] - -getaccs(vi::VarInfo) = vi.accs -setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs - -# Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple). -isempty(vi::VarInfo) = _isempty(vi.metadata) -_isempty(metadata::Metadata) = isempty(metadata.idcs) -_isempty(vnv::VarNamedVector) = isempty(vnv) -@generated function _isempty(metadata::NamedTuple{names}) where {names} - return Expr(:&&, (:(_isempty(metadata.$f)) for f in names)...) -end - -function link!!(::DynamicTransformation, vi::NTVarInfo, model::Model) - vns = all_varnames_grouped_by_symbol(vi) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _link(model, vi, vns) - vi = _link!!(vi, vns) - return vi -end - -function link!!(::DynamicTransformation, vi::VarInfo, model::Model) - vns = keys(vi) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _link(model, vi, vns) - vi = _link!!(vi, vns) - return vi end -function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) +function from_internal_transform(::VarInfo, ::VarName, dist::Distribution) + return from_vec_transform(dist) end -function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _link(model, vi, vns) - vi = _link!!(vi, vns) - return vi +function from_linked_internal_transform(::VarInfo, ::VarName, dist::Distribution) + return from_linked_vec_transform(dist) end -function link!!( - t::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) +function from_internal_transform(vi::VarInfo, vn::VarName) + return getindex(vi.values, vn).transform end -function _link!!(vi::UntypedVarInfo, vns) - # TODO: Change to a lazy iterator over `vns` - if ~is_transformed(vi, vns[1]) - for vn in vns - f = internal_to_linked_internal_transform(vi, vn) - vi = _inner_transform!(vi, vn, f) - vi = set_transformed!!(vi, true, vn) - end - return vi - else - @warn("[DynamicPPL] attempt to link a linked vi") +function from_linked_internal_transform(vi::VarInfo, vn::VarName) + if !is_transformed(vi, vn) + error("Variable $vn is not linked; cannot get linked transformation.") end -end - -# If we try to _link!! a NTVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the NTVarInfo. -function _link!!(vi::NTVarInfo, vns::VarNameTuple) - return _link!!(vi, group_varnames_by_symbol(vns)) -end - -function _link!!(vi::NTVarInfo, vns::NamedTuple) - return _link!!(vi.metadata, vi, vns) + return getindex(vi.values, vn).transform end """ - filter_subsumed(filter_vns, filtered_vns) + _link_or_invlink!!(vi::VarInfo, vns, model::Model, ::Val{link}) where {link isa Bool} -Return the subset of `filtered_vns` that are subsumed by any variable in `filter_vns`. -""" -function filter_subsumed(filter_vns, filtered_vns) - return filter(x -> any(subsumes(y, x) for y in filter_vns), filtered_vns) -end +The internal function that implements both link!! and invlink!!. -@generated function _link!!( - ::NamedTuple{metadata_names}, vi, varnames::NamedTuple{vns_names} -) where {metadata_names,vns_names} - expr = Expr(:block) - for f in metadata_names - if !(f in vns_names) - continue +The last argument controls whether linking (true) or invlinking (false) is performed. If +`vns` is `nothing`, all variables in `vi` are transformed; otherwise, only the variables +in `vns` are transformed. Existing variables already in the desired state are left +unchanged. +""" +function _link_or_invlink!!(vi::VarInfo, vns, model::Model, ::Val{link}) where {link} + @assert link isa Bool + # Note that extract_priors causes a model execution. In the past with the Metadata-based + # VarInfo we rather derived the transformations from the distributions stored in the + # VarInfo itself. However, that is not fail-safe with dynamic models, and would require + # storing the distributions in TransformedValue (which we could start doing). Instead we + # use extract_priors to get the current, correct transformations. This logic is very + # similar to what DynamicTransformation used to do, and we might replace this with a + # context that transforms each variable in turn during the execution. + dists = extract_priors(model, vi) + cumulative_logjac = zero(LogProbType) + new_values = map_pairs!!(vi.values) do pair + vn, tv = pair + if vns !== nothing && !any(x -> subsumes(x, vn), vns) + # Not one of the target variables. + return tv end - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(varnames.$f, f_vns) - if !isempty(f_vns) - if !is_transformed(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - f = internal_to_linked_internal_transform(vi, vn) - vi = _inner_transform!(vi, vn, f) - vi = set_transformed!!(vi, true, vn) - end - else - @warn("[DynamicPPL] attempt to link a linked vi") - end - end - end, - ) - end - push!(expr.args, :(return vi)) - return expr -end - -function invlink!!(::DynamicTransformation, vi::NTVarInfo, model::Model) - vns = all_varnames_grouped_by_symbol(vi) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _invlink(model, vi, vns) - vi = _invlink!!(vi, vns) - return vi -end - -function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) - vns = keys(vi) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _invlink(model, vi, vns) - vi = _invlink!!(vi, vns) - return vi -end - -function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) -end - -function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _invlink(model, vi, vns) - vi = _invlink!!(vi, vns) - return vi -end - -function invlink!!( - ::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model) -end - -function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) - # Because `VarInfo` does not contain any information about what the transformation - # other than whether or not it has actually been transformed, the best we can do - # is just assume that `default_transformation` is the correct one if - # `is_transformed(vi)`. - t = is_transformed(vi) ? default_transformation(model, vi) : NoTransformation() - return maybe_invlink_before_eval!!(t, vi, model) -end - -function _invlink!!(vi::UntypedVarInfo, vns) - if is_transformed(vi, vns[1]) - for vn in vns - f = linked_internal_to_internal_transform(vi, vn) - vi = _inner_transform!(vi, vn, f) - vi = set_transformed!!(vi, false, vn) + if is_transformed(tv) == link + # Already in the desired state. + return tv end - return vi - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") - end -end - -# If we try to _invlink!! a NTVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the NTVarInfo. -function _invlink!!(vi::NTVarInfo, vns::VarNameTuple) - return _invlink!!(vi.metadata, vi, group_varnames_by_symbol(vns)) -end - -function _invlink!!(vi::NTVarInfo, vns::NamedTuple) - return _invlink!!(vi.metadata, vi, vns) -end - -@generated function _invlink!!( - ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} -) where {metadata_names,vns_names} - expr = Expr(:block) - for f in metadata_names - if !(f in vns_names) - continue + dist = getindex(dists, vn) + vec_transform = from_vec_transform(dist) + link_transform = from_linked_vec_transform(dist) + current_transform, new_transform = if link + (vec_transform, link_transform) + else + (link_transform, vec_transform) end - - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(vns.$f, f_vns) - if is_transformed(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - f = linked_internal_to_internal_transform(vi, vn) - vi = _inner_transform!(vi, vn, f) - vi = set_transformed!!(vi, false, vn) - end - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") - end - end, + val_untransformed, logjac1 = with_logabsdet_jacobian(current_transform, tv.val) + val_new, logjac2 = with_logabsdet_jacobian( + inverse(new_transform), val_untransformed ) + # !is_transformed(tv) is the same as `link`, but might be easier for type inference. + new_tv = TransformedValue{!is_transformed(tv)}(val_new, new_transform, tv.size) + cumulative_logjac += logjac1 + logjac2 + return new_tv end - push!(expr.args, :(return vi)) - return expr -end - -function _inner_transform!(vi::VarInfo, vn::VarName, f) - return _inner_transform!(getmetadata(vi, vn), vi, vn, f) -end - -function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) - # TODO: Use inplace versions to avoid allocations - yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(md, vn)) - # Determine the new range. - start = first(getrange(md, vn)) - # NOTE: `length(yvec)` should never be longer than `getrange(vi, vn)`. - setrange!(md, vn, start:(start + length(yvec) - 1)) - # Set the new value. - setval!(md, yvec, vn) + vi_linked = if vns === nothing + link + else + nothing + end + vi = VarInfo{vi_linked}(new_values, vi.accs) if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, logjac) + vi = acclogjac!!(vi, cumulative_logjac) end return vi end -function link(::DynamicTransformation, vi::NTVarInfo, model::Model) - return _link(model, vi, all_varnames_grouped_by_symbol(vi)) -end - -function link(::DynamicTransformation, varinfo::VarInfo, model::Model) - return _link(model, varinfo, keys(varinfo)) -end - -function link(::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, model) +function link!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) + return _link_or_invlink!!(vi, vns, model, Val(true)) end - -function link(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) - return _link(model, varinfo, vns) -end - -function link( - ::DynamicTransformation, - varinfo::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model) -end - -function _link(model::Model, varinfo::VarInfo, vns) - varinfo = deepcopy(varinfo) - md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) - new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogJacobian)) - new_varinfo = acclogjac!!(new_varinfo, logjac) - end - return new_varinfo -end - -# If we try to _link a NTVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the NTVarInfo. -function _link(model::Model, varinfo::NTVarInfo, vns::VarNameTuple) - return _link(model, varinfo, group_varnames_by_symbol(vns)) -end - -function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) - varinfo = deepcopy(varinfo) - md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns) - new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogJacobian)) - new_varinfo = acclogjac!!(new_varinfo, logjac) - end - return new_varinfo +function link!!(::DynamicTransformation, vi::VarInfo, model::Model) + return _link_or_invlink!!(vi, nothing, model, Val(true)) end - -@generated function _link_metadata!( - model::Model, - varinfo::VarInfo, - metadata::NamedTuple{metadata_names}, - vns::NamedTuple{vns_names}, -) where {metadata_names,vns_names} - expr = quote - cumulative_logjac = zero(LogProbType) - end - mds = Expr(:tuple) - for f in metadata_names - if f in vns_names - push!( - mds.args, - quote - begin - md, logjac = _link_metadata!!(model, varinfo, metadata.$f, vns.$f) - cumulative_logjac += logjac - md - end - end, - ) - else - push!(mds.args, :(metadata.$f)) - end - end - - push!( - expr.args, - quote - NamedTuple{$metadata_names}($mds), cumulative_logjac - end, - ) - return expr +function invlink!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) + return _link_or_invlink!!(vi, vns, model, Val(false)) end - -function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) - vns = metadata.vns - cumulative_logjac = zero(LogProbType) - - # Construct the new transformed values, and keep track of their lengths. - vals_new = map(vns) do vn - # Return early if we're already in unconstrained space. - # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. - if is_transformed(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) - return metadata.vals[getrange(metadata, vn)] - end - - # Transform to constrained space. - x = getindex_internal(metadata, vn) - dist = getdist(metadata, vn) - f = internal_to_linked_internal_transform(varinfo, vn, dist) - y, logjac = with_logabsdet_jacobian(f, x) - # Vectorize value. - yvec = tovec(y) - # Accumulate the log-abs-det jacobian correction. - cumulative_logjac += logjac - # Mark as transformed. - set_transformed!!(varinfo, true, vn) - # Return the vectorized transformed value. - return yvec - end - - # Determine new ranges. - ranges_new = similar(metadata.ranges) - offset = 0 - for (i, v) in enumerate(vals_new) - r_start, r_end = offset + 1, length(v) + offset - offset = r_end - ranges_new[i] = r_start:r_end - end - - # Now we just create a new metadata with the new `vals` and `ranges`. - return Metadata( - metadata.idcs, - metadata.vns, - ranges_new, - reduce(vcat, vals_new), - metadata.dists, - metadata.is_transformed, - ), - cumulative_logjac -end - -function _link_metadata!!( - model::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns -) - vns = target_vns === nothing ? keys(metadata) : target_vns - dists = extract_priors(model, varinfo) - cumulative_logjac = zero(LogProbType) - for vn in vns - # First transform from however the variable is stored in vnv to the model - # representation. - transform_to_orig = gettransform(metadata, vn) - val_old = getindex_internal(metadata, vn) - val_orig, logjac1 = with_logabsdet_jacobian(transform_to_orig, val_old) - # Then transform from the model representation to the linked representation. - transform_from_linked = from_linked_vec_transform(dists[vn]) - transform_to_linked = inverse(transform_from_linked) - val_new, logjac2 = with_logabsdet_jacobian(transform_to_linked, val_orig) - # TODO(mhauru) We are calling a !! function but ignoring the return value. - # Fix this when attending to issue #653. - cumulative_logjac += logjac1 + logjac2 - metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked) - set_transformed!(metadata, true, vn) +function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) + return _link_or_invlink!!(vi, nothing, model, Val(false)) +end + +function link!!(t::StaticTransformation{<:Bijectors.Transform}, vi::VarInfo, ::Model) + # TODO(mhauru) This assumes that the user has defined the bijector using the same + # variable ordering as what `vi[:]` and `unflatten!!(vi, x)` use. This is a bad user + # interface. + b = inverse(t.bijector) + x = vi[:] + y, logjac = with_logabsdet_jacobian(b, x) + # TODO(mhauru) This doesn't set the transforms of `vi`. With the old Metadata that meant + # that getindex(vi, vn) would apply the default link transform of the distribution. With + # the new VarNamedTuple-based VarInfo it means that getindex(vi, vn) won't apply any + # link transform. Neither is correct, rather the transform should be the inverse of b. + vi = unflatten!!(vi, y) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, logjac) end - # Linking can often change the sizes of variables, causing inactive elements. We don't - # want to keep them around, since typically linking is done once and then the VarInfo - # is evaluated multiple times. Hence we contiguify here. - metadata = contiguify!(metadata) - return metadata, cumulative_logjac -end - -function invlink(::DynamicTransformation, vi::NTVarInfo, model::Model) - return _invlink(model, vi, all_varnames_grouped_by_symbol(vi)) -end - -function invlink(::DynamicTransformation, vi::VarInfo, model::Model) - return _invlink(model, vi, keys(vi)) -end - -function invlink( - ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, model) -end - -function invlink(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) - return _invlink(model, varinfo, vns) + return set_transformed!!(vi, t) end -function invlink( - ::DynamicTransformation, - varinfo::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, vns, model) -end +function invlink!!(t::StaticTransformation{<:Bijectors.Transform}, vi::VarInfo, ::Model) + b = t.bijector + y = vi[:] + x, inv_logjac = with_logabsdet_jacobian(b, y) -function _invlink(model::Model, varinfo::VarInfo, vns) - varinfo = deepcopy(varinfo) - md, inv_logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) - new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogJacobian)) - # Mildly confusing: we need to _add_ the logjac of the inverse transform, - # because we are trying to remove the logjac of the forward transform - # that was previously accumulated when linking. - new_varinfo = acclogjac!!(new_varinfo, inv_logjac) + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + vi = unflatten!!(vi, x) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, inv_logjac) end - return new_varinfo + return set_transformed!!(vi, NoTransformation()) end -# If we try to _invlink a NTVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the NTVarInfo. -function _invlink(model::Model, varinfo::NTVarInfo, vns::VarNameTuple) - return _invlink(model, varinfo, group_varnames_by_symbol(vns)) +# TODO(mhauru) I don't think this should return the internal values, but that's the current +# convention. +function values_as(vi::VarInfo, ::Type{Vector}) + return mapfoldl(pair -> tovec(pair.second.val), vcat, vi.values; init=Union{}[]) end -function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) - varinfo = deepcopy(varinfo) - md, inv_logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) - new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogJacobian)) - # Mildly confusing: we need to _add_ the logjac of the inverse transform, - # because we are trying to remove the logjac of the forward transform - # that was previously accumulated when linking. - new_varinfo = acclogjac!!(new_varinfo, inv_logjac) - end - return new_varinfo +function values_as(vi::VarInfo, ::Type{T}) where {T<:AbstractDict} + return mapfoldl(identity, function (cumulant, pair) + vn, tv = pair + val = tv.transform(tv.val) + return setindex!!(cumulant, val, vn) + end, vi.values; init=T()) end -@generated function _invlink_metadata!( - model::Model, - varinfo::VarInfo, - metadata::NamedTuple{metadata_names}, - vns::NamedTuple{vns_names}, -) where {metadata_names,vns_names} - expr = quote - cumulative_inv_logjac = zero(LogProbType) - end - mds = Expr(:tuple) - for f in metadata_names - if (f in vns_names) - push!( - mds.args, - quote - begin - md, inv_logjac = _invlink_metadata!!( - model, varinfo, metadata.$f, vns.$f - ) - cumulative_inv_logjac += inv_logjac - md - end - end, - ) - else - push!(mds.args, :(metadata.$f)) - end - end - - push!( - expr.args, - quote - (NamedTuple{$metadata_names}($mds), cumulative_inv_logjac) +# TODO(mhauru) I really dislike this sort of conversion to Symbols, but it's the current +# interface provided by rand(::Model). We should change that to return a VarNamedTuple +# instead, and then this method (and any other values_as methods for NamedTuple) could be +# removed. +function values_as(vi::VarInfo, ::Type{NamedTuple}) + return mapfoldl( + identity, + function (cumulant, pair) + vn, tv = pair + val = tv.transform(tv.val) + return setindex!!(cumulant, val, Symbol(vn)) end, + vi.values; + init=NamedTuple(), ) - return expr -end - -function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) - vns = metadata.vns - cumulative_inv_logjac = zero(LogProbType) - - # Construct the new transformed values, and keep track of their lengths. - vals_new = map(vns) do vn - # Return early if we're already in constrained space OR if we're not - # supposed to touch this `vn`. - # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. - if !is_transformed(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) - return metadata.vals[getrange(metadata, vn)] - end - - # Transform to constrained space. - y = getindex_internal(varinfo, vn) - dist = getdist(varinfo, vn) - f = from_linked_internal_transform(varinfo, vn, dist) - x, inv_logjac = with_logabsdet_jacobian(f, y) - # Vectorize value. - xvec = tovec(x) - # Accumulate the log-abs-det jacobian correction. - cumulative_inv_logjac += inv_logjac - # Mark as no longer transformed. - set_transformed!!(varinfo, false, vn) - # Return the vectorized transformed value. - return xvec - end - - # Determine new ranges. - ranges_new = similar(metadata.ranges) - offset = 0 - for (i, v) in enumerate(vals_new) - r_start, r_end = offset + 1, length(v) + offset - offset = r_end - ranges_new[i] = r_start:r_end - end - - # Now we just create a new metadata with the new `vals` and `ranges`. - return Metadata( - metadata.idcs, - metadata.vns, - ranges_new, - reduce(vcat, vals_new), - metadata.dists, - metadata.is_transformed, - ), - cumulative_inv_logjac end -function _invlink_metadata!!( - ::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns -) - vns = target_vns === nothing ? keys(metadata) : target_vns - cumulative_inv_logjac = zero(LogProbType) - for vn in vns - transform = gettransform(metadata, vn) - old_val = getindex_internal(metadata, vn) - new_val, inv_logjac = with_logabsdet_jacobian(transform, old_val) - # TODO(mhauru) We are calling a !! function but ignoring the return value. - cumulative_inv_logjac += inv_logjac - new_transform = from_vec_transform(new_val) - metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform) - set_transformed!(metadata, false, vn) - end - # Linking can often change the sizes of variables, causing inactive elements. We don't - # want to keep them around, since typically linking is done once and then the VarInfo - # is evaluated multiple times. Hence we contiguify here. - metadata = contiguify!(metadata) - return metadata, cumulative_inv_logjac -end - -# TODO(mhauru) The treatment of the case when some variables are transformed and others are -# not should be revised. It used to be the case that for UntypedVarInfo `is_transformed` -# returned whether the first variable was linked. For NTVarInfo we did an OR over the first -# variables under each symbol. We now more consistently use OR, but I'm not convinced this -# is really the right thing to do. """ - is_transformed(vi::VarInfo) + VectorChunkIterator{T<:AbstractVector} -Check whether `vi` is in the transformed space. +A tiny struct for getting chunks of a vector sequentially. -Turing's Hamiltonian samplers use the `link` and `invlink` functions from -[Bijectors.jl](https://github.com/TuringLang/Bijectors.jl) to map a constrained variable -(for example, one bounded to the space `[0, 1]`) from its constrained space to the set of -real numbers. `is_transformed` checks if the number is in the constrained space or the real -space. - -If some but only some of the variables in `vi` are transformed, this function will return -`true`. This behavior will likely change in the future. +The only function provided is `get_next_chunk!`, which takes a length and returns +a view into the next chunk of that length, updating the internal index. """ -function is_transformed(vi::VarInfo) - return any(is_transformed(vi, vn) for vn in keys(vi)) -end - -# The default getindex & setindex!() for get & set values -# NOTE: vi[vn] will always transform the variable to its original space and Julia type -function getindex(vi::VarInfo, vn::VarName) - return from_maybe_linked_internal_transform(vi, vn)(getindex_internal(vi, vn)) +mutable struct VectorChunkIterator{T<:AbstractVector} + vec::T + index::Int end -function getindex(vi::VarInfo, vn::VarName, dist::Distribution) - @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - val = getindex_internal(vi, vn) - return from_maybe_linked_internal(vi, vn, dist, val) +function get_next_chunk!(vci::VectorChunkIterator, len::Int) + i = vci.index + chunk = @view vci.vec[i:(i + len - 1)] + vci.index += len + return chunk end -function getindex(vi::VarInfo, vns::Vector{<:VarName}) - vals = map(vn -> getindex(vi, vn), vns) - - et = eltype(vals) - # This will catch type unstable cases, where vals has mixed types. - if !isconcretetype(et) - throw(ArgumentError("All variables must have the same type.")) - end - - if et <: Vector - all_of_equal_dimension = all(x -> length(x) == length(vals[1]), vals) - if !all_of_equal_dimension - throw(ArgumentError("All variables must have the same dimension.")) +function unflatten!!(vi::VarInfo{Linked}, vec::AbstractVector) where {Linked} + # You may wonder, why have a whole struct for this, rather than just an index variable + # that the mapping function would close over. I wonder too. But for some reason type + # inference fails on such an index variable, turning it into a Core.Box. + vci = VectorChunkIterator(vec, 1) + new_values = map_values!!(vi.values) do tv + old_val = tv.val + if !(old_val isa AbstractVector) + error( + "Can't unflatten!! a VarInfo for which existing values are not vectors:" * + " Got value of type $(typeof(old_val)).", + ) end + len = length(old_val) + new_val = get_next_chunk!(vci, len) + return TransformedValue{is_transformed(tv)}(new_val, tv.transform, tv.size) end - - # TODO(mhauru) I'm not very pleased with the return type varying like this, even though - # this should be type stable. - vec_vals = reduce(vcat, vals) - if et <: Vector - # The individual variables are multivariate, and thus we return the values as a - # matrix. - return reshape(vec_vals, (:, length(vns))) - else - # The individual variables are univariate, and thus we return a vector of scalars. - return vec_vals - end -end - -function getindex(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) - @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - vals_linked = mapreduce(vcat, vns) do vn - getindex(vi, vn, dist) - end - return recombine(dist, vals_linked, length(vns)) -end - -# Recursively builds a tuple of the `vals` of all the symbols -@generated function _getindex(metadata, ranges::NamedTuple{names}) where {names} - expr = Expr(:tuple) - for f in names - push!(expr.args, :(metadata.$f.vals[ranges.$f])) - end - return expr + return VarInfo{Linked}(new_values, vi.accs) end -# TODO(mhauru) I think the below implementation of setindex! is a mistake. It should be -# called setindex_internal! since it directly writes to the `vals` field of the metadata. """ - setindex!(vi::VarInfo, val, vn::VarName) + subset(varinfo::VarInfo, vns) -Set the current value(s) of the random variable `vn` in `vi` to `val`. +Create a new `VarInfo` containing only the variables in `vns`. -The value(s) may or may not be transformed to Euclidean space. +`vns` can be almost any collection of `VarName`s, e.g. a `Set`, `Vector`, or `Tuple`. """ -setindex!(vi::VarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi) -function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) - setindex!(vi, val, vn) - return vi -end - -@inline function findvns(vi, f_vns) - if length(f_vns) == 0 - throw("Unidentified error, please report this error in an issue.") - end - return map(vn -> vi[vn], f_vns) +function subset(varinfo::VarInfo{Linked}, vns) where {Linked} + new_values = subset(varinfo.values, vns) + return VarInfo{Linked}(new_values, map(copy, getaccs(varinfo))) end -Base.haskey(metadata::Metadata, vn::VarName) = haskey(metadata.idcs, vn) - -""" - haskey(vi::VarInfo, vn::VarName) - -Check whether `vn` has a value in `vi`. """ -Base.haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn), vn) -function Base.haskey(vi::NTVarInfo, vn::VarName) - md_haskey = map(vi.metadata) do metadata - haskey(metadata, vn) - end - return any(md_haskey) -end - -function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) - lines = Tuple{String,Any}[ - ("VarNames", vi.metadata.vns), - ("Range", vi.metadata.ranges), - ("Vals", vi.metadata.vals), - ] - for accname in acckeys(vi) - push!(lines, (string(accname), getacc(vi, Val(accname)))) - end - push!(lines, ("is_transformed", vi.metadata.is_transformed)) - max_name_length = maximum(map(length ∘ first, lines)) - fmt = Printf.Format("%-$(max_name_length)s") - vi_str = ( - """ - /======================================================================= - | VarInfo - |----------------------------------------------------------------------- - """ * - prod( - map(lines) do (name, value) - """ - | $(Printf.format(fmt, name)) : $(value) - """ - end, - ) * - """ - \\======================================================================= - """ - ) - return print(io, vi_str) -end - -const _MAX_VARS_SHOWN = 4 + merge(varinfo_left::VarInfo, varinfo_right::VarInfo) -function _show_varnames(io::IO, vi) - md = vi.metadata - vns = keys(md) +Merge two `VarInfo`s into a new `VarInfo` containing all variables from both. - vns_by_name = Dict{Symbol,Vector{VarName}}() - for vn in vns - group = get!(() -> Vector{VarName}(), vns_by_name, getsym(vn)) - push!(group, vn) - end - - L = length(vns_by_name) - if L == 0 - print(io, "0 variables, dimension 0") - else - (L == 1) ? print(io, "1 variable (") : print(io, L, " variables (") - join(io, Iterators.take(keys(vns_by_name), _MAX_VARS_SHOWN), ", ") - (L > _MAX_VARS_SHOWN) && print(io, ", ...") - print(io, "), dimension ", length(md.vals)) - end -end - -function Base.show(io::IO, vi::UntypedVarInfo) - print(io, "VarInfo (") - _show_varnames(io, vi) - print(io, "; accumulators: ") - # TODO(mhauru) This uses "text/plain" because we are doing quite a condensed repretation - # of vi anyway. However, technically `show(io, x)` should give full details of x and - # preferably output valid Julia code. - show(io, MIME"text/plain"(), getaccs(vi)) - return print(io, ")") -end +The accumulators are taken exclusively from `varinfo_right`. +If a variable exists in both `varinfo_left` and `varinfo_right`, the value from +`varinfo_right` is used. """ - push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) - -Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to -the `VarInfo` `vi`, mutating if it makes sense. -""" -function BangBang.push!!(vi::VarInfo, vn::VarName, val, dist::Distribution) - @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" - md = push!!(getmetadata(vi, vn), vn, val, dist) - return VarInfo(md, vi.accs) -end - -function BangBang.push!!(vi::NTVarInfo, vn::VarName, val, dist::Distribution) - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist" - sym = getsym(vn) - meta = if ~haskey(vi.metadata, sym) - # The NamedTuple doesn't have an entry for this variable, let's add one. - _new_submetadata(vi, vn, val, dist) - else - push!!(getmetadata(vi, vn), vn, val, dist) - end - vi = Accessors.@set vi.metadata[sym] = meta - return vi -end - -""" - _new_submetadata(vi::VarInfo{NamedTuple{Names,SubMetas}}, args...) where {Names,SubMetas} - -Create a new sub-metadata for an NTVarInfo. The type is chosen by the types of existing -SubMetas. -""" -@generated function _new_submetadata( - vi::VarInfo{NamedTuple{Names,SubMetas}}, vn, r, dist -) where {Names,SubMetas} - has_vnv = any(s -> s <: VarNamedVector, SubMetas.parameters) - return if has_vnv - :(return _new_vnv_submetadata(vn, r, dist)) +function Base.merge( + varinfo_left::VarInfo{LinkedLeft}, varinfo_right::VarInfo{LinkedRight} +) where {LinkedLeft,LinkedRight} + new_values = merge(varinfo_left.values, varinfo_right.values) + new_accs = map(copy, getaccs(varinfo_right)) + new_linked = if LinkedLeft == LinkedRight + LinkedLeft else - :(return _new_metadata_submetadata(vn, r, dist)) - end -end - -_new_vnv_submetadata(vn, r, _) = VarNamedVector([vn], [r]) - -function _new_metadata_submetadata(vn, r, dist) - val = tovec(r) - return Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false])) -end - -function Base.push!(vi::UntypedVectorVarInfo, pair::Pair, args...) - vn, val = pair - return push!(vi, vn, val, args...) -end - -# TODO(mhauru) push! can't be implemented in-place for NTVarInfo if the symbol doesn't -# exist in the NTVarInfo already. We could implement it in the cases where it it does -# exist, but that feels a bit pointless. I think we should rather rely on `push!!`. - -function Base.push!(meta::Metadata, vn, r, dist) - val = tovec(r) - meta.idcs[vn] = length(meta.idcs) + 1 - push!(meta.vns, vn) - l = length(meta.vals) - n = length(val) - push!(meta.ranges, (l + 1):(l + n)) - append!(meta.vals, val) - push!(meta.dists, dist) - push!(meta.is_transformed, false) - return meta -end - -function BangBang.push!!(meta::Metadata, vn, r, dist) - push!(meta, vn, r, dist) - return meta -end - -function Base.delete!(vi::VarInfo, vn::VarName) - delete!(getmetadata(vi, vn), vn) - return vi -end - -####################################### -# Rand & replaying method for VarInfo # -####################################### - -# TODO: Maybe rename or something? -""" - _apply!(kernel!, vi::VarInfo, values, keys) - -Calls `kernel!(vi, vn, values, keys)` for every `vn` in `vi`. -""" -function _apply!(kernel!, vi::VarInfoOrThreadSafeVarInfo, values, keys) - keys_strings = map(string, collect_maybe(keys)) - num_indices_seen = 0 - - for vn in Base.keys(vi) - indices_found = kernel!(vi, vn, values, keys_strings) - if indices_found !== nothing - num_indices_seen += length(indices_found) - end - end - - if length(keys) > num_indices_seen - # Some keys have not been seen, i.e. attempted to set variables which - # we were not able to locate in `vi`. - # Find the ones we missed so we can warn the user. - unused_keys = _find_missing_keys(vi, keys_strings) - @warn "the following keys were not found in `vi`, and thus `kernel!` was not applied to these: $(unused_keys)" - end - - return vi -end - -function _apply!(kernel!, vi::NTVarInfo, values, keys) - return _typed_apply!(kernel!, vi, vi.metadata, values, collect_maybe(keys)) -end - -@generated function _typed_apply!( - kernel!, vi::NTVarInfo, metadata::NamedTuple{names}, values, keys -) where {names} - updates = map(names) do n - quote - for vn in Base.keys(metadata.$n) - indices_found = kernel!(vi, vn, values, keys_strings) - if indices_found !== nothing - num_indices_seen += length(indices_found) - end - end - end - end - - return quote - keys_strings = map(string, keys) - num_indices_seen = 0 - - $(updates...) - - if length(keys) > num_indices_seen - # Some keys have not been seen, i.e. attempted to set variables which - # we were not able to locate in `vi`. - # Find the ones we missed so we can warn the user. - unused_keys = _find_missing_keys(vi, keys_strings) - @warn "the following keys were not found in `vi`, and thus `kernel!` was not applied to these: $(unused_keys)" - end - - return vi + # TODO(mhauru) Consider doing something more clever here, e.g. checking whether + # either varinfo_left or varinfo_right is empty, or actually iterating over all the + # values to check their linked status. Needs to balance keeping the type parameter + # alive vs runtime costs. + nothing end -end - -function _find_missing_keys(vi::VarInfoOrThreadSafeVarInfo, keys) - string_vns = map(string, collect_maybe(Base.keys(vi))) - # If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`. - missing_keys = filter(keys) do key - !any(Base.Fix2(subsumes_string, key), string_vns) - end - - return missing_keys -end - -values_as(vi::VarInfo) = vi.metadata -values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon())) -function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) - iter = values_from_metadata(vi.metadata) - return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) -end -function values_as(vi::UntypedVarInfo, ::Type{D}) where {D<:AbstractDict} - return ConstructionBase.constructorof(D)(values_from_metadata(vi.metadata)) -end - -function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{NamedTuple}) where {names} - iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) - return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) -end - -function values_as( - vi::VarInfo{<:NamedTuple{names}}, ::Type{D} -) where {names,D<:AbstractDict} - iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) - return ConstructionBase.constructorof(D)(iter) -end - -values_as(vi::UntypedVectorVarInfo, args...) = values_as(vi.metadata, args...) -values_as(vi::UntypedVectorVarInfo, T::Type{Vector}) = values_as(vi.metadata, T) - -function values_from_metadata(md::Metadata) - return ( - # `copy` to avoid accidentally mutation of internal representation. - vn => copy( - from_internal_transform(md, vn, getdist(md, vn))(getindex_internal(md, vn)) - ) for vn in md.vns - ) -end - -values_from_metadata(md::VarNamedVector) = pairs(md) - -# Transforming from internal representation to distribution representation. -# Without `dist` argument: base on `dist` extracted from self. -function from_internal_transform(vi::VarInfo, vn::VarName) - return from_internal_transform(getmetadata(vi, vn), vn) -end -function from_internal_transform(md::Metadata, vn::VarName) - return from_internal_transform(md, vn, getdist(md, vn)) -end -function from_internal_transform(md::VarNamedVector, vn::VarName) - return gettransform(md, vn) -end -# With both `vn` and `dist` arguments: base on provided `dist`. -function from_internal_transform(vi::VarInfo, vn::VarName, dist) - return from_internal_transform(getmetadata(vi, vn), vn, dist) -end -from_internal_transform(::Metadata, ::VarName, dist) = from_vec_transform(dist) -function from_internal_transform(::VarNamedVector, ::VarName, dist) - return from_vec_transform(dist) -end - -# Without `dist` argument: base on `dist` extracted from self. -function from_linked_internal_transform(vi::VarInfo, vn::VarName) - return from_linked_internal_transform(getmetadata(vi, vn), vn) -end -function from_linked_internal_transform(md::Metadata, vn::VarName) - return from_linked_internal_transform(md, vn, getdist(md, vn)) -end -function from_linked_internal_transform(md::VarNamedVector, vn::VarName) - return gettransform(md, vn) -end -# With both `vn` and `dist` arguments: base on provided `dist`. -function from_linked_internal_transform(vi::VarInfo, vn::VarName, dist) - # Dispatch to metadata in case this alters the behavior. - return from_linked_internal_transform(getmetadata(vi, vn), vn, dist) -end -function from_linked_internal_transform(::Metadata, ::VarName, dist) - return from_linked_vec_transform(dist) -end -function from_linked_internal_transform(::VarNamedVector, ::VarName, dist) - return from_linked_vec_transform(dist) + return VarInfo{new_linked}(new_values, new_accs) end diff --git a/src/varname.jl b/src/varname.jl index 3eb1f2460..037b5d35d 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -1,20 +1,3 @@ -""" - subsumes_string(u::String, v::String[, u_indexing]) - -Check whether stringified variable name `v` describes a sub-range of stringified variable `u`. - -This is a very restricted version `subumes(u::VarName, v::VarName)` only really supporting: -- Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc. - -## Note -- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)` - for strings, one can always do `eval(varname(Meta.parse(u))` to get `VarName` of `u`, - and similarly to `v`. But this is slow. -""" -function subsumes_string(u::String, v::String, u_indexing=u * "[") - return u == v || startswith(v, u_indexing) -end - """ inargnames(varname::VarName, model::Model) @@ -41,3 +24,29 @@ Possibly existing indices of `varname` are neglected. ) where {s,missings,_F,_a,_T} return s in missings end + +# TODO(mhauru) This should probably be Base.size(::VarName) in AbstractPPL. +""" + varnamesize(vn::VarName) + +Return the size of the object referenced by this VarName. + +```jldoctest +julia> varnamesize(@varname(a)) +() + +julia> varnamesize(@varname(b[1:3, 2])) +(3,) + +julia> varnamesize(@varname(c.d[4].e[3, 2:5, 2, 1:4, 1])) +(4, 4) +""" +function varnamesize(vn::VarName) + l = AbstractPPL.olast(vn.optic) + if l isa AbstractPPL.Index + isempty(l.kw) || error("keyword indices are not supported") + return reduce((x, y) -> tuple(x..., y...), map(size, l.ix)) + else + return () + end +end diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl new file mode 100644 index 000000000..818fd752f --- /dev/null +++ b/src/varnamedtuple.jl @@ -0,0 +1,117 @@ +# TODO(mhauru) This module should probably be moved to AbstractPPL. +module VarNamedTuples + +using AbstractPPL +using AbstractPPL: AbstractPPL +using Distributions: Distributions, Distribution +using BangBang +using DynamicPPL: DynamicPPL + +export VarNamedTuple, vnt_size, map_pairs!!, map_values!!, apply!! + +# Currently, keyword arguments are not supported in getindex/_setindex!!. That is because +# `PartialArray` under the hood is backed by `Base.Array`. Thus, if `kw` is not empty, we +# will just error here. However, in principle, this can be expanded by allowing PartialArray +# to wrap generic array types (the 'shadow array' mechanism); see +# https://github.com/TuringLang/DynamicPPL.jl/issues/1194. +function error_kw_indices() + throw(ArgumentError("Keyword indices in VarNames are not yet supported in DynamicPPL.")) +end + +include("varnamedtuple/partial_array.jl") +# The actual definition of the VarNamedTuple struct. Yeah, it needs a better name, I'll sort +# that out. +include("varnamedtuple/vnt.jl") +include("varnamedtuple/getset.jl") +include("varnamedtuple/map.jl") + +function AbstractPPL.hasvalue(vnt::VarNamedTuple, vn::VarName) + return _haskey_optic(vnt, vn) +end + +function AbstractPPL.getvalue(vnt::VarNamedTuple, vn::VarName) + return _getindex_optic(vnt, vn) +end + +# TODO(mhauru) The following methods mimic the structure of those in +# AbstractPPLDistributionsExtension, and fall back on converting any PartialArrays to +# dictionaries, and calling the AbstractPPL methods. We should eventually make +# implementations of these directly for PartialArray, and maybe move these methods +# elsewhere. Better yet, once we no longer store VarName values in Dictionaries anywhere, +# and FlexiChains takes over from MCMCChains, this could hopefully all be removed. + +# The only case where the Distribution argument makes a difference is if the distribution +# is multivariate and the values are stored in a PartialArray. + +function AbstractPPL.hasvalue( + vnt::VarNamedTuple, vn::VarName, ::Distributions.UnivariateDistribution +) + return AbstractPPL.hasvalue(vnt, vn) +end + +function AbstractPPL.getvalue( + vnt::VarNamedTuple, vn::VarName, ::Distributions.UnivariateDistribution +) + return AbstractPPL.getvalue(vnt, vn) +end + +function AbstractPPL.hasvalue(vals::VarNamedTuple, vn::VarName, dist::Distribution) + @warn "`hasvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `hasvalue(vals, vn)`." + return AbstractPPL.hasvalue(vals, vn) +end + +function AbstractPPL.getvalue(vals::VarNamedTuple, vn::VarName, dist::Distribution) + @warn "`getvalue(vals, vn, dist)` is not implemented for $(typeof(dist)); falling back to `getvalue(vals, vn)`." + return AbstractPPL.getvalue(vals, vn) +end + +const MV_DIST_TYPES = Union{ + Distributions.MultivariateDistribution, + Distributions.MatrixDistribution, + Distributions.LKJCholesky, +} + +function AbstractPPL.hasvalue(vnt::VarNamedTuple, vn::VarName, dist::MV_DIST_TYPES) + if !haskey(vnt, vn) + # Can't even find the parent VarName, there is no hope. + return false + end + # Note that _getindex_optic, rather than Base.getindex, skips the need to denseify + # PartialArrays. + val = _getindex_optic(vnt, vn) + if !(val isa VarNamedTuple || val isa PartialArray) + # There is _a_ value. Whether it's the right kind, we do not know, but returning + # true is no worse than `hasvalue` returning true for e.g. UnivariateDistributions + # whenever there is at least some value. + return true + end + # Convert to VarName-keyed Dict. + et = val isa VarNamedTuple ? Any : eltype(val) + dval = Dict{VarName,et}() + for k in keys(val) + # VarNamedTuples have VarNames as keys, PartialArrays have Index optics. + subvn = val isa VarNamedTuple ? prefix(k, vn) : AbstractPPL.append_optic(vn, k) + dval[subvn] = _getindex_optic(val, k) + end + return AbstractPPL.hasvalue(dval, vn, dist) +end + +function AbstractPPL.getvalue(vnt::VarNamedTuple, vn::VarName, dist::MV_DIST_TYPES) + # Note that _getindex_optic, rather than Base.getindex, skips the need to denseify + # PartialArrays. + val = _getindex_optic(vnt, vn) + if !(val isa VarNamedTuple || val isa PartialArray) + return val + end + # Convert to VarName-keyed Dict. + et = val isa VarNamedTuple ? Any : eltype(val) + dval = Dict{VarName,et}() + for k in keys(val) + # VarNamedTuples have VarNames as keys, PartialArrays have Index optics. + subvn = val isa VarNamedTuple ? prefix(k, vn) : AbstractPPL.append_optic(vn, k) + dval[subvn] = _getindex_optic(val, k) + end + return AbstractPPL.getvalue(dval, vn, dist) +end + +end diff --git a/src/varnamedtuple/getset.jl b/src/varnamedtuple/getset.jl new file mode 100644 index 000000000..2f7a9da85 --- /dev/null +++ b/src/varnamedtuple/getset.jl @@ -0,0 +1,185 @@ +# We define our own getindex, setindex!!, and haskey functions, which we use to +# get/set/check values in VarNamedTuple and PartialArray. We do this because we want to be +# able to override their behaviour for some types exported from elsewhere without type +# piracy. This is needed because +# 1. We would want to index into things with lenses (from AbstractPPL.jl) using getindex and +# setindex!!, but AbstractPPL does not define these methods. +# 2. We would want `haskey` to fall back onto `checkbounds` when called on Base.Arrays. + +const IndexWithoutChild = AbstractPPL.Index{<:Tuple,<:NamedTuple,AbstractPPL.Iden} + +""" + DynamicPPL._getindex_optic(collection, optic::AbstractPPL.Optic) + DynamicPPL._getindex_optic(collection, vn::VarName) + +Access the value in `collection` at the location specified by the given `optic`. If a `VarName` +is provided, it is first converted to an optic using `AbstractPPL.varname_to_optic`. + +Here, `collection` can be either a `VarNamedTuple` or a `PartialArray`, or a leaf value stored +within one of these. + +This is semantically similar to `AbstractPPL.getvalue` but is specialised for `VarNamedTuple` +and `PartialArray`, and skips a number of checks that are unnecessary here. + +Note that it is only valid to index into a `VarNamedTuple` with a `Property` optic, and a +`PartialArray` with an `Index` optic. Other combinations are not valid. When we have reached +the leaf of the VNT i.e. a value, we could still handle pure `Index` optics if the value is +an `AbstractArray`, but otherwise the only valid optic is `Iden`. +""" +function _getindex_optic(vnt::VarNamedTuple, vn::VarName) + return _getindex_optic(vnt, AbstractPPL.varname_to_optic(vn)) +end +@inline _getindex_optic(@nospecialize(x::Any), ::AbstractPPL.Iden) = x +function _getindex_optic(vnt::VarNamedTuple, optic::AbstractPPL.Property{S}) where {S} + return _getindex_optic(getindex(vnt.data, S), optic.child) +end +function _getindex_optic(pa::PartialArray, optic::AbstractPPL.Index) + return _getindex_optic(Base.getindex(pa, optic.ix...; optic.kw...), optic.child) +end +function _getindex_optic(arr::AbstractArray, optic::IndexWithoutChild) + return Base.getindex(arr, optic.ix...; optic.kw...) +end + +function _haskey_optic(vnt::VarNamedTuple, name::VarName) + return _haskey_optic(vnt, AbstractPPL.varname_to_optic(name)) +end +@inline _haskey_optic(@nospecialize(::Any), ::AbstractPPL.Iden) = true +@inline _haskey_optic(::VarNamedTuple, ::AbstractPPL.Index) = false +function _haskey_optic(vnt::VarNamedTuple, optic::AbstractPPL.Property{S}) where {S} + return Base.haskey(vnt.data, S) && _haskey_optic(getindex(vnt.data, S), optic.child) +end +function _haskey_optic(pa::PartialArray, optic::AbstractPPL.Index) + return Base.haskey(pa, optic.ix...; optic.kw...) && + _haskey_optic(Base.getindex(pa, optic.ix...; optic.kw...), optic.child) +end +function _haskey_optic(arr::AbstractArray, optic::IndexWithoutChild) + # Note that this call to `checkbounds` can error, although it is technically out of our + # hands: it depends on how the provider of the AbstractArray has implemented + # checkbounds. For example, DimArray can error here: + # https://github.com/rafaqz/DimensionalData.jl/issues/1156. But that is not our job to fix + # -- it should be done upstream -- hence we just forward the indices. + return checkbounds(Bool, arr, optic.ix...; optic.kw...) +end + +""" + _setindex_optic!!(collection, value, key; allow_new=Val(true)) + +Like `setindex!!`, but special-cased for `VarNamedTuple` and `PartialArray` to recurse +into nested structures. + +The `allow_new` keyword argument is a performance optimisation: If it is set to +`Val(false)`, the function can assume that the key being set already exists in `collection`. +This allows skipping some code paths, which may have a minor benefit at runtime, but more +importantly, allows for better constant propagation and type stability at compile time. + +`allow_new` being set to `Val(false)` does _not_ guarantee that no new keys will be added. +It only gives the implementation of `_setindex!!` the permission to assume that the key +already exists. Setting it to `Val(false)` should be done only when the caller is sure that +the key already exists, anything else is a bug in the caller. + +Most methods of _setindex!! ignore the `allow_new` keyword argument, as they have no use for +it. See the method for setting values in a `VarNamedTuple` with a `ComposedFunction` for +when it is useful. +""" +function _setindex_optic!!(vnt::VarNamedTuple, value, name::VarName; allow_new=Val(true)) + return _setindex_optic!!( + vnt, value, AbstractPPL.varname_to_optic(name); allow_new=allow_new + ) +end +@inline function _setindex_optic!!( + @nospecialize(::Any), value, ::AbstractPPL.Iden; allow_new=Val(true) +) + return value +end +function _setindex_optic!!( + arr::AbstractArray, value, optic::IndexWithoutChild; allow_new=Val(true) +) + return BangBang.setindex!!(arr, value, optic.ix...; optic.kw...) +end + +function throw_setindex_allow_new_error() + return error( + "Attempted to set a value at a key that does not exist, but" * + " `allow_new=Val(false)` was specified. If you did not attempt" * + " to call this function yourself, this likely indicates a bug in" * + " DynamicPPL. Please file an issue at" * + " https://github.com/TuringLang/DynamicPPL.jl/issues.", + ) +end + +function _setindex_optic!!( + pa::PartialArray, value, optic::AbstractPPL.Index; allow_new=Val(true) +) + sub_value = if optic.child isa AbstractPPL.Iden + # Skip recursion + value + elseif Base.haskey(pa, optic.ix...; optic.kw...) + # Data already exists; we need to recurse into it + _setindex_optic!!( + Base.getindex(pa, optic.ix...; optic.kw...), + value, + optic.child; + allow_new=allow_new, + ) + elseif allow_new isa Val{true} + # No new data but we are allowed to create it. + make_leaf(value, optic.child) + else + throw_setindex_allow_new_error() + end + return BangBang.setindex!!(pa, sub_value, optic.ix...; optic.kw...) +end + +function _setindex_optic!!( + vnt::VarNamedTuple, value, optic::AbstractPPL.Property{S}; allow_new=Val(true) +) where {S} + sub_value = if optic.child isa AbstractPPL.Iden + # Skip recursion + value + elseif Base.haskey(vnt.data, S) + # Data already exists; we need to recurse into it + _setindex_optic!!(vnt.data[S], value, optic.child; allow_new=allow_new) + elseif allow_new isa Val{true} + # No new data but we are allowed to create it. + make_leaf(value, optic.child) + else + # If this branch is ever reached, then someone has used allow_new=Val(false) + # incorrectly. + error(""" + _setindex_optic was called with allow_new=Val(false) but the key does not exist. + This indicates a bug in DynamicPPL: Please file an issue on GitHub.""") + end + return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((sub_value,)))) +end + +""" + make_leaf(value, optic) + +Make a new leaf node for a VarNamedTuple. + +This is the function that sets any `optic` that is a `Property` to be stored as a +`VarNamedTuple`, any `Index` to be stored as a `PartialArray`, and other `Iden` optics to be +stored as raw values. It is the link that joins `VarNamedTuple` and `PartialArray` together. +""" +@inline make_leaf(@nospecialize(value::Any), ::AbstractPPL.Iden) = value +function make_leaf(value, optic::AbstractPPL.Property{S}) where {S} + sub_value = make_leaf(value, optic.child) + return VarNamedTuple(NamedTuple{(S,)}((sub_value,))) +end +function make_leaf(value, optic::AbstractPPL.Index) + isempty(optic.kw) || error_kw_indices() + sub_value = make_leaf(value, optic.child) + inds = optic.ix + num_inds = length(inds) + # The element type of the PartialArray depends on whether we are setting a single value + # or a range of values. + et = if !_is_multiindex(inds) + typeof(sub_value) + elseif _needs_arraylikeblock(sub_value, inds...) + ArrayLikeBlock{typeof(sub_value),typeof(inds)} + else + eltype(sub_value) + end + pa = PartialArray{et,num_inds}() + return BangBang.setindex!!(pa, sub_value, optic.ix...; optic.kw...) +end diff --git a/src/varnamedtuple/map.jl b/src/varnamedtuple/map.jl new file mode 100644 index 000000000..b52e04d11 --- /dev/null +++ b/src/varnamedtuple/map.jl @@ -0,0 +1,312 @@ +Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) + +# This needs to be a generated function for type stability. +@generated function _merge_recursive( + vnt1::VarNamedTuple{names1}, vnt2::VarNamedTuple{names2} +) where {names1,names2} + all_names = union(names1, names2) + exs = Expr[] + push!(exs, :(data = (;))) + for name in all_names + val_expr = if name in names1 && name in names2 + :(_merge_recursive(vnt1.data.$name, vnt2.data.$name)) + elseif name in names1 + :(vnt1.data.$name) + else + :(vnt2.data.$name) + end + push!(exs, :(data = merge(data, NamedTuple{($(QuoteNode(name)),)}(($val_expr,))))) + end + push!(exs, :(return VarNamedTuple(data))) + return Expr(:block, exs...) +end + +""" + subset(vnt::VarNamedTuple, vns) + +Create a new `VarNamedTuple` containing only the variables subsumed by ones in `vns`. +""" +function DynamicPPL.subset(vnt::VarNamedTuple, vns) + # TODO(mhauru) This could be done more efficiently by generating the code directly, + # because we could short-circuit: For instance, if `vns` contains `a`, we could + # directly include the whole subtree under `a`, without checking each individual + # variable under it. + return mapfoldl( + identity, + function (init, pair) + name, value = pair + return if any(vn -> subsumes(vn, name), vns) + setindex!!(init, value, name) + else + init + end + end, + vnt; + init=VarNamedTuple(), + ) +end + +""" + apply!!(func, vnt::VarNamedTuple, name::VarName) + +Apply `func` to the subdata at `name` in `vnt`, and set the result back at `name`. + +Like `map_values!!`, but only for a single `VarName`. + +```jldoctest +julia> using DynamicPPL: VarNamedTuple, setindex!! + +julia> using DynamicPPL.VarNamedTuples: apply!! + +julia> vnt = VarNamedTuple() +VarNamedTuple() + +julia> vnt = setindex!!(vnt, [1, 2, 3], @varname(a)) +VarNamedTuple(a = [1, 2, 3],) + +julia> apply!!(x -> x .+ 1, vnt, @varname(a)) +VarNamedTuple(a = [2, 3, 4],) +``` +""" +function apply!!(func, vnt::VarNamedTuple, name::VarName) + if !haskey(vnt, name) + throw(KeyError(repr(name))) + end + subdata = _getindex_optic(vnt, name) + new_subdata = func(subdata) + # The allow_new=Val(true) is a performance optimisation: Since we've already checked + # that the key exists, we know that no new fields will be created. + return _setindex_optic!!(vnt, new_subdata, name; allow_new=Val(false)) +end + +""" + _map_recursive!!(func, x, vn) + +Call `func` on `vn => x`, except if `x` is a `VarNamedTuple` or `PartialArray`, in which +case call `_map_recursive!!` recursively on all their elements, updating `vn` with the right +prefix. + +This is the internal implementation of `map_pairs!!`, but because it has a method defined +for literally every type in existence, we hide it behind the interface of the more +discriminating `map_pairs!!`. It makes the implementation a bit simpler, compared to +checking element types within `map_pairs!!` itself. +""" +_map_recursive!!(func, x, vn) = func(vn => x) + +# TODO(mhauru) The below is type unstable for some complex VarNames. My example case +# for which type stability fails is @varname(e.f[3].g.h[2].i). I don't understand this +# well, but I think it's just because constant propagation gives up at some point, and fails +# to go through the lines that figure out `new_et`. I could be wrong. I tried fixing this by +# lifting the first three lines of the function into a generated function, but that seems +# to run into trouble when trying to call Core.Compiler.return_type recursively on the same +# function. An earlier implementation of this function that only operated on the values, +# not on pairs of key => value, was type stable (presumably because it was a bit easier on +# constant propagation). +function _map_recursive!!(func, pa::PartialArray, vn) + # Ask the compiler to infer the return type of applying func recursively to eltype(pa). + et = eltype(pa) + index_type = AbstractPPL.Index{NTuple{ndims(pa),Int},@NamedTuple{},AbstractPPL.Iden} + new_vn_type = Core.Compiler.return_type( + AbstractPPL.append_optic, Tuple{typeof(vn),index_type} + ) + new_et = Core.Compiler.return_type( + Tuple{typeof(_map_recursive!!),typeof(func),et,new_vn_type} + ) + new_data = if new_et <: et + # We can reuse the existing data array. + pa.data + else + # We need to allocate a new data array. + similar(pa.data, new_et) + end + # Keep a dictionary of already-seen ArrayLikeBlocks to avoid redundant computations. + # This matters not only for performance, but also for correctness, because + # _map_recursive!! may mutate the value, and we don't want to mutate it multiple times. + albs_seen = Dict{ArrayLikeBlock,ArrayLikeBlock}() + @inbounds for i in CartesianIndices(pa.mask) + if pa.mask[i] + val = pa.data[i] + is_alb = val isa ArrayLikeBlock + if is_alb + if val in keys(albs_seen) + new_data[i] = albs_seen[val] + continue + end + end + ind = is_alb ? val.inds : Tuple(i) + new_vn = AbstractPPL.append_optic(vn, AbstractPPL.Index(ind, (;))) + new_val = _map_recursive!!(func, pa.data[i], new_vn) + new_data[i] = new_val + if is_alb + albs_seen[val] = new_val + end + end + end + # The above type inference may be overly conservative, so we concretise the eltype. + return _concretise_eltype!!(PartialArray(new_data, pa.mask)) +end + +function _map_recursive!!(func, alb::ArrayLikeBlock, vn) + new_block = _map_recursive!!(func, alb.block, vn) + sz_new = vnt_size(new_block) + sz_old = vnt_size(alb.block) + if !(sz_new isa SkipSizeCheck) && !(sz_old isa SkipSizeCheck) && sz_new != sz_old + throw( + DimensionMismatch( + "map_pairs!! can't change the size of an ArrayLikeBlock. Tried to change " * + "from $(vnt_size(alb.block)) to $(vnt_size(new_block)).", + ), + ) + end + return ArrayLikeBlock(new_block, alb.inds) +end + +# As above but with a prefix VarName `vn`. +@generated function _map_recursive!!(func, vnt::VarNamedTuple{Names}, vn::T) where {Names,T} + exs = Expr[] + for name in Names + push!( + exs, + :(_map_recursive!!( + func, vnt.data.$name, AbstractPPL.prefix(VarName{$(QuoteNode(name))}(), vn) + )), + ) + end + return quote + return VarNamedTuple(NamedTuple{Names}(($(exs...),))) + end +end + +""" + map_pairs!!(func, vnt::VarNamedTuple) + +Apply `func` to all key => value pairs of `vnt`, in place if possible. + +`func` should accept a pair of `VarName` and value, and return the new value to be set. +""" +@generated function map_pairs!!(func, vnt::VarNamedTuple{Names}) where {Names} + exs = Expr[] + for name in Names + push!(exs, :(_map_recursive!!(func, vnt.data.$name, VarName{$(QuoteNode(name))}()))) + end + return quote + return VarNamedTuple(NamedTuple{Names}(($(exs...),))) + end +end + +Base.foreach(func, vnt::VarNamedTuple) = map_pairs!!(p -> (func(p); p), vnt) + +""" + map_values!!(func, vnt::VarNamedTuple) + +Apply `func` to elements of `vnt`, in place if possible. +""" +map_values!!(func, vnt::VarNamedTuple) = map_pairs!!(pair -> func(pair.second), vnt) + +""" + mapreduce(f, op, vnt::VarNamedTuple; init) + +Apply `f` to all elements of `vnt`, and reduce the results using `op`, starting from `init`. + +The order is the same as in `mapfoldl`, i.e. left-associative with `init` as the +left-most value. + +`init` is a keyword argument to conform to the usual `mapreduce` interface in Base, but it +is not optional. + +`f` op` should accept pairs of `varname => value`. +""" +@generated function Base.mapreduce( + f, op, vnt::VarNamedTuple{Names}; init::InitType=nothing +) where {Names,InitType} + if InitType === Nothing + return quote + throw( + ArgumentError( + "mapreduce without init is not implemented for VarNamedTuple." + ), + ) + end + end + + exs = Expr[:(result = init)] + for name in Names + push!( + exs, + quote + result = _mapreduce_recursive( + f, op, vnt.data.$name, VarName{$(QuoteNode(name))}(), result + ) + end, + ) + end + push!(exs, :(return result)) + return Expr(:block, exs...) +end + +# Our mapreduce is always left-associative. +Base.mapfoldl(f, op, vnt::VarNamedTuple; init=nothing) = mapreduce(f, op, vnt; init=init) + +_mapreduce_recursive(f, op, x, vn, init) = op(init, f(vn => x)) +_mapreduce_recursive(f, op, pa::ArrayLikeBlock, vn, init) = op(init, f(vn => pa.block)) + +# As above but with a prefix VarName `vn`. +@generated function _mapreduce_recursive( + f, op, vnt::VarNamedTuple{Names}, vn, init +) where {Names} + exs = Expr[:(result = init)] + for name in Names + push!( + exs, + quote + result = _mapreduce_recursive( + f, + op, + vnt.data.$name, + AbstractPPL.prefix(VarName{$(QuoteNode(name))}(), vn), + result, + ) + end, + ) + end + push!(exs, :(return result)) + return Expr(:block, exs...) +end + +function _mapreduce_recursive(f, op, pa::PartialArray, vn, init) + result = init + et = eltype(pa) + + albs_seen = Set{ArrayLikeBlock}() + @inbounds for i in CartesianIndices(pa.mask) + if pa.mask[i] + val = pa.data[i] + is_alb = val isa ArrayLikeBlock + if is_alb + if val in albs_seen + continue + end + push!(albs_seen, val) + end + ind = is_alb ? val.inds : Tuple(i) + new_vn = AbstractPPL.append_optic(vn, AbstractPPL.Index(ind, (;))) + result = _mapreduce_recursive(f, op, pa.data[i], new_vn, result) + end + end + return result +end + +# TODO(mhauru) We could try to keep the return types of these more tight, rather than always +# return the same, abstract element type. Would that be better? It would be faster in some +# cases, but would be less consistent, and could result in a lot of allocations in the +# mapreduce, as the element type is gradually expanded. +Base.keys(vnt::VarNamedTuple) = mapreduce(first, push!, vnt; init=VarName[]) +Base.values(vnt::VarNamedTuple) = mapreduce(pair -> pair.second, push!, vnt; init=Any[]) + +function Base.length(vnt::VarNamedTuple) + len = 0 + for subdata in vnt.data + len += subdata isa VarNamedTuple || subdata isa PartialArray ? length(subdata) : 1 + end + return len +end diff --git a/src/varnamedtuple/partial_array.jl b/src/varnamedtuple/partial_array.jl new file mode 100644 index 000000000..ca5d0272d --- /dev/null +++ b/src/varnamedtuple/partial_array.jl @@ -0,0 +1,796 @@ +# Some utilities for checking what sort of indices we are dealing with. +# The non-generated function implementations of these would be +# _has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) +# function _is_multiindex(::T) where {T<:Tuple} +# return any(x <: AbstractUnitRange || x <: Colon for x in T.parameters) +# end +# However, constant propagation sometimes fails if the index tuple is too big (e.g. length +# 4), so we play it safe and use generated functions. Constant propagating these is +# important, because many functions choose different paths based on their values, which +# would lead to type instability if they were only evaluated at runtime. +@generated function _has_colon_or_dynamicindex(::T) where {T<:Tuple} + for x in T.parameters + if x <: Colon || x <: AbstractPPL.DynamicIndex + return :(return true) + end + end + return :(return false) +end +@generated function _is_multiindex(::T) where {T<:Tuple} + for x in T.parameters + if x <: AbstractUnitRange || x <: Colon + return :(return true) + end + end + return :(return false) +end + +""" + _merge_recursive(x1, x2) + +Recursively merge two values `x1` and `x2`. + +Unlike `Base.merge`, this function is defined for all types, and by default returns the +second argument. It is overridden for `PartialArray` and `VarNamedTuple`, since they are +nested containers, and calls itself recursively on all elements that are found in both +`x1` and `x2`. + +In other words, if both `x` and `y` are collections with the key `a`, `Base.merge(x, y)[a]` +is `y[a]`, whereas `_merge_recursive(x, y)[a]` will be `_merge_recursive(x[a], y[a])`, +unless no specific method is defined for the type of `x` and `y`, in which case +`_merge_recursive(x, y) === y`. +""" +_merge_recursive(_, x2) = x2 + +"""The factor by which we increase the dimensions of PartialArrays when resizing them.""" +const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 + +"""A convenience for defining method argument type bounds.""" +const INDEX_TYPES = Union{Integer,AbstractUnitRange,Colon,AbstractPPL.DynamicIndex} + +""" + SkipSizeCheck() + +A special return value for `vnt_size` indicating that size checks should be skipped. +""" +struct SkipSizeCheck end + +""" + vnt_size(x) + +Get the size of an object `x` for use in `VarNamedTuple` and `PartialArray`. + +By default, this falls back onto `Base.size`, but can be overloaded for custom types. +This notion of type is used to determine whether a value can be set into a `PartialArray` +as a block, see the docstring of `PartialArray` and `ArrayLikeBlock` for details. + +A special return value of `SkipSizeCheck()` indicates that the size check should be skipped. +""" +vnt_size(x) = size(x) + +""" + ArrayLikeBlock{T,I} + +A wrapper for non-array blocks stored in `PartialArray`s. + +When setting a value in a `PartialArray` over a range of indices, if the value being set +is not itself an `AbstractArray`, but has a well-defined size, we wrap it in an +`ArrayLikeBlock`, which records both the value and the indices it was set with. + +When getting values from a `PartialArray`, if any of the requested indices correspond to +an `ArrayLikeBlock`, we check that the requested indices match the ones used to set the +value. If they do, we return the underlying block, otherwise we throw an error. +""" +struct ArrayLikeBlock{T,I} + block::T + inds::I + + function ArrayLikeBlock(block::T, inds::I) where {T,I} + if !_is_multiindex(inds) + throw(ArgumentError("ArrayLikeBlock must be constructed with a multi-index")) + end + return new{T,I}(block, inds) + end +end + +function Base.show(io::IO, alb::ArrayLikeBlock) + # Note the distinction: The raw strings that form part of the structure of the print + # out are `print`ed, whereas the keys and values are `show`n. The latter ensures + # that strings are quoted, Symbols are prefixed with :, etc. + print(io, "ArrayLikeBlock(") + show(io, alb.block) + print(io, ", ") + show(io, alb.inds) + print(io, ")") + return nothing +end + +_blocktype(::Type{ArrayLikeBlock{T}}) where {T} = T + +""" + PartialArray{ElType,numdims} + +An array-like like structure that may only have some of its elements defined. + +A `PartialArray` is like a `Base.Array,` except not all of its elements are necessarily +defined. That is to say, one can create an empty `PartialArray` `arr` and e.g. set +`arr[3,2] = 5`, but asking for `arr[1,1]` may throw a `BoundsError` if `[1, 1]` has not been +explicitly set yet. + +`PartialArray`s can be indexed with integer indices and ranges. Indexing is always 1-based. +Other types of indexing allowed by `Base.Array` are not supported. Some of these are simply +because we haven't seen a need and haven't bothered to implement them, namely boolean +indexing, linear indexing into multidimensional arrays, and indexing with arrays. However, +notably, indexing with colons (i.e. `:`) is not supported for more fundamental reasons. + +To understand this, note that a `PartialArray` has no well-defined size. For example, if one +creates an empty array and sets `arr[3,2]`, it is unclear if that should be taken to mean +that the array has size `(3,2)`: It could be larger, and saying that the size is `(3,2)` +would also misleadingly suggest that all elements within `1:3,1:2` are set. This is also why +colon indexing is ill-defined: If one would e.g. set `arr[2,:] = [1,2,3]`, we would have no +way of saying whether the right hand side is of an acceptable size or not. + +The fact that its size is ill-defined also means that `PartialArray` is not a subtype of +`AbstractArray`. + +All indexing into `PartialArray`s is done with `getindex` and `setindex!!`. `setindex!`, +`push!`, etc. are not defined. The element type of a `PartialArray` will change as needed +under `setindex!!` to accomoddate the new values. + +Like `Base.Array`s, `PartialArray`s have a well-defined, compile-time-known element type +`ElType` and number of dimensions `numdims`. Indices into a `PartialArray` must have exactly +`numdims` elements. + +One can set values in a `PartialArray` either element-by-element, or with ranges like +`arr[1:3,2] = [5,10,15]`. When setting values over a range of indices, the value being set +must either be an `AbstractArray` or otherwise something for which `vnt_size(value)` or +`Base.size(value)` (which `vnt_size` falls back onto) is defined, and the size matches the +range. If the value is an `AbstractArray`, the elements are copied individually, but if it +is not, the value is stored as a block, that takes up the whole range, e.g. `[1:3,2]`, but +is only a single object. Getting such a block-value must be done with the exact same range +of indices, otherwise an error is thrown. + +If the element type of a `PartialArray` is not concrete, any call to `setindex!!` will check +if, after the new value has been set, the element type can be made more concrete. If so, +a new `PartialArray` with a more concrete element type is returned. Thus the element type +of any `PartialArray` should always be as concrete as is allowed by the elements in it. + +The internal implementation of an `PartialArray` consists of two arrays: one holding the +data and the other one being a boolean mask indicating which elements are defined. These +internal arrays may need resizing when new elements are set that have index ranges larger +than the current internal arrays. To avoid resizing too often, the internal arrays are +resized in exponentially increasing steps. This means that most `setindex!!` calls are very +fast, but some may incur substantial overhead due to resizing and copying data. It also +means that the largest index set so far determines the memory usage of the `PartialArray`. +`PartialArray`s are thus well-suited when most values in it will eventually be set. If only +a few scattered values are set, a structure like `SparseArray` may be more appropriate. +""" +struct PartialArray{ElType,num_dims} + # TODO(mhauru) Consider trying FixedSizeArrays instead, see how it would change + # performance. We reallocate new Arrays every time when resizing anyway, except for + # Vectors, which can be extended in place. + data::Array{ElType,num_dims} + mask::Array{Bool,num_dims} + + function PartialArray( + data::Array{ElType,num_dims}, mask::Array{Bool,num_dims} + ) where {ElType,num_dims} + if size(data) != size(mask) + throw(ArgumentError("Data and mask arrays must have the same size")) + end + return new{ElType,num_dims}(data, mask) + end +end + +""" + PartialArray{ElType,num_dims}(args::Vararg{Pair}; min_size=nothing) + +Create a new `PartialArray`. + +The element type and number of dimensions have to be specified explicitly as type +parameters. The positional arguments can be `Pair`s of indices and values. For example, +```jldoctest +julia> using DynamicPPL.VarNamedTuples: PartialArray + +julia> pa = PartialArray{Int,2}((1,2) => 5, (3,4) => 10) +PartialArray{Int64,2}((1, 2) => 5, (3, 4) => 10) +``` + +The optional keyword argument `min_size` can be used to specify the minimum initial size. +This is purely a performance optimisation, to avoid resizing if the eventual size is known +ahead of time. +""" +function PartialArray{ElType,num_dims}( + args::Vararg{Pair}; min_size::Union{Tuple,Nothing}=nothing +) where {ElType,num_dims} + dims = if min_size === nothing + ntuple(_ -> PARTIAL_ARRAY_DIM_GROWTH_FACTOR, num_dims) + else + map(_partial_array_dim_size, min_size) + end + data = Array{ElType,num_dims}(undef, dims) + mask = fill(false, dims) + pa = PartialArray(data, mask) + + for (inds, value) in args + pa = setindex!!(pa, convert(ElType, value), inds...) + end + return pa +end + +Base.ndims(::PartialArray{ElType,num_dims}) where {ElType,num_dims} = num_dims +Base.eltype(::PartialArray{ElType}) where {ElType} = ElType + +function Base.show(io::IO, pa::PartialArray) + print(io, "PartialArray{", eltype(pa), ",", ndims(pa), "}(") + is_first = true + for inds in CartesianIndices(pa.mask) + if @inbounds(!pa.mask[inds]) + continue + end + if !is_first + print(io, ", ") + else + is_first = false + end + val = @inbounds(pa.data[inds]) + # Note the distinction: The raw strings that form part of the structure of the print + # out are `print`ed, whereas the keys and values are `show`n. The latter ensures + # that strings are quoted, Symbols are prefixed with :, etc. + show(io, Tuple(inds)) + print(io, " => ") + show(io, val) + end + print(io, ")") + return nothing +end + +# We deliberately don't define Base.size for PartialArray, because it is ill-defined. +# The size of the .data field is an implementation detail. +_internal_size(pa::PartialArray, args...) = size(pa.data, args...) + +# Even though a PartialArray has no well-defined size, we still allow it to be used as an +# ArrayLikeBlock. This enables setting values for keys like @varname(x[1:3][1]), which will +# be stored as a PartialArray wrapped in an ArrayLikeBlock, stored in another PartialArray. +# Note that this bypasses _any_ size checks, so that e.g. @varname(x[1:3][1,15]) is also a +# valid key. +vnt_size(::PartialArray) = SkipSizeCheck() + +function Base.copy(pa::PartialArray) + # Make a shallow copy of pa, except for any VarNamedTuple elements, which we recursively + # copy. + pa_copy = PartialArray(copy(pa.data), copy(pa.mask)) + et = eltype(pa) + if ( + VarNamedTuple <: et || + et <: VarNamedTuple || + PartialArray <: et || + et <: PartialArray + ) + @inbounds for i in eachindex(pa.mask) + if pa.mask[i] + val = @inbounds pa_copy.data[i] + if val isa VarNamedTuple || val isa PartialArray + pa_copy.data[i] = copy(val) + end + end + end + end + return pa_copy +end + +function Base.:(==)(pa1::PartialArray, pa2::PartialArray) + if ndims(pa1) != ndims(pa2) + return false + end + size1 = _internal_size(pa1) + size2 = _internal_size(pa2) + # TODO(mhauru) This could be optimised by not calling checkbounds on all elements + # outside the size of an array, but not sure it's worth it. + merge_size = ntuple(i -> max(size1[i], size2[i]), ndims(pa1)) + result = true + for i in CartesianIndices(merge_size) + m1 = checkbounds(Bool, pa1.mask, Tuple(i)...) ? pa1.mask[i] : false + m2 = checkbounds(Bool, pa2.mask, Tuple(i)...) ? pa2.mask[i] : false + if m1 != m2 + return false + end + if m1 + elements_equal = pa1.data[i] == pa2.data[i] + if elements_equal === false + return false + elseif elements_equal === missing + # This branch can't short-circuit and just return missing, because some + # later values may be straight-up not equal. + result = missing + end + end + end + return result +end + +# Exactly as == above, except the comparison of the data elements uses isequal. +function Base.isequal(pa1::PartialArray, pa2::PartialArray) + if ndims(pa1) != ndims(pa2) + return false + end + size1 = _internal_size(pa1) + size2 = _internal_size(pa2) + # TODO(mhauru) This could be optimised by not calling checkbounds on all elements + # outside the size of an array, but not sure it's worth it. + merge_size = ntuple(i -> max(size1[i], size2[i]), ndims(pa1)) + for i in CartesianIndices(merge_size) + m1 = checkbounds(Bool, pa1.mask, Tuple(i)...) ? pa1.mask[i] : false + m2 = checkbounds(Bool, pa2.mask, Tuple(i)...) ? pa2.mask[i] : false + if m1 != m2 + return false + end + if m1 && !isequal(pa1.data[i], pa2.data[i]) + return false + end + end + return true +end + +function Base.hash(pa::PartialArray, h::UInt) + h = hash(ndims(pa), h) + for i in eachindex(pa.mask) + @inbounds if pa.mask[i] + h = hash(i, h) + h = hash(pa.data[i], h) + end + end + return h +end + +Base.isempty(pa::PartialArray) = !any(pa.mask) +Base.empty(pa::PartialArray) = PartialArray(similar(pa.data), fill(false, size(pa.mask))) +function BangBang.empty!!(pa::PartialArray) + fill!(pa.mask, false) + return pa +end + +# This is a tad hacky: We use _mapreduce_recursive which requires a prefix VarName. We give +# it the non-sense @varname(_), and then strip it away with the mapping function, returning +# only the optic. +function Base.keys(pa::PartialArray) + return _mapreduce_recursive(pair -> first(pair).optic, push!, pa, @varname(_), Any[]) +end + +# Length could be defined as a special case of mapreduce, but it's harder to keep it type +# stable that way: If the element type is abstract, we end up calling _mapreduce_recursive +# on an abstract type, which makes the type of the cumulant Any. +function Base.length(pa::PartialArray) + len = 0 + @inbounds for i in eachindex(pa.mask) + if !pa.mask[i] + continue + end + val = pa.data[i] + len += val isa VarNamedTuple || val isa PartialArray ? length(val) : 1 + end + return len +end + +""" + _concretise_eltype!!(pa::PartialArray) + +Concretise the element type of a `PartialArray`. + +Returns a new `PartialArray` with the same data and mask as `pa`, but with its element type +concretised to the most specific type that can hold all currently defined elements. + +Note that this function is fundamentally type unstable if the current element type of `pa` +is not already concrete. + +The name has a `!!` not because it mutates its argument, but because the return value +aliases memory with the argument, and is thus not independent of it. +""" +function _concretise_eltype!!(pa::PartialArray) + if isconcretetype(eltype(pa)) + return pa + end + # We could use promote_type here, instead of typejoin. However, that would e.g. + # cause Ints to be converted to Float64s, since + # promote_type(Int, Float64) == Float64, which can cause problems. See + # https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188. + # Base.promote_typejoin would be like typejoin, but creates Unions out of Nothing + # and Missing, rather than falling back on Any. However, it's not exported. + new_et = typejoin((typeof(pa.data[i]) for i in eachindex(pa.mask) if pa.mask[i])...) + # TODO(mhauru) Should we check as below, or rather isconcretetype(new_et)? + # In other words, does it help to be more concrete, even if we aren't fully concrete? + if new_et === eltype(pa) + # The types of the elements do not allow for concretisation. + return pa + end + new_data = Array{new_et,ndims(pa)}(undef, _internal_size(pa)) + @inbounds for i in eachindex(pa.mask) + if pa.mask[i] + new_data[i] = pa.data[i] + end + end + return PartialArray(new_data, pa.mask) +end + +"""Return the length needed in a dimension given an index.""" +_length_needed(i::Integer) = i +_length_needed(r::AbstractUnitRange) = last(r) + +"""Take the minimum size that a dimension of a PartialArray needs to be, and return the size +we choose it to be. This size will be the smallest possible power of +PARTIAL_ARRAY_DIM_GROWTH_FACTOR. Growing PartialArrays in big jumps like this helps reduce +data copying, as resizes aren't needed as often. +""" +function _partial_array_dim_size(min_dim) + factor = PARTIAL_ARRAY_DIM_GROWTH_FACTOR + return factor^(Int(ceil(log(factor, min_dim)))) +end + +"""Return the minimum internal size needed for a `PartialArray` to be able set the value +at inds. +""" +function _min_size(pa::PartialArray, inds) + return ntuple(i -> max(_internal_size(pa, i), _length_needed(inds[i])), length(inds)) +end + +"""Resize a PartialArray to be able to accommodate the index inds. This operates in place +for vectors, but makes a copy for higher-dimensional arrays, unless no resizing is +necessary, in which case this is a no-op.""" +function _resize_partialarray!!(pa::PartialArray, inds) + min_size = _min_size(pa, inds) + new_size = map(_partial_array_dim_size, min_size) + if new_size == _internal_size(pa) + return pa + end + # Generic multidimensional Arrays can not be resized, so we need to make a new one. + # See https://github.com/JuliaLang/julia/issues/37900 + new_data = Array{eltype(pa),ndims(pa)}(undef, new_size) + new_mask = fill(false, new_size) + # Note that we have to use CartesianIndices instead of eachindex, because the latter + # may use a linear index that does not match between the old and the new arrays. + @inbounds for i in CartesianIndices(pa.data) + mask_val = pa.mask[i] + if mask_val + new_mask[i] = mask_val + new_data[i] = pa.data[i] + end + end + return PartialArray(new_data, new_mask) +end + +# The below implements the same functionality as above, but more performantly for 1D arrays. +function _resize_partialarray!!(pa::PartialArray{Eltype,1}, (ind,)) where {Eltype} + # Resize arrays to accommodate new indices. + old_size = _internal_size(pa, 1) + min_size = max(old_size, _length_needed(ind)) + new_size = _partial_array_dim_size(min_size) + if new_size == old_size + return pa + end + resize!(pa.data, new_size) + resize!(pa.mask, new_size) + @inbounds pa.mask[(old_size + 1):new_size] .= false + return pa +end + +"""Throw an appropriate error if the given indices are invalid for `pa`.""" +function _check_index_validity(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) where {N} + if length(inds) != ndims(pa) + throw(BoundsError(pa, inds)) + end + if _has_colon_or_dynamicindex(inds) + msg = """ + Indexing PartialArrays with Colon or AbstractPPL.DynamicIndex is not supported. + You may need to concretise the `VarName` first.""" + throw(ArgumentError(msg)) + end + return nothing +end + +""" + Base.getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}; kw...) + +Obtain the value at the given indices from the `PartialArray`. This needs to be smarter than +just calling Base.getindex on the internal data array, because we need to check if the +requested indices correspond to an ArrayLikeBlock. +""" +function Base.getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}; kw...) + isempty(kw) || error_kw_indices() + # The unmodified inds is needed later for ArrayLikeBlock checks. + orig_inds = inds + _check_index_validity(pa, inds) + if !(checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...)))) + throw(BoundsError(pa, inds)) + end + val = getindex(pa.data, inds...) + + # If not for ArrayLikeBlocks, at this point we could just return val directly. However, + # we need to check if val contains any ArrayLikeBlocks, and if so, make sure that that + # we are retrieving exactly that block and nothing else. + + # The error we'll throw if the retrieval is invalid. + err = ArgumentError(""" + A non-Array value set with a range of indices must be retrieved with the same + range of indices. + """) + if val isa ArrayLikeBlock + # Tried to get a single value, but it's an ArrayLikeBlock. + throw(err) + elseif val isa Array && (eltype(val) <: ArrayLikeBlock || ArrayLikeBlock <: eltype(val)) + # Tried to get a range of values, and at least some of them may be ArrayLikeBlocks. + # The below isempty check is deliberately kept separate from the outer elseif, + # because the outer one can be resolved at compile time. + if isempty(val) + # We need to return an empty array, but for type stability, we want to unwrap + # any ArrayLikeBlock types in the element type. + return if eltype(val) <: ArrayLikeBlock + Array{_blocktype(eltype(val)),ndims(val)}() + else + val + end + end + first_elem = first(val) + if !(first_elem isa ArrayLikeBlock) + throw(err) + end + if orig_inds != first_elem.inds + # The requested indices do not match the ones used to set the value. + throw(err) + end + # If _setindex!! works correctly, we should only be able to reach this point if all + # the elements in `val` are identical to first_elem. Thus we just return that one. + return first(val).block + else + return val + end +end + +function Base.haskey(pa::PartialArray, inds::Vararg{INDEX_TYPES}; kw...) + isempty(kw) || error_kw_indices() + _check_index_validity(pa, inds) + hasall = + checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...))) + + # If not for ArrayLikeBlocks, we could just return hasall directly. However, we need to + # check that if any ArrayLikeBlocks are included, they are fully included. + et = eltype(pa) + if !(et <: ArrayLikeBlock || ArrayLikeBlock <: et) + # pa can't possibly hold any ArrayLikeBlocks, so nothing to do. + return hasall + end + + if !hasall + return false + end + # From this point on we can assume that all the requested elements are set, and the only + # thing to check is that we are not partially indexing into any ArrayLikeBlocks. + # We've already checked checkbounds at the top of the function, and returned if it + # wasn't true, so @inbounds is safe. + subdata = @inbounds getindex(pa.data, inds...) + if !_is_multiindex(inds) + return !(subdata isa ArrayLikeBlock) + end + return !any(elem -> elem isa ArrayLikeBlock && elem.inds != inds, subdata) +end + +function BangBang.delete!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + _check_index_validity(pa, inds) + if _is_multiindex(inds) + pa.mask[inds...] .= false + else + pa.mask[inds...] = false + end + return pa +end + +_ensure_range(r::AbstractUnitRange) = r +_ensure_range(i::Integer) = i:i + +""" + _remove_partial_blocks!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + +Remove any ArrayLikeBlocks that overlap with the given indices from the PartialArray. + +Note that this removes the whole block, even the parts that are within `inds`, to avoid +partially indexing into ArrayLikeBlocks. +""" +function _remove_partial_blocks!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + et = eltype(pa) + if !(et <: ArrayLikeBlock || ArrayLikeBlock <: et) + # pa can't possibly hold any ArrayLikeBlocks, so nothing to do. + return pa + end + + for i in CartesianIndices(map(_ensure_range, inds)) + if pa.mask[i] + val = @inbounds pa.data[i] + if val isa ArrayLikeBlock + pa = delete!!(pa, val.inds...) + end + end + end + return pa +end + +""" + _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES}) + +Check if the given value needs to be wrapped in an `ArrayLikeBlock` when being set at inds. + +The value only depends on the types of the arguments, and should be constant propagated. +""" +function _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES}) + return _is_multiindex(inds) && + !isa(value, AbstractArray) && + hasmethod(vnt_size, Tuple{typeof(value)}) +end + +function BangBang.setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}; kw...) + isempty(kw) || error_kw_indices() + orig_inds = inds + _check_index_validity(pa, inds) + pa = if checkbounds(Bool, pa.mask, inds...) + pa + else + _resize_partialarray!!(pa, inds) + end + pa = _remove_partial_blocks!!(pa, inds...) + + new_data = pa.data + if _needs_arraylikeblock(value, inds...) + inds_size = reduce((x, y) -> tuple(x..., y...), map(size, inds)) + if !(vnt_size(value) isa SkipSizeCheck) && vnt_size(value) != inds_size + throw( + DimensionMismatch( + "Assigned value has size $(vnt_size(value)), which does not match " * + "the size implied by the indices $(map(x -> _length_needed(x), inds)).", + ), + ) + end + # At this point we know we have a value that is not an AbstractArray, but it has + # some notion of size, and that size matches the indices that are being set. In this + # case we wrap the value in an ArrayLikeBlock, and set all the individual indices + # to point to that. + alb = ArrayLikeBlock(value, orig_inds) + new_data = setindex!!(new_data, fill(alb, inds_size...), inds...) + else + new_data = setindex!!(new_data, value, inds...) + end + + if _is_multiindex(inds) + pa.mask[inds...] .= true + else + pa.mask[inds...] = true + end + return _concretise_eltype!!(PartialArray(new_data, pa.mask)) +end + +Base.merge(x1::PartialArray, x2::PartialArray) = _merge_recursive(x1, x2) + +function _merge_element_recursive(x1::PartialArray, x2::PartialArray, ind::CartesianIndex) + m1 = x1.mask[ind] + m2 = x2.mask[ind] + return if m1 && m2 + _merge_recursive(x1.data[ind], x2.data[ind]) + elseif m2 + x2.data[ind] + else + x1.data[ind] + end +end + +# TODO(mhauru) Would this benefit from a specialised method for 1D PartialArrays? +function _merge_recursive(pa1::PartialArray, pa2::PartialArray) + if ndims(pa1) != ndims(pa2) + throw( + ArgumentError("Cannot merge PartialArrays with different numbers of dimensions") + ) + end + num_dims = ndims(pa1) + merge_size = ntuple(i -> max(_internal_size(pa1, i), _internal_size(pa2, i)), num_dims) + return if merge_size == _internal_size(pa2) + # Either pa2 is strictly bigger than pa1 or they are equal in size. + result = copy(pa2) + for i in CartesianIndices(pa1.data) + @inbounds if pa1.mask[i] + result = setindex!!( + result, _merge_element_recursive(pa1, result, i), Tuple(i)... + ) + end + end + result + else + if merge_size == _internal_size(pa1) + # pa1 is bigger than pa2 + result = copy(pa1) + for i in CartesianIndices(pa2.data) + @inbounds if pa2.mask[i] + result = setindex!!( + result, _merge_element_recursive(result, pa2, i), Tuple(i)... + ) + end + end + result + else + # Neither is strictly bigger than the other. + # We could use promote_type here, instead of typejoin. However, that would e.g. + # cause Ints to be converted to Float64s, since + # promote_type(Int, Float64) == Float64, which can cause problems. See + # https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188. + # Base.promote_typejoin would be like typejoin, but creates Unions out of + # Nothing and Missing, rather than falling back on Any. However, it's not + # exported. + et = typejoin(eltype(pa1), eltype(pa2)) + new_data = Array{et,num_dims}(undef, merge_size) + new_mask = fill(false, merge_size) + result = PartialArray(new_data, new_mask) + for i in CartesianIndices(pa2.data) + @inbounds if pa2.mask[i] + result.mask[i] = true + result.data[i] = pa2.data[i] + end + end + for i in CartesianIndices(pa1.data) + @inbounds if pa1.mask[i] + result = setindex!!( + result, _merge_element_recursive(pa1, result, i), Tuple(i)... + ) + end + end + result + end + end +end + +""" + _dense_array(pa::PartialArray) + +Return a `Base.Array` of the elements of the `PartialArray`. + +If the `PartialArray` has any missing elements that are within the block of set elements, +this will error. For instance, if `pa` is two-dimensional and (2,2) is set, but one of +(1,1), (1,2), or (2,1) is not. + +Likewise, if `pa` includes any blocks set as `ArrayLikeBlocks`, this will error. +""" +function _dense_array(pa::PartialArray) + # Find the size of the dense array, by checking what are the largest indices set in pa. + num_dims = ndims(pa) + size_needed = fill(0, num_dims) + # TODO(mhauru) This could be optimised by not looping over the whole array: If e.g. + # (3,3) is set, we have no need to check any indices within the block (3,3). + for ind in CartesianIndices(pa.mask) + @inbounds if !pa.mask[ind] + continue + end + for d in 1:num_dims + size_needed[d] = max(size_needed[d], ind[d]) + end + end + + # Check that all indices within size_needed are set. + slice = ntuple(d -> 1:size_needed[d], num_dims) + if !all(pa.mask[slice...]) + throw( + ArgumentError( + "Cannot convert PartialArray to dense Array when some elements within " * + "the dense block are not set.", + ), + ) + end + + retval = pa.data[slice...] + if eltype(pa) <: ArrayLikeBlock || ArrayLikeBlock <: eltype(pa) + for ind in CartesianIndices(retval) + @inbounds if retval[ind] isa ArrayLikeBlock + throw( + ArgumentError( + "Cannot convert PartialArray to dense Array when some elements " * + "are set as ArrayLikeBlocks.", + ), + ) + end + end + end + return retval +end diff --git a/src/varnamedtuple/vnt.jl b/src/varnamedtuple/vnt.jl new file mode 100644 index 000000000..a7e4a6eb5 --- /dev/null +++ b/src/varnamedtuple/vnt.jl @@ -0,0 +1,203 @@ +""" + VarNamedTuple{names,Values} + +A `NamedTuple`-like structure with `VarName` keys. + +`VarNamedTuple` is a data structure for storing arbitrary data, keyed by `VarName`s, in an +efficient and type stable manner. It is mainly used through `getindex`, `setindex!!`, and +`haskey`, all of which accept `VarName`s and only `VarName`s as keys. Other notable methods +are `merge` and `subset`. + +`VarNamedTuple` has an ordering to its elements, and two `VarNamedTuple`s with the same keys +and values but in different orders are considered different for equality and hashing. +Iterations such as `keys` and `values` respect this ordering. The ordering is dependent on +the order in which elements were inserted into the `VarNamedTuple`, though isn't always +equal to it. More specifically + +* Any new keys that have a joint parent `VarName` with an existing key are inserted after + that key. For instance, if one first inserts, in order, `@varname(a.x)`, `@varname(b)`, + and `@varname(a.y)`, the resulting order will be + `(@varname(a.x), @varname(a.y), @varname(b))`. +* `Index` keys`, like `@varname(a[3])` or `@varname(b[2,3,4:5])`, are always iterated + in the same order an `Array` with the same indices would be iterated. For instance, + if one first inserts, in order, `@varname(a[2])`, `@varname(b)`, and `@varname(a[1])`, + the resulting order will be `(@varname(a[1]), @varname(a[2]), @varname(b))`. + +Otherwise insertion order is respected. + +The there are two major limitations to indexing by VarNamedTuples: + +* `VarName`s with `Colon`s, (e.g. `a[:]`) are not supported. This is because the meaning of + `a[:]` is ambiguous if only some elements of `a`, say `a[1]` and `a[3]`, are defined. + However, _concretised_ `VarName`s with `Colon`s are supported. +* Any `VarNames` with `Index` lenses must have a consistent number of indices. That is, one + cannot set `a[1]` and `a[1,2]` in the same `VarNamedTuple`. + +`setindex!!` and `getindex` on `VarNamedTuple` are type stable as long as one does not store +heterogeneous data under different indices of the same symbol. That is, if either + +* one sets `a[1]` and `a[2]` to be of different types, or +* if `a[1]` and `a[2]` both exist, one sets `a[1].b` without setting `a[2].b`, + +then getting values for `a[1]` or `a[2]` will not be type stable. + +`VarNamedTuple` is intrinsically linked to `PartialArray`, which it'll use to store data +related to `VarName`s with `Index` components. +""" +struct VarNamedTuple{Names,Values} + data::NamedTuple{Names,Values} + + function VarNamedTuple(data::NamedTuple{Names,Values}) where {Names,Values} + return new{Names,Values}(data) + end +end + +VarNamedTuple(; kwargs...) = VarNamedTuple((; kwargs...)) + +""" + VarNamedTuple(d) + VarNamedTuple(nt::NamedTuple) + +Create a `VarNamedTuple` from a collection or a `NamedTuple`. + +Any collection `d` is assumed to be an iterable of key-value pairs, where the keys are +`VarName`s. This could be a an `AbstractDict`, a vector of `Pair`s or `Tuple`s, etc. The +only exception is `NamedTuple`s, for which the `Symbol` keys are converted to `VarName`s. + +Note that `VarNamedTuple` has an ordering to its elements, and two `VarNamedTuple`s with the +same keys and values but in different orders are considered different. If `d` does not +guarantee an iteration order, then the order of the elements in the resulting +`VarNamedTuple` is undefined. +""" +function VarNamedTuple(d) + vnt = VarNamedTuple() + for (k, v) in d + vnt = setindex!!(vnt, v, k) + end + return vnt +end + +Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data +Base.isequal(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = isequal(vnt1.data, vnt2.data) +Base.hash(vnt::VarNamedTuple, h::UInt) = hash(vnt.data, h) + +function Base.show(io::IO, vnt::VarNamedTuple) + if isempty(vnt.data) + return print(io, "VarNamedTuple()") + end + print(io, "VarNamedTuple") + show(io, vnt.data) + return nothing +end + +function Base.copy(vnt::VarNamedTuple{names}) where {names} + # Make a shallow copy of vnt, except for any VarNamedTuple or PartialArray elements, + # which we recursively copy. + return VarNamedTuple( + NamedTuple{names}( + map( + x -> x isa Union{VarNamedTuple,PartialArray} ? copy(x) : x, values(vnt.data) + ), + ), + ) +end + +# PartialArrays are an implementation detail of VarNamedTuple, and should never be the +# return value of getindex. Thus, we automatically convert them to dense arrays if needed. +# TODO(mhauru) The below doesn't handle nested PartialArrays. Is that a problem? +_dense_array_if_needed(pa::PartialArray) = _dense_array(pa) +_dense_array_if_needed(x) = x +function Base.getindex(vnt::VarNamedTuple, vn::VarName) + return _dense_array_if_needed(_getindex_optic(vnt, vn)) +end + +Base.haskey(vnt::VarNamedTuple, vn::VarName) = _haskey_optic(vnt, vn) + +function BangBang.setindex!!(vnt::VarNamedTuple, value, vn::VarName) + return _setindex_optic!!(vnt, value, vn) +end + +""" + _has_partial_array(::Type{VarNamedTuple{Names,Values}}) where {Names,Values} + +Check if any of the types in the `Values` tuple is or contains a `PartialArray`. + +Recurses into any sub-`VarNamedTuple`s. +""" +@generated function _has_partial_array( + ::Type{VarNamedTuple{Names,Values}} +) where {Names,Values} + for T in Values.parameters + if _has_partial_array(T) + return :(return true) + end + end + return :(return false) +end + +_has_partial_array(::Type{T}) where {T} = false +_has_partial_array(::Type{<:PartialArray}) = true + +Base.empty(::VarNamedTuple) = VarNamedTuple() + +""" + empty!!(vnt::VarNamedTuple) + +Create an empty version of `vnt` in place. + +This differs from `Base.empty` in that any `PartialArray`s contained within `vnt` are kept +but have their contents deleted, rather than being removed entirely. This means that + +1) The result has a "memory" of how many dimensions different variables had, and you cannot, + for example, set `a[1,2]` after emptying a `VarNamedTuple` that had only `a[1]` defined. +2) Memory allocations may be reduced when reusing `VarNamedTuple`s, since the internal + `PartialArray`s do not need to be reallocated from scratch. +""" +@generated function BangBang.empty!!(vnt::VarNamedTuple{Names,Values}) where {Names,Values} + if !_has_partial_array(VarNamedTuple{Names,Values}) + return :(return VarNamedTuple()) + end + # Check all the fields of the NamedTuple, and keep the ones that contain PartialArrays, + # calling empty!! on them recursively. + new_names = () + new_values = () + for (name, ValType) in zip(Names, Values.parameters) + if _has_partial_array(ValType) + new_values = (new_values..., :(BangBang.empty!!(vnt.data.$name))) + new_names = (new_names..., name) + end + end + return quote + return VarNamedTuple(NamedTuple{$new_names}(($(new_values...),))) + end +end + +@generated function Base.isempty(vnt::VarNamedTuple{Names,Values}) where {Names,Values} + if isempty(Names) + return :(return true) + end + if !_has_partial_array(VarNamedTuple{Names,Values}) + return :(return false) + end + exs = Expr[] + for (name, ValType) in zip(Names, Values.parameters) + if !_has_partial_array(ValType) + return :(return false) + end + push!( + exs, + quote + val = vnt.data.$name + if val isa VarNamedTuple || val isa PartialArray + if !Base.isempty(val) + return false + end + else + return false + end + end, + ) + end + push!(exs, :(return true)) + return Expr(:block, exs...) +end diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl deleted file mode 100644 index 17b851d1d..000000000 --- a/src/varnamedvector.jl +++ /dev/null @@ -1,1690 +0,0 @@ -const CHECK_CONSISTENCY_DEFAULT = true - -""" - VarNamedVector - -A container that stores values in a vectorised form, but indexable by variable names. - -A `VarNamedVector` can be thought of as an ordered mapping from `VarName`s to pairs of -`(internal_value, transform)`. Here `internal_value` is a vectorised value for the variable -and `transform` is a function such that `transform(internal_value)` is the "original" value -of the variable, the one that the user sees. For instance, if the variable has a matrix -value, `internal_value` could bea flattened `Vector` of its elements, and `transform` would -be a `reshape` call. - -`transform` may implement simply vectorisation, but it may do more. Most importantly, it may -implement linking, where the internal storage of a random variable is in a form where all -values in Euclidean space are valid. This is useful for sampling, because the sampler can -make changes to `internal_value` without worrying about constraints on the space of -the random variable. - -The way to access this storage format directly is through the functions `getindex_internal` -and `setindex_internal`. The `transform` argument for `setindex_internal` is optional, by -default it is either the identity, or the existing transform if a value already exists for -this `VarName`. - -`VarNamedVector` also provides a `Dict`-like interface that hides away the internal -vectorisation. This can be accessed with `getindex` and `setindex!`. `setindex!` only takes -the value, the transform is automatically set to be a simple vectorisation. The only notable -deviation from the behavior of a `Dict` is that `setindex!` will throw an error if one tries -to set a new value for a variable that lives in a different "space" than the old one (e.g. -is of a different type or size). This is because `setindex!` does not change the transform -of a variable, e.g. preserve linking, and thus the new value must be compatible with the old -transform. - -For now, a third value is in fact stored for each `VarName`: a boolean indicating whether -the variable has been transformed to unconstrained Euclidean space or not. This is only in -place temporarily due to the needs of our old Gibbs sampler. - -Internally, `VarNamedVector` stores the values of all variables in a single contiguous -vector. This makes some operations more efficient, and means that one can access the entire -contents of the internal storage quickly with `getindex_internal(vnv, :)`. The other fields -of `VarNamedVector` are mostly used to keep track of which part of the internal storage -belongs to which `VarName`. - -All constructors accept a keyword argument `check_consistency::Bool=true` that controls -whether to run checks like the number of values matching the number of variables. Some of -these checks can be expensive, so if you are confident in the input, you may want to turn -`check_consistency` off for performance. - -# Fields - -$(FIELDS) - -# Extended help - -The values for different variables are internally all stored in a single vector. For -instance, -```jldoctest varnamedvector-struct -julia> using DynamicPPL: ReshapeTransform, VarNamedVector, @varname, setindex!!, update!!, getindex_internal - -julia> vnv = VarNamedVector(); - -julia> vnv = setindex!!(vnv, [0.0, 0.0, 0.0, 0.0], @varname(x)); - -julia> vnv = setindex!!(vnv, reshape(1:6, (2,3)), @varname(y)); - -julia> vnv.vals -10-element Vector{Real}: - 0.0 - 0.0 - 0.0 - 0.0 - 1 - 2 - 3 - 4 - 5 - 6 -``` - -The `varnames`, `ranges`, and `varname_to_index` fields keep track of which value belongs to -which variable. The `transforms` field stores the transformations that needed to transform -the vectorised internal storage back to its original form: - -```jldoctest varnamedvector-struct -julia> vnv.transforms[vnv.varname_to_index[@varname(y)]] == DynamicPPL.ReshapeTransform((6,), (2,3)) -true -``` - -If a variable is updated with a new value that is of a smaller dimension than the old -value, rather than resizing `vnv.vals`, some elements in `vnv.vals` are marked as inactive. - -```jldoctest varnamedvector-struct -julia> vnv = update!!(vnv, [46.0, 48.0], @varname(x)); - -julia> vnv.vals -10-element Vector{Real}: - 46.0 - 48.0 - 0.0 - 0.0 - 1 - 2 - 3 - 4 - 5 - 6 - -julia> println(vnv.num_inactive); -Dict(1 => 2) -``` - -This helps avoid unnecessary memory allocations for values that repeatedly change dimension. -The user does not have to worry about the inactive entries as long as they use functions -like `setindex!` and `getindex!` rather than directly accessing `vnv.vals`. - -```jldoctest varnamedvector-struct -julia> vnv[@varname(x)] -2-element Vector{Real}: - 46.0 - 48.0 - -julia> getindex_internal(vnv, :) -8-element Vector{Real}: - 46.0 - 48.0 - 1 - 2 - 3 - 4 - 5 - 6 -``` -""" -struct VarNamedVector{ - K<:VarName,V,T,KVec<:AbstractVector{K},VVec<:AbstractVector{V},TVec<:AbstractVector{T} -} - """ - mapping from a `VarName` to its integer index in `varnames`, `ranges` and `transforms` - """ - varname_to_index::Dict{K,Int} - - """ - vector of `VarNames` for the variables, where `varnames[varname_to_index[vn]] == vn` - """ - varnames::KVec - - """ - vector of index ranges in `vals` corresponding to `varnames`; each `VarName` `vn` has - a single index or a set of contiguous indices, such that the values of `vn` can be found - at `vals[ranges[varname_to_index[vn]]]` - """ - ranges::Vector{UnitRange{Int}} - - """ - vector of values of all variables; the value(s) of `vn` is/are - `vals[ranges[varname_to_index[vn]]]` - """ - vals::VVec - - """ - vector of transformations, so that `transforms[varname_to_index[vn]]` is a callable - that transforms the value of `vn` back to its original space, undoing any linking and - vectorisation - """ - transforms::TVec - - """ - vector of booleans indicating whether a variable has been explicitly transformed to - unconstrained Euclidean space, i.e. whether its domain is all of `ℝ^ⁿ`. If - `is_unconstrained[varname_to_index[vn]]` is true, it guarantees that the variable - `vn` is not constrained. However, the converse does not hold: if `is_unconstrained` - is false, the variable `vn` may still happen to be unconstrained, e.g. if its - original distribution is itself unconstrained (like a normal distribution). - """ - is_unconstrained::BitVector - - """ - mapping from a variable index to the number of inactive entries for that variable. - Inactive entries are elements in `vals` that are not part of the value of any variable. - They arise when a variable is set to a new value with a different dimension, in-place. - Inactive entries always come after the last active entry for the given variable. - See the extended help with `??VarNamedVector` for more details. - """ - num_inactive::Dict{Int,Int} - - function VarNamedVector( - varname_to_index, - varnames::KVec, - ranges, - vals::VVec, - transforms::TVec, - is_unconstrained=fill!(BitVector(undef, length(varnames)), 0), - num_inactive=Dict{Int,Int}(); - check_consistency::Bool=CHECK_CONSISTENCY_DEFAULT, - ) where {K,V,T,KVec<:AbstractVector{K},VVec<:AbstractVector{V},TVec<:AbstractVector{T}} - if check_consistency - if length(varnames) != length(ranges) || - length(varnames) != length(transforms) || - length(varnames) != length(is_unconstrained) || - length(varnames) != length(varname_to_index) - msg = ( - "Inputs to VarNamedVector have inconsistent lengths. " * - "Got lengths varnames: $(length(varnames)), " * - "ranges: $(length(ranges)), " * - "transforms: $(length(transforms)), " * - "is_unconstrained: $(length(is_unconstrained)), " * - "varname_to_index: $(length(varname_to_index))." - ) - throw(ArgumentError(msg)) - end - - num_vals = mapreduce(length, (+), ranges; init=0) + sum(values(num_inactive)) - if num_vals != length(vals) - msg = ( - "The total number of elements in `vals` ($(length(vals))) does not " * - "match the sum of the lengths of the ranges and the number of " * - "inactive entries ($(num_vals))." - ) - throw(ArgumentError(msg)) - end - - if Set(values(varname_to_index)) != Set(axes(varnames, 1)) - msg = ( - "The set of values of `varname_to_index` is not the set of valid " * - "indices for `varnames`." - ) - throw(ArgumentError(msg)) - end - - if !issubset(Set(keys(num_inactive)), Set(values(varname_to_index))) - msg = ( - "The keys of `num_inactive` are not a subset of the values of " * - "`varname_to_index`." - ) - throw(ArgumentError(msg)) - end - - # Check that the varnames don't overlap. The time cost is quadratic in number of - # variables. If this ever becomes an issue, we should be able to go down to at - # least N log N by sorting based on subsumes-order. - for vn1 in keys(varname_to_index) - for vn2 in keys(varname_to_index) - vn1 === vn2 && continue - if subsumes(vn1, vn2) - msg = ( - "Variables in a VarNamedVector should not subsume each " * - "other, but $vn1 subsumes $vn2." - ) - throw(ArgumentError(msg)) - end - end - end - - # We could also have a test to check that the ranges don't overlap, but that - # sounds unlikely to occur, and implementing it in linear time would require a - # tiny bit of thought. - end - - return new{K,V,T,KVec,VVec,TVec}( - varname_to_index, - varnames, - ranges, - vals, - transforms, - is_unconstrained, - num_inactive, - ) - end -end - -function VarNamedVector{K,V,T}() where {K,V,T} - return VarNamedVector( - Dict{K,Int}(), K[], UnitRange{Int}[], V[], T[]; check_consistency=false - ) -end - -VarNamedVector() = VarNamedVector{Union{},Union{},Union{}}() -function VarNamedVector(xs::Pair...; check_consistency=CHECK_CONSISTENCY_DEFAULT) - return VarNamedVector(OrderedDict(xs...); check_consistency=check_consistency) -end -function VarNamedVector(x::AbstractDict; check_consistency=CHECK_CONSISTENCY_DEFAULT) - return VarNamedVector(keys(x), values(x); check_consistency=check_consistency) -end -function VarNamedVector(varnames, vals; check_consistency=CHECK_CONSISTENCY_DEFAULT) - return VarNamedVector( - collect_maybe(varnames), collect_maybe(vals); check_consistency=check_consistency - ) -end -function VarNamedVector( - varnames::AbstractVector, - orig_vals::AbstractVector, - transforms=fill(identity, length(varnames)); - check_consistency=CHECK_CONSISTENCY_DEFAULT, -) - if isempty(varnames) && isempty(orig_vals) && isempty(transforms) - return VarNamedVector{eltype(varnames),eltype(orig_vals),eltype(transforms)}() - end - # Convert `vals` into a vector of vectors. - vals_vecs = map(tovec, orig_vals) - transforms = map( - (t, val) -> _compose_no_identity(t, from_vec_transform(val)), transforms, orig_vals - ) - # Make `varnames` have as concrete an element type as possible. - varnames = [v for v in varnames] - varname_to_index = Dict{eltype(varnames),Int}( - vn => i for (i, vn) in enumerate(varnames) - ) - vals = reduce(vcat, vals_vecs) - # Make the ranges. - ranges = Vector{UnitRange{Int}}() - offset = 0 - for x in vals_vecs - r = (offset + 1):(offset + length(x)) - push!(ranges, r) - offset = r[end] - end - - # Passing on check_consistency here seems wasteful. Wouldn't it be faster to do a - # lightweight check of the arguments of this function, and rely on the correctness - # of what this function does? However, the expensive check is whether any variable - # subsumes another, and that's the same regardless of where it's done, so the - # optimisation would be quite pointless. - return VarNamedVector( - varname_to_index, - varnames, - ranges, - vals, - transforms; - check_consistency=check_consistency, - ) -end - -function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector) - return vnv_left.varname_to_index == vnv_right.varname_to_index && - vnv_left.varnames == vnv_right.varnames && - vnv_left.ranges == vnv_right.ranges && - vnv_left.vals == vnv_right.vals && - vnv_left.transforms == vnv_right.transforms && - vnv_left.is_unconstrained == vnv_right.is_unconstrained && - vnv_left.num_inactive == vnv_right.num_inactive -end - -function is_tightly_typed(vnv::VarNamedVector) - k = eltype(vnv.varnames) - v = eltype(vnv.vals) - t = eltype(vnv.transforms) - return (isconcretetype(k) || k === Union{}) && - (isconcretetype(v) || v === Union{}) && - (isconcretetype(t) || t === Union{}) -end - -getidx(vnv::VarNamedVector, vn::VarName) = vnv.varname_to_index[vn] - -getrange(vnv::VarNamedVector, idx::Int) = vnv.ranges[idx] -getrange(vnv::VarNamedVector, vn::VarName) = getrange(vnv, getidx(vnv, vn)) - -gettransform(vnv::VarNamedVector, idx::Int) = vnv.transforms[idx] -gettransform(vnv::VarNamedVector, vn::VarName) = gettransform(vnv, getidx(vnv, vn)) - -# TODO(mhauru) Eventually I would like to rename the is_transformed function to -# is_unconstrained, but that's significantly breaking. -""" - is_transformed(vnv::VarNamedVector, vn::VarName) - -Return a boolean for whether `vn` is guaranteed to have been transformed so that its domain -is all of Euclidean space. -""" -is_transformed(vnv::VarNamedVector, vn::VarName) = vnv.is_unconstrained[getidx(vnv, vn)] - -""" - set_transformed!(vnv::VarNamedVector, val::Bool, vn::VarName) - -Set the value for whether `vn` is guaranteed to have been transformed so that all of -Euclidean space is its domain. -""" -function set_transformed!(vnv::VarNamedVector, val::Bool, vn::VarName) - return vnv.is_unconstrained[vnv.varname_to_index[vn]] = val -end - -function set_transformed!!(vnv::VarNamedVector, val::Bool, vn::VarName) - set_transformed!(vnv, val, vn) - return vnv -end - -""" - has_inactive(vnv::VarNamedVector) - -Returns `true` if `vnv` has inactive entries. - -See also: [`num_inactive`](@ref) -""" -has_inactive(vnv::VarNamedVector) = !isempty(vnv.num_inactive) - -""" - num_inactive(vnv::VarNamedVector) - -Return the number of inactive entries in `vnv`. - -See also: [`has_inactive`](@ref), [`num_allocated`](@ref) -""" -num_inactive(vnv::VarNamedVector) = sum(values(vnv.num_inactive)) - -""" - num_inactive(vnv::VarNamedVector, vn::VarName) - -Returns the number of inactive entries for `vn` in `vnv`. -""" -num_inactive(vnv::VarNamedVector, vn::VarName) = num_inactive(vnv, getidx(vnv, vn)) -num_inactive(vnv::VarNamedVector, idx::Int) = get(vnv.num_inactive, idx, 0) - -""" - num_allocated(vnv::VarNamedVector) - num_allocated(vnv::VarNamedVector[, vn::VarName]) - num_allocated(vnv::VarNamedVector[, idx::Int]) - -Return the number of allocated entries in `vnv`, both active and inactive. - -If either a `VarName` or an `Int` index is specified, only count entries allocated for that -variable. - -Allocated entries take up memory in `vnv.vals`, but, if inactive, may not currently hold any -meaningful data. One can remove them with [`contiguify!`](@ref), but doing so may cause more -memory allocations in the future if variables change dimension. -""" -num_allocated(vnv::VarNamedVector) = length(vnv.vals) -num_allocated(vnv::VarNamedVector, vn::VarName) = num_allocated(vnv, getidx(vnv, vn)) -function num_allocated(vnv::VarNamedVector, idx::Int) - return length(getrange(vnv, idx)) + num_inactive(vnv, idx) -end - -# Dictionary interface. -Base.isempty(vnv::VarNamedVector) = isempty(vnv.varnames) -Base.length(vnv::VarNamedVector) = length(vnv.varnames) -Base.keys(vnv::VarNamedVector) = vnv.varnames -Base.values(vnv::VarNamedVector) = Iterators.map(Base.Fix1(getindex, vnv), vnv.varnames) -Base.pairs(vnv::VarNamedVector) = (vn => vnv[vn] for vn in keys(vnv)) -Base.haskey(vnv::VarNamedVector, vn::VarName) = haskey(vnv.varname_to_index, vn) - -# Vector-like interface. -Base.eltype(vnv::VarNamedVector) = eltype(vnv.vals) - -""" - length_internal(vnv::VarNamedVector) - -Return the length of the internal storage vector of `vnv`, ignoring inactive entries. -""" -function length_internal(vnv::VarNamedVector) - if !has_inactive(vnv) - return length(vnv.vals) - else - return sum(length, vnv.ranges) - end -end - -# Getting and setting values - -function Base.getindex(vnv::VarNamedVector, vn::VarName) - x = getindex_internal(vnv, vn) - f = gettransform(vnv, vn) - return f(x) -end - -""" - find_containing_range(ranges::AbstractVector{<:AbstractRange}, x) - -Find the first range in `ranges` that contains `x`. - -Throw an `ArgumentError` if `x` is not in any of the ranges. -""" -function find_containing_range(ranges::AbstractVector{<:AbstractRange}, x) - # TODO: Assume `ranges` to be sorted and contiguous, and use `searchsortedfirst` - # for a more efficient approach. - range_idx = findfirst(Base.Fix1(∈, x), ranges) - - # If we're out of bounds, we raise an error. - if range_idx === nothing - throw(ArgumentError("Value $x is not in any of the ranges.")) - end - - return range_idx -end - -""" - adjusted_ranges(vnv::VarNamedVector) - -Return what `vnv.ranges` would be if there were no inactive entries. -""" -function adjusted_ranges(vnv::VarNamedVector) - # Every range following inactive entries needs to be shifted. - offset = 0 - ranges_adj = similar(vnv.ranges) - for (idx, r) in enumerate(vnv.ranges) - # Remove the `offset` in `r` due to inactive entries. - ranges_adj[idx] = r .- offset - # Update `offset`. - offset += get(vnv.num_inactive, idx, 0) - end - - return ranges_adj -end - -""" - index_to_vals_index(vnv::VarNamedVector, i::Int) - -Convert an integer index that ignores inactive entries to an index that accounts for them. - -This is needed when the user wants to index `vnv` like a vector, but shouldn't have to care -about inactive entries in `vnv.vals`. -""" -function index_to_vals_index(vnv::VarNamedVector, i::Int) - # If we don't have any inactive entries, there's nothing to do. - has_inactive(vnv) || return i - - # Get the adjusted ranges. - ranges_adj = adjusted_ranges(vnv) - # Determine the adjusted range that the index corresponds to. - r_idx = find_containing_range(ranges_adj, i) - r = vnv.ranges[r_idx] - # Determine how much of the index `i` is used to get to this range. - i_used = r_idx == 1 ? 0 : sum(length, ranges_adj[1:(r_idx - 1)]) - # Use remainder to index into `r`. - i_remainder = i - i_used - return r[i_remainder] -end - -""" - getindex_internal(vnv::VarNamedVector, vn::VarName) - -Like `getindex`, but returns the values as they are stored in `vnv`, without transforming. -""" -getindex_internal(vnv::VarNamedVector, vn::VarName) = vnv.vals[getrange(vnv, vn)] - -""" - getindex_internal(vnv::VarNamedVector, i::Int) - -Gets the `i`th element of the internal storage vector, ignoring inactive entries. -""" -getindex_internal(vnv::VarNamedVector, i::Int) = vnv.vals[index_to_vals_index(vnv, i)] - -function getindex_internal(vnv::VarNamedVector, ::Colon) - return if has_inactive(vnv) - mapreduce(Base.Fix1(getindex, vnv.vals), vcat, vnv.ranges) - else - vnv.vals - end -end - -function Base.setindex!(vnv::VarNamedVector, val, vn::VarName) - if haskey(vnv, vn) - return update!(vnv, val, vn) - else - return insert!(vnv, val, vn) - end -end - -""" - reset!(vnv::VarNamedVector, val, vn::VarName) - -Reset the value of `vn` in `vnv` to `val`. - -This differs from `setindex!` in that it will always change the transform of the variable -to be the default vectorisation transform. This undoes any possible linking. - -# Examples - -```jldoctest varnamedvector-reset -julia> using DynamicPPL: VarNamedVector, @varname, reset! - -julia> vnv = VarNamedVector{VarName,Any,Any}(); - -julia> vnv[@varname(x)] = reshape(1:9, (3, 3)); - -julia> setindex!(vnv, 2.0, @varname(x)) -ERROR: An error occurred while assigning the value 2.0 to variable x. If you are changing the type or size of a variable you'll need to call reset! -[...] - -julia> reset!(vnv, 2.0, @varname(x)); - -julia> vnv[@varname(x)] -2.0 -``` -""" -function reset!(vnv::VarNamedVector, val, vn::VarName) - f = from_vec_transform(val) - retval = setindex_internal!(vnv, tovec(val), vn, f) - set_transformed!(vnv, false, vn) - return retval -end - -""" - update!(vnv::VarNamedVector, val, vn::VarName) - -Update the value of `vn` in `vnv` to `val`. - -Like `setindex!`, but errors if the key `vn` doesn't exist. -""" -function update!(vnv::VarNamedVector, val, vn::VarName) - if !haskey(vnv, vn) - throw(KeyError(vn)) - end - f = inverse(gettransform(vnv, vn)) - internal_val = try - f(val) - catch - error( - "An error occurred while assigning the value $val to variable $vn. " * - "If you are changing the type or size of a variable you'll need to call " * - "reset!", - ) - end - return setindex_internal!(vnv, internal_val, vn) -end - -""" - insert!(vnv::VarNamedVector, val, vn::VarName) - -Add a variable with given value to `vnv`. - -Like `setindex!`, but errors if the key `vn` already exists. -""" -function Base.insert!(vnv::VarNamedVector, val, vn::VarName) - if haskey(vnv, vn) - throw("Variable $vn already exists in VarNamedVector.") - end - return reset!(vnv, val, vn) -end - -""" - push!(vnv::VarNamedVector, pair::Pair) - -Add a variable with given value to `vnv`. Pair should be a `VarName` and a value. -""" -function Base.push!(vnv::VarNamedVector, pair::Pair) - vn, val = pair - # TODO(mhauru) Or should this rather call `reset!`? It would be more inline with what - # Dict does, but could also cause confusion. - return setindex!(vnv, val, vn) -end - -""" - setindex_internal!(vnv::VarNamedVector, val, i::Int) - -Sets the `i`th element of the internal storage vector, ignoring inactive entries. -""" -function setindex_internal!(vnv::VarNamedVector, val, i::Int) - return vnv.vals[index_to_vals_index(vnv, i)] = val -end - -""" - setindex_internal!(vnv::VarNamedVector, val, vn::VarName[, transform]) - -Like `setindex!`, but sets the values as they are stored internally in `vnv`. - -Optionally can set the transformation, such that `transform(val)` is the original value of -the variable. By default, the transform is the identity if creating a new entry in `vnv`, or -the existing transform if updating an existing entry. -""" -function setindex_internal!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - if haskey(vnv, vn) - return update_internal!(vnv, val, vn, transform) - else - return insert_internal!(vnv, val, vn, transform) - end -end - -""" - insert_internal!(vnv::VarNamedVector, val::AbstractVector, vn::VarName[, transform]) - -Add a variable with given value to `vnv`. - -Like `setindex_internal!`, but errors if the key `vn` already exists. - -`transform` should be a function that converts `val` to the original representation. By -default it's `identity`. -""" -function insert_internal!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - if transform === nothing - transform = identity - end - haskey(vnv, vn) && throw(ArgumentError("variable name $vn already exists")) - # NOTE: We need to compute the `nextrange` BEFORE we start mutating the underlying - # storage. - r_new = nextrange(vnv, val) - vnv.varname_to_index[vn] = length(vnv.varname_to_index) + 1 - push!(vnv.varnames, vn) - push!(vnv.ranges, r_new) - append!(vnv.vals, val) - push!(vnv.transforms, transform) - push!(vnv.is_unconstrained, false) - return nothing -end - -""" - update_internal!(vnv::VarNamedVector, vn::VarName, val::AbstractVector[, transform]) - -Update an existing entry for `vn` in `vnv` with the value `val`. - -Like `setindex_internal!`, but errors if the key `vn` doesn't exist. - -`transform` should be a function that converts `val` to the original representation. By -default it's the same as the old transform for `vn`. -""" -function update_internal!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - # Here we update an existing entry. - if !haskey(vnv, vn) - throw(KeyError(vn)) - end - idx = getidx(vnv, vn) - # Extract the old range. - r_old = getrange(vnv, idx) - start_old, end_old = first(r_old), last(r_old) - n_old = length(r_old) - # Compute the new range. - n_new = length(val) - start_new = start_old - end_new = start_old + n_new - 1 - r_new = start_new:end_new - - #= - Suppose we currently have the following: - - | x | x | o | o | o | y | y | y | <- Current entries - - where 'O' denotes an inactive entry, and we're going to - update the variable `x` to be of size `k` instead of 2. - - We then have a few different scenarios: - 1. `k > 5`: All inactive entries become active + need to shift `y` to the right. - E.g. if `k = 7`, then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | x | x | x | x | x | x | y | y | y | <- New entries - - 2. `k = 5`: All inactive entries become active. - Then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | x | x | x | x | y | y | y | <- New entries - - 3. `k < 5`: Some inactive entries become active, some remain inactive. - E.g. if `k = 3`, then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | x | x | o | o | y | y | y | <- New entries - - 4. `k = 2`: No inactive entries become active. - Then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | x | o | o | o | y | y | y | <- New entries - - 5. `k < 2`: More entries become inactive. - E.g. if `k = 1`, then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | o | o | o | o | y | y | y | <- New entries - =# - - # Compute the allocated space for `vn`. - had_inactive = haskey(vnv.num_inactive, idx) - n_allocated = had_inactive ? n_old + vnv.num_inactive[idx] : n_old - - if n_new > n_allocated - # Then we need to grow the underlying vector. - n_extra = n_new - n_allocated - # Allocate. - resize!(vnv.vals, length(vnv.vals) + n_extra) - # Shift current values. - shift_right!(vnv.vals, end_old + 1, n_extra) - # No more inactive entries. - had_inactive && delete!(vnv.num_inactive, idx) - # Update the ranges for all variables after this one. - shift_subsequent_ranges_by!(vnv, idx, n_extra) - elseif n_new == n_allocated - # => No more inactive entries. - had_inactive && delete!(vnv.num_inactive, idx) - else - # `n_new < n_allocated` - # => Need to update the number of inactive entries. - vnv.num_inactive[idx] = n_allocated - n_new - end - - # Update the range for this variable. - vnv.ranges[idx] = r_new - # Update the value. - vnv.vals[r_new] = val - if transform !== nothing - # Update the transform. - vnv.transforms[idx] = transform - end - - # TODO: Should we maybe sweep over inactive ranges and re-contiguify - # if the total number of inactive elements is "large" in some sense? - - return nothing -end - -function Base.push!(vnv::VarNamedVector, vn, val, dist) - f = from_vec_transform(dist) - return setindex_internal!(vnv, tovec(val), vn, f) -end - -function BangBang.push!!(vnv::VarNamedVector, vn, val, dist) - f = from_vec_transform(dist) - return setindex_internal!!(vnv, tovec(val), vn, f) -end - -# BangBang versions of the above functions. -# The only difference is that update_internal!! and insert_internal!! check whether the -# container types of the VarNamedVector vector need to be expanded to accommodate the new -# values. If so, they create a new instance, otherwise they mutate in place. All the others -# functions, e.g. setindex!!, setindex_internal!!, etc., are carbon copies of the ! versions -# with every ! call replaced with a !! call. - -""" - loosen_types!!(vnv::VarNamedVector, ::Type{KNew}, ::Type{VNew}, ::Type{TNew}) - -Loosen the types of `vnv` to allow varname type `KNew` and transformation type `TransNew`. - -If `KNew` is a subtype of `K` and `TransNew` is a subtype of the element type of the -`TTrans` then this is a no-op and `vnv` is returned as is. Otherwise a new `VarNamedVector` -is returned with the same data but more abstract types, so that variables of type `KNew` and -transformations of type `TransNew` can be pushed to it. Some of the underlying storage is -shared between `vnv` and the return value, and thus mutating one may affect the other. - -# See also -[`tighten_types!!`](@ref) - -# Examples - -```jldoctest varnamedvector-loosen-types -julia> using DynamicPPL: VarNamedVector, @varname, loosen_types!!, setindex_internal! - -julia> vnv = VarNamedVector(@varname(x) => [1.0]); - -julia> y_trans(x) = reshape(x, (2, 2)); - -julia> setindex_internal!(vnv, collect(1:4), @varname(y), y_trans) -ERROR: MethodError: Cannot `convert` an object of type -[...] - -julia> vnv_loose = DynamicPPL.loosen_types!!( - vnv, typeof(@varname(y)), Float64, typeof(y_trans) - ); - -julia> setindex_internal!(vnv_loose, collect(1:4), @varname(y), y_trans) - -julia> vnv_loose[@varname(y)] -2×2 Matrix{Float64}: - 1.0 3.0 - 2.0 4.0 -``` -""" -function loosen_types!!( - vnv::VarNamedVector, ::Type{KNew}, ::Type{VNew}, ::Type{TNew} -) where {KNew,VNew,TNew} - K = eltype(vnv.varnames) - V = eltype(vnv.vals) - T = eltype(vnv.transforms) - if KNew <: K && VNew <: V && TNew <: T - return vnv - else - # We could use promote_type here, instead of typejoin. However, that would e.g. - # cause Ints to be converted to Float64s, since - # promote_type(Int, Float64) == Float64, which can cause problems. See - # https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188. - # Base.promote_typejoin would be like typejoin, but creates Unions out of Nothing - # and Missing, rather than falling back on Any. However, it's not exported. - vn_type = typejoin(K, KNew) - val_type = typejoin(V, VNew) - transform_type = typejoin(T, TNew) - # This function would work the same way if the first if statement a few lines above - # was skipped, and we only checked for the below condition. However, the first one - # is constant propagated away at compile time (at least on Julia v1.11.7), whereas - # this one isn't. Hence we keep both for performance. - return if vn_type == K && val_type == V && transform_type == T - vnv - elseif isempty(vnv) - VarNamedVector( - Dict{vn_type,Int}(), - Vector{vn_type}(), - UnitRange{Int}[], - Vector{val_type}(), - Vector{transform_type}(), - BitVector(), - Dict{Int,Int}(); - check_consistency=false, - ) - else - # TODO(mhauru) We allow a `vnv` to have any AbstractVector type as its vals, but - # then here always revert to Vector. - VarNamedVector( - Dict{vn_type,Int}(vnv.varname_to_index), - Vector{vn_type}(vnv.varnames), - vnv.ranges, - Vector{val_type}(vnv.vals), - Vector{transform_type}(vnv.transforms), - vnv.is_unconstrained, - vnv.num_inactive; - check_consistency=false, - ) - end - end -end - -""" - tighten_types!!(vnv::VarNamedVector) - -Return a `VarNamedVector` like `vnv` with the most concrete types possible. - -This function either returns `vnv` itself or new `VarNamedVector` with the same values in -it, but with the element types of various containers made as concrete as possible. - -For instance, if `vnv` has its vector of transforms have eltype `Any`, but all the -transforms are actually identity transformations, this function will return a new -`VarNamedVector` with the transforms vector having eltype `typeof(identity)`. - -This is a lot like the reverse of [`loosen_types!!`](@ref). Like with `loosen_types!!`, the -return value may share some of its underlying storage with `vnv`, and thus mutating one may -affect the other. - -# See also -[`loosen_types!!`](@ref) - -# Examples - -```jldoctest varnamedvector-tighten-types -julia> using DynamicPPL: VarNamedVector, @varname, loosen_types!!, setindex_internal! - -julia> vnv = VarNamedVector(@varname(x) => Real[23], @varname(y) => randn(2,2)); - -julia> vnv = delete!(vnv, @varname(y)); - -julia> eltype(vnv) -Real - -julia> vnv.transforms -1-element Vector{Any}: - identity (generic function with 1 method) - -julia> vnv_tight = DynamicPPL.tighten_types!!(vnv); - -julia> eltype(vnv_tight) == Int -true - -julia> vnv_tight.transforms -1-element Vector{typeof(identity)}: - identity (generic function with 1 method) -``` -""" -function tighten_types!!(vnv::VarNamedVector) - return if is_tightly_typed(vnv) - # There can not be anything to tighten, so short-circuit. - vnv - elseif isempty(vnv) - VarNamedVector() - else - VarNamedVector( - Dict(vnv.varname_to_index...), - [x for x in vnv.varnames], - vnv.ranges, - [x for x in vnv.vals], - [x for x in vnv.transforms], - vnv.is_unconstrained, - vnv.num_inactive; - check_consistency=false, - ) - end -end - -function BangBang.setindex!!(vnv::VarNamedVector, val, vn::VarName) - if haskey(vnv, vn) - return update!!(vnv, val, vn) - else - return insert!!(vnv, val, vn) - end -end - -function reset!!(vnv::VarNamedVector, val, vn::VarName) - f = from_vec_transform(val) - vnv = setindex_internal!!(vnv, tovec(val), vn, f) - vnv = set_transformed!!(vnv, false, vn) - return vnv -end - -function update!!(vnv::VarNamedVector, val, vn::VarName) - if !haskey(vnv, vn) - throw(KeyError(vn)) - end - f = inverse(gettransform(vnv, vn)) - internal_val = try - f(val) - catch - error( - "An error occurred while assigning the value $val to variable $vn. " * - "If you are changing the type or size of a variable you'll need to either " * - "`delete!` it first or use `setindex_internal!`", - ) - end - return setindex_internal!!(vnv, internal_val, vn) -end - -function insert!!(vnv::VarNamedVector, val, vn::VarName) - if haskey(vnv, vn) - throw("Variable $vn already exists in VarNamedVector.") - end - return reset!!(vnv, val, vn) -end - -function setindex_internal!!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - if haskey(vnv, vn) - return update_internal!!(vnv, val, vn, transform) - else - return insert_internal!!(vnv, val, vn, transform) - end -end - -function insert_internal!!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - if transform === nothing - transform = identity - end - vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform)) - insert_internal!(vnv, val, vn, transform) - vnv = tighten_types!!(vnv) - return vnv -end - -function update_internal!!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - transform_resolved = transform === nothing ? gettransform(vnv, vn) : transform - vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform_resolved)) - update_internal!(vnv, val, vn, transform) - vnv = tighten_types!!(vnv) - return vnv -end - -function BangBang.push!!(vnv::VarNamedVector, pair::Pair) - vn, val = pair - return setindex!!(vnv, val, vn) -end - -function Base.empty!(vnv::VarNamedVector) - # TODO: Or should the semantics be different, e.g. keeping `varnames`? - empty!(vnv.varname_to_index) - empty!(vnv.varnames) - empty!(vnv.ranges) - empty!(vnv.vals) - empty!(vnv.transforms) - empty!(vnv.is_unconstrained) - empty!(vnv.num_inactive) - return nothing -end -BangBang.empty!!(vnv::VarNamedVector) = (empty!(vnv); return vnv) - -""" - replace_raw_storage(vnv::VarNamedVector, vals::AbstractVector) - -Replace the values in `vnv` with `vals`, as they are stored internally. - -This is useful when we want to update the entire underlying vector of values in one go or if -we want to change the how the values are stored, e.g. alter the `eltype`. - -!!! warning - This replaces the raw underlying values, and so care should be taken when using this - function. For example, if `vnv` has any inactive entries, then the provided `vals` - should also contain the inactive entries to avoid unexpected behavior. - -# Examples - -```jldoctest varnamedvector-replace-raw-storage -julia> using DynamicPPL: VarNamedVector, replace_raw_storage - -julia> vnv = VarNamedVector(@varname(x) => [1.0]); - -julia> replace_raw_storage(vnv, [2.0])[@varname(x)] == [2.0] -true -``` - -This is also useful when we want to differentiate wrt. the values using automatic -differentiation, e.g. ForwardDiff.jl. - -```jldoctest varnamedvector-replace-raw-storage -julia> using ForwardDiff: ForwardDiff - -julia> f(x) = sum(abs2, replace_raw_storage(vnv, x)[@varname(x)]) -f (generic function with 1 method) - -julia> ForwardDiff.gradient(f, [1.0]) -1-element Vector{Float64}: - 2.0 -``` -""" -replace_raw_storage(vnv::VarNamedVector, vals) = Accessors.@set vnv.vals = vals - -vector_length(vnv::VarNamedVector) = length(vnv.vals) - num_inactive(vnv) - -""" - unflatten(vnv::VarNamedVector, vals::AbstractVector) - -Return a new instance of `vnv` with the values of `vals` assigned to the variables. - -This assumes that `vals` have been transformed by the same transformations that that the -values in `vnv` have been transformed by. However, unlike [`replace_raw_storage`](@ref), -`unflatten` does account for inactive entries in `vnv`, so that the user does not have to -care about them. - -This is in a sense the reverse operation of `vnv[:]`. - -The return value may share memory with the input `vnv`, and thus one can not be mutated -safely without affecting the other. - -Unflatten recontiguifies the internal storage, getting rid of any inactive entries. - -# Examples - -```jldoctest varnamedvector-unflatten -julia> using DynamicPPL: VarNamedVector, unflatten - -julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(y) => [3.0]); - -julia> unflatten(vnv, vnv[:]) == vnv -true -""" -function unflatten(vnv::VarNamedVector, vals::AbstractVector) - if length(vals) != vector_length(vnv) - throw( - ArgumentError( - "Length of `vals` ($(length(vals))) does not match the length of " * - "`vnv` ($(vector_length(vnv))).", - ), - ) - end - new_ranges = vnv.ranges - num_inactive = vnv.num_inactive - if has_inactive(vnv) - new_ranges = recontiguify_ranges!(new_ranges) - num_inactive = Dict{Int,Int}() - end - return VarNamedVector( - vnv.varname_to_index, - vnv.varnames, - new_ranges, - vals, - vnv.transforms, - vnv.is_unconstrained, - num_inactive; - check_consistency=false, - ) -end - -function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) - # Return early if possible. - isempty(left_vnv) && return deepcopy(right_vnv) - isempty(right_vnv) && return deepcopy(left_vnv) - - # Determine varnames. - vns_left = left_vnv.varnames - vns_right = right_vnv.varnames - vns_both = union(vns_left, vns_right) - - # Check that varnames do not subsume each other. - for vn_left in vns_left - for vn_right in vns_right - vn_left == vn_right && continue - # TODO(mhauru) Subsumation doesn't actually need to be a showstopper. For - # instance, if right has a value for `x` and left has a value for `x[1]`, then - # right will take precedence anyway, and we could merge. However, that requires - # some extra logic that hasn't been done yet. - if subsumes(vn_left, vn_right) - throw( - ArgumentError( - "Cannot merge VarNamedVectors: variable name $vn_left " * - "subsumes $vn_right.", - ), - ) - elseif subsumes(vn_right, vn_left) - throw( - ArgumentError( - "Cannot merge VarNamedVectors: variable name $vn_right " * - "subsumes $vn_left.", - ), - ) - end - end - end - - # Determine `eltype` of `vals`. - T_left = eltype(left_vnv.vals) - T_right = eltype(right_vnv.vals) - T = typejoin(T_left, T_right) - - # Determine `eltype` of `varnames`. - V_left = eltype(left_vnv.varnames) - V_right = eltype(right_vnv.varnames) - V = typejoin(V_left, V_right) - if !(V <: VarName) - V = VarName - end - - # Determine `eltype` of `transforms`. - F_left = eltype(left_vnv.transforms) - F_right = eltype(right_vnv.transforms) - F = typejoin(F_left, F_right) - - # Allocate. - varname_to_index = Dict{V,Int}() - ranges = UnitRange{Int}[] - vals = T[] - transforms = F[] - is_unconstrained = BitVector(undef, length(vns_both)) - - # Range offset. - offset = 0 - - for (idx, vn) in enumerate(vns_both) - varname_to_index[vn] = idx - # Extract the necessary information from `left` or `right`. - if vn in vns_left && !(vn in vns_right) - # `vn` is only in `left`. - val = getindex_internal(left_vnv, vn) - f = gettransform(left_vnv, vn) - is_unconstrained[idx] = is_transformed(left_vnv, vn) - else - # `vn` is either in both or just `right`. - # Note that in a `merge` the right value has precedence. - val = getindex_internal(right_vnv, vn) - f = gettransform(right_vnv, vn) - is_unconstrained[idx] = is_transformed(right_vnv, vn) - end - n = length(val) - r = (offset + 1):(offset + n) - # Update. - append!(vals, val) - push!(ranges, r) - push!(transforms, f) - # Increment `offset`. - offset += n - end - - return VarNamedVector( - varname_to_index, - vns_both, - ranges, - vals, - transforms, - is_unconstrained; - check_consistency=false, - ) -end - -""" - subset(vnv::VarNamedVector, vns::AbstractVector{<:VarName}) - -Return a new `VarNamedVector` containing the values from `vnv` for variables in `vns`. - -Which variables to include is determined by the `VarName`'s `subsumes` relation, meaning -that e.g. `subset(vnv, [@varname(x)])` will include variables like `@varname(x.a[1])`. - -Preserves the order of variables in `vnv`. - -# Examples - -```jldoctest varnamedvector-subset -julia> using DynamicPPL: VarNamedVector, @varname, subset - -julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(y) => [3.0]); - -julia> subset(vnv, [@varname(x)]) == VarNamedVector(@varname(x) => [1.0, 2.0]) -true - -julia> subset(vnv, [@varname(x[2])]) == VarNamedVector(@varname(x[2]) => [2.0]) -true -""" -function subset(vnv::VarNamedVector, vns_given::AbstractVector{<:VarName}) - vnv_new = similar(vnv) - # Return early if possible. - isempty(vnv) && return vnv_new - - for vn in vnv.varnames - if any(subsumes(vn_given, vn) for vn_given in vns_given) - insert_internal!(vnv_new, getindex_internal(vnv, vn), vn, gettransform(vnv, vn)) - set_transformed!(vnv_new, is_transformed(vnv, vn), vn) - end - end - - return tighten_types!!(vnv_new) -end - -""" - similar(vnv::VarNamedVector) - -Return a new `VarNamedVector` with the same structure as `vnv`, but with empty values. - -In this respect `vnv` behaves more like a dictionary than an array: `similar(vnv)` will -be entirely empty, rather than have `undef` values in it. - -# Examples - -```julia-doctest-varnamedvector-similar -julia> using DynamicPPL: VarNamedVector, @varname, similar - -julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(x[3]) => [3.0]); - -julia> similar(vnv) == VarNamedVector{VarName{:x}, Float64}() -true -""" -function Base.similar(vnv::VarNamedVector) - # NOTE: Whether or not we should empty the underlying containers or not - # is somewhat ambiguous. For example, `similar(vnv.varname_to_index)` will - # result in an empty `AbstractDict`, while the vectors, e.g. `vnv.ranges`, - # will result in non-empty vectors but with entries as `undef`. But it's - # much easier to write the rest of the code assuming that `undef` is not - # present, and so for now we empty the underlying containers, thus differing - # from the behavior of `similar` for `AbstractArray`s. - return VarNamedVector( - empty(vnv.varname_to_index), - similar(vnv.varnames, 0), - similar(vnv.ranges, 0), - similar(vnv.vals, 0), - similar(vnv.transforms, 0), - BitVector(), - empty(vnv.num_inactive); - check_consistency=false, - ) -end - -""" - is_contiguous(vnv::VarNamedVector) - -Returns `true` if the underlying data of `vnv` is stored in a contiguous array. - -This is equivalent to negating [`has_inactive(vnv)`](@ref). -""" -is_contiguous(vnv::VarNamedVector) = !has_inactive(vnv) - -""" - nextrange(vnv::VarNamedVector, x) - -Return the range of `length(x)` from the end of current data in `vnv`. -""" -function nextrange(vnv::VarNamedVector, x) - offset = length(vnv.vals) - return (offset + 1):(offset + length(x)) -end - -# TODO(mhauru) Might add another specialisation to _compose_no_identity, where if -# ReshapeTransforms are composed with each other or with a an UnwrapSingeltonTransform, only -# the latter one would be kept. -""" - _compose_no_identity(f, g) - -Like `f ∘ g`, but if `f` or `g` is `identity` it is omitted. - -This helps avoid trivial cases of `ComposedFunction` that would cause unnecessary type -conflicts. -""" -_compose_no_identity(f, g) = f ∘ g -_compose_no_identity(::typeof(identity), g) = g -_compose_no_identity(f, ::typeof(identity)) = f -_compose_no_identity(::typeof(identity), ::typeof(identity)) = identity - -""" - shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) - -Shifts the elements of `x` starting from index `start` by `n` to the right. -""" -function shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) - x[(start + n):end] = x[start:(end - n)] - return x -end - -""" - shift_subsequent_ranges_by!(vnv::VarNamedVector, idx::Int, n) - -Shifts the ranges of variables in `vnv` starting from index `idx` by `n`. -""" -function shift_subsequent_ranges_by!(vnv::VarNamedVector, idx::Int, n) - for i in (idx + 1):length(vnv.ranges) - vnv.ranges[i] = vnv.ranges[i] .+ n - end - return nothing -end - -# set!! is the function defined in utils.jl that tries to do fancy stuff with optics when -# setting the value of a generic container using a VarName. We can bypass all that because -# VarNamedVector handles VarNames natively. However, it's semantics are slightly different -# from setindex!'s: It allows resetting variables that already have a value with values of -# a different type/size. -set!!(vnv::VarNamedVector, vn::VarName, val) = reset!!(vnv, val, vn) - -function setval!(vnv::VarNamedVector, val, vn::VarName) - return setindex_internal!(vnv, tovec(val), vn) -end - -function recontiguify_ranges!(ranges::AbstractVector{<:AbstractRange}) - offset = 0 - for i in 1:length(ranges) - r_old = ranges[i] - ranges[i] = (offset + 1):(offset + length(r_old)) - offset += length(r_old) - end - - return ranges -end - -""" - contiguify!(vnv::VarNamedVector) - -Re-contiguify the underlying vector and shrink if possible. - -# Examples - -```jldoctest varnamedvector-contiguify -julia> using DynamicPPL: VarNamedVector, @varname, contiguify!, update!, has_inactive - -julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0, 3.0], @varname(y) => [3.0]); - -julia> update!(vnv, [23.0, 24.0], @varname(x)); - -julia> has_inactive(vnv) -true - -julia> length(vnv.vals) -4 - -julia> contiguify!(vnv); - -julia> has_inactive(vnv) -false - -julia> length(vnv.vals) -3 - -julia> vnv[@varname(x)] # All the values are still there. -2-element Vector{Float64}: - 23.0 - 24.0 -``` -""" -function contiguify!(vnv::VarNamedVector) - if !has_inactive(vnv) - return vnv - end - # Extract the re-contiguified values. - # NOTE: We need to do this before we update the ranges. - old_vals = copy(vnv.vals) - old_ranges = copy(vnv.ranges) - # And then we re-contiguify the ranges. - recontiguify_ranges!(vnv.ranges) - # Clear the inactive ranges. - empty!(vnv.num_inactive) - # Now we update the values. - for (old_range, new_range) in zip(old_ranges, vnv.ranges) - vnv.vals[new_range] = old_vals[old_range] - end - # And (potentially) shrink the underlying vector. - resize!(vnv.vals, vnv.ranges[end][end]) - # The rest should be left as is. - return vnv -end - -""" - group_by_symbol(vnv::VarNamedVector) - -Return a dictionary mapping symbols to `VarNamedVector`s with varnames containing that -symbol. - -# Examples - -```jldoctest varnamedvector-group-by-symbol -julia> using DynamicPPL: VarNamedVector, @varname, group_by_symbol - -julia> vnv = VarNamedVector(@varname(x) => [1.0], @varname(y) => [2.0], @varname(x[1]) => [3.0]); - -julia> d = group_by_symbol(vnv); - -julia> collect(keys(d)) -[Symbol("x"), Symbol("y")] - -julia> d[@varname(x)] == VarNamedVector(@varname(x) => [1.0], @varname(x[1]) => [3.0]) -true - -julia> d[@varname(y)] == VarNamedVector(@varname(y) => [2.0]) -true -""" -function group_by_symbol(vnv::VarNamedVector) - symbols = unique(map(getsym, vnv.varnames)) - nt_vals = map(s -> tighten_types!!(subset(vnv, [VarName{s}()])), symbols) - return OrderedDict(zip(symbols, nt_vals)) -end - -""" - shift_index_left!(vnv::VarNamedVector, idx::Int) - -Shift the index `idx` to the left by one and update the relevant fields. - -This only affects `vnv.varname_to_index` and `vnv.num_inactive` and is only valid as a -helper function for [`shift_subsequent_indices_left!`](@ref). - -!!! warning - This does not check if index we're shifting to is already occupied. -""" -function shift_index_left!(vnv::VarNamedVector, idx::Int) - # Shift the index in the lookup table. - vn = vnv.varnames[idx] - vnv.varname_to_index[vn] = idx - 1 - # Shift the index in the inactive ranges. - if haskey(vnv.num_inactive, idx) - # Done in increasing order => don't need to worry about - # potentially shifting the same index twice. - vnv.num_inactive[idx - 1] = pop!(vnv.num_inactive, idx) - end -end - -""" - shift_subsequent_indices_left!(vnv::VarNamedVector, idx::Int) - -Shift the indices for all variables after `idx` to the left by one and update the relevant - fields. - -This only affects `vnv.varname_to_index` and `vnv.num_inactive` and is only valid as a -helper function for [`delete!`](@ref). -""" -function shift_subsequent_indices_left!(vnv::VarNamedVector, idx::Int) - # Shift the indices for all variables after `idx`. - for idx_to_shift in (idx + 1):length(vnv.varnames) - shift_index_left!(vnv, idx_to_shift) - end -end - -function Base.delete!(vnv::VarNamedVector, vn::VarName) - # Error if we don't have the variable. - !haskey(vnv, vn) && throw(ArgumentError("variable name $vn does not exist")) - - # Get the index of the variable. - idx = getidx(vnv, vn) - - # Delete the values. - r_start = first(getrange(vnv, idx)) - n_allocated = num_allocated(vnv, idx) - # NOTE: `deleteat!` also results in a `resize!` so we don't need to do that. - deleteat!(vnv.vals, r_start:(r_start + n_allocated - 1)) - - # Delete `vn` from the lookup table. - delete!(vnv.varname_to_index, vn) - - # Delete any inactive ranges corresponding to `vn`. - haskey(vnv.num_inactive, idx) && delete!(vnv.num_inactive, idx) - - # Re-adjust the indices for varnames occuring after `vn` so - # that they point to the correct indices after the deletions below. - shift_subsequent_indices_left!(vnv, idx) - - # Re-adjust the ranges for varnames occuring after `vn`. - shift_subsequent_ranges_by!(vnv, idx, -n_allocated) - - # Delete references from vector fields, thus shifting the indices of - # varnames occuring after `vn` by one to the left, as we adjusted for above. - deleteat!(vnv.varnames, idx) - deleteat!(vnv.ranges, idx) - deleteat!(vnv.transforms, idx) - - return vnv -end - -""" - delete!!(vnv::VarNamedVector, vn::VarName) - -Like `delete!!`, but tightens the element types of the returned `VarNamedVector`. - -# See also: -[`tighten_types!!`](@ref) -""" -BangBang.delete!!(vnv::VarNamedVector, vn::VarName) = tighten_types!!(delete!(vnv, vn)) - -""" - values_as(vnv::VarNamedVector[, T]) - -Return the values/realizations in `vnv` as type `T`, if implemented. - -If no type `T` is provided, return values as stored in `vnv`. - -# Examples - -```jldoctest -julia> using DynamicPPL: VarNamedVector - -julia> vnv = VarNamedVector(@varname(x) => 1, @varname(y) => [2.0]); - -julia> values_as(vnv) == [1.0, 2.0] -true - -julia> values_as(vnv, Vector{Float32}) == Vector{Float32}([1.0, 2.0]) -true - -julia> values_as(vnv, OrderedDict) == OrderedDict(@varname(x) => 1.0, @varname(y) => [2.0]) -true - -julia> values_as(vnv, NamedTuple) == (x = 1.0, y = [2.0]) -true -``` -""" -values_as(vnv::VarNamedVector) = values_as(vnv, Vector) -values_as(vnv::VarNamedVector, ::Type{Vector}) = getindex_internal(vnv, :) -function values_as(vnv::VarNamedVector, ::Type{Vector{T}}) where {T} - return convert(Vector{T}, values_as(vnv, Vector)) -end -function values_as(vnv::VarNamedVector, ::Type{NamedTuple}) - return NamedTuple(zip(map(Symbol, keys(vnv)), values(vnv))) -end -function values_as(vnv::VarNamedVector, ::Type{D}) where {D<:AbstractDict} - return ConstructionBase.constructorof(D)(pairs(vnv)) -end - -# See the docstring of `getvalue` for the semantics of `hasvalue` and `getvalue`, and how -# they differ from `haskey` and `getindex`. They can be found in AbstractPPL.jl. - -# TODO(mhauru) This is tricky to implement in the general case, and the below implementation -# only covers some simple cases. It's probably sufficient in most situations though. -function hasvalue(vnv::VarNamedVector, vn::VarName) - haskey(vnv, vn) && return true - any(subsumes(vn, k) for k in keys(vnv)) && return true - # Handle the easy case where the right symbol isn't even present. - !any(k -> getsym(k) == getsym(vn), keys(vnv)) && return false - - optic = getoptic(vn) - if optic isa Accessors.IndexLens || optic isa Accessors.ComposedOptic - # If vn is of the form @varname(somesymbol[someindex]), we check whether we store - # @varname(somesymbol) and can index into it with someindex. If we rather have a - # composed optic with the last part being an index lens, we do a similar check but - # stripping out the last index lens part. If these pass, the answer is definitely - # "yes". If not, we still don't know for sure. - # TODO(mhauru) What about casese where vnv stores both @varname(x) and - # @varname(x[1]) or @varname(x.a)? Those should probably be banned, but currently - # aren't. - head, tail = if optic isa Accessors.ComposedOptic - decomp_optic = Accessors.decompose(optic) - first(decomp_optic), Accessors.compose(decomp_optic[2:end]...) - else - optic, identity - end - parent_varname = VarName{getsym(vn)}(tail) - if haskey(vnv, parent_varname) - valvec = getindex(vnv, parent_varname) - return canview(head, valvec) - end - end - throw(ErrorException("hasvalue has not been fully implemented for this VarName: $(vn)")) -end - -# TODO(mhauru) Like hasvalue, this is only partially implemented. -function getvalue(vnv::VarNamedVector, vn::VarName) - !hasvalue(vnv, vn) && throw(KeyError(vn)) - haskey(vnv, vn) && getindex(vnv, vn) - - subsumed_keys = filter(k -> subsumes(vn, k), keys(vnv)) - if length(subsumed_keys) > 0 - # TODO(mhauru) What happens if getindex returns e.g. matrices, and we vcat them? - return mapreduce(k -> getindex(vnv, k), vcat, subsumed_keys) - end - - optic = getoptic(vn) - # See hasvalue for some comments on the logic of this if block. - if optic isa Accessors.IndexLens || optic isa Accessors.ComposedOptic - head, tail = if optic isa Accessors.ComposedOptic - decomp_optic = Accessors.decompose(optic) - first(decomp_optic), Accessors.compose(decomp_optic[2:end]...) - else - optic, identity - end - parent_varname = VarName{getsym(vn)}(tail) - valvec = getindex(vnv, parent_varname) - return head(valvec) - end - throw(ErrorException("getvalue has not been fully implemented for this VarName: $(vn)")) -end - -Base.get(vnv::VarNamedVector, vn::VarName) = getvalue(vnv, vn) diff --git a/test/Project.toml b/test/Project.toml index 927954ba4..5b26b0aef 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,7 +14,6 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" @@ -34,7 +33,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" AbstractMCMC = "5.10" -AbstractPPL = "0.13" +AbstractPPL = "0.14" Accessors = "0.1" Aqua = "0.8" BangBang = "0.4" @@ -46,7 +45,6 @@ Distributions = "0.25" DistributionsAD = "0.6.3" Documenter = "1" ForwardDiff = "0.10.12, 1" -JET = "0.9, 0.10, 0.11" LogDensityProblems = "2" MCMCChains = "7.2.1" MacroTools = "0.5.6" diff --git a/test/chains.jl b/test/chains.jl index 36c274b62..d69d2d4ca 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -68,8 +68,16 @@ end @testset "ParamsWithStats from LogDensityFunction" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS - unlinked_vi = VarInfo(m) + if m.f === DynamicPPL.TestUtils.demo_static_transformation + # TODO(mhauru) These tests are broken for demo_static_transformation because + # vi[vn] doesn't know which transform it should apply to the internally stored + # value. This requires a rethink, either of StaticTransformation or of what the + # comparison in this test should be. + @test false broken = true + continue + end @testset "$islinked" for islinked in (false, true) + unlinked_vi = VarInfo(m) vi = if islinked DynamicPPL.link!!(unlinked_vi, m) else @@ -82,7 +90,8 @@ end ps = ParamsWithStats(params, ldf) # Check that length of parameters is as expected - @test length(ps.params) == length(keys(vi)) + expected_length = sum(prod ∘ DynamicPPL.varnamesize, keys(vi)) + @test length(ps.params) == expected_length # Iterate over all variables to check that their values match for vn in keys(vi) diff --git a/test/compiler.jl b/test/compiler.jl index 62b6b9b2b..253b32990 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -236,9 +236,9 @@ module Issue537 end # https://github.com/TuringLang/Turing.jl/issues/1464#issuecomment-731153615 vi = VarInfo(gdemo(x)) - @test haskey(vi.metadata, :x) + @test haskey(vi, @varname(x)) vi = VarInfo(gdemo(x)) - @test haskey(vi.metadata, :x) + @test haskey(vi, @varname(x)) # Non-array variables @model function testmodel_nonarray(x, y) @@ -341,21 +341,21 @@ module Issue537 end end @testset "user-defined variable name" begin @model f1() = x ~ NamedDist(Normal(), :y) - @model f2() = x ~ NamedDist(Normal(), @varname(y[2][:, 1])) + @model f2() = x ~ NamedDist(Normal(), @varname(y[2][5, 1])) @model f3() = x ~ NamedDist(Normal(), @varname(y[1])) vi1 = VarInfo(f1()) vi2 = VarInfo(f2()) vi3 = VarInfo(f3()) - @test haskey(vi1.metadata, :y) - @test first(Base.keys(vi1.metadata.y)) == @varname(y) - @test haskey(vi2.metadata, :y) - @test first(Base.keys(vi2.metadata.y)) == @varname(y[2][:, 1]) - @test haskey(vi3.metadata, :y) - @test first(Base.keys(vi3.metadata.y)) == @varname(y[1]) + @test haskey(vi1, @varname(y)) + @test first(Base.keys(vi1)) == @varname(y) + @test haskey(vi2, @varname(y[2][5, 1])) + @test first(Base.keys(vi2)) == @varname(y[2][5, 1]) + @test haskey(vi3, @varname(y[1])) + @test first(Base.keys(vi3)) == @varname(y[1]) # Conditioning f1_c = f1() | (y=1,) - f2_c = f2() | NamedTuple((Symbol(@varname(y[2][:, 1])) => 1,)) + f2_c = f2() | NamedTuple((Symbol(@varname(y[2][5, 1])) => 1,)) f3_c = f3() | NamedTuple((Symbol(@varname(y[1])) => 1,)) @test f1_c() == 1 # TODO(torfjelde): We need conditioning for `Dict`. @@ -604,9 +604,9 @@ module Issue537 end # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) - @test svi == SimpleVarInfo() - @test retval == svi + retval, vi = DynamicPPL.init!!(demo(), VarInfo()) + @test vi == VarInfo() + @test retval == vi # We should not be altering return-values other than at top-level. @model function demo() @@ -615,11 +615,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) + retval, vi = DynamicPPL.init!!(demo(), VarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) + retval, vi = DynamicPPL.init!!(demo(), VarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() diff --git a/test/contexts.jl b/test/contexts.jl index cdd32f379..5c16c8dd5 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,4 +1,4 @@ -using Test, DynamicPPL, Accessors +using Test, DynamicPPL using AbstractPPL: getoptic, hasvalue, getvalue using DynamicPPL: leafcontext, @@ -430,18 +430,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end @testset "InitContext" begin - empty_varinfos = [ - ("untyped+metadata", VarInfo()), - ("typed+metadata", DynamicPPL.typed_varinfo(VarInfo())), - ("untyped+VNV", VarInfo(DynamicPPL.VarNamedVector())), - ( - "typed+VNV", - DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), - ), - ("SVI+NamedTuple", SimpleVarInfo()), - ("Svi+Dict", SimpleVarInfo(OrderedDict{VarName,Any}())), - ] - @model function test_init_model() x ~ Normal() y ~ MvNormal(fill(x, 2), I) @@ -454,19 +442,17 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # Check that init!! can generate values that weren't there # previously. model = test_init_model() - @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos - this_vi = deepcopy(empty_vi) - _, vi = DynamicPPL.init!!(model, this_vi, strategy) - @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) - x, y = vi[@varname(x)], vi[@varname(y)] - @test x isa Real - @test y isa AbstractVector{<:Real} - @test length(y) == 2 - (; logprior, loglikelihood) = getlogp(vi) - @test logpdf(Normal(), x) + logpdf(MvNormal(fill(x, 2), I), y) == - logprior - @test logpdf(Normal(), 1.0) == loglikelihood - end + empty_vi = VarInfo() + this_vi = deepcopy(empty_vi) + _, vi = DynamicPPL.init!!(model, this_vi, strategy) + @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) + x, y = vi[@varname(x)], vi[@varname(y)] + @test x isa Real + @test y isa AbstractVector{<:Real} + @test length(y) == 2 + (; logprior, loglikelihood) = getlogp(vi) + @test logpdf(Normal(), x) + logpdf(MvNormal(fill(x, 2), I), y) == logprior + @test logpdf(Normal(), 1.0) == loglikelihood end end @@ -474,40 +460,40 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "replacing old values: $(typeof(strategy))" begin # Check that init!! can overwrite values that were already there. model = test_init_model() - @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos - # start by generating some rubbish values - vi = deepcopy(empty_vi) - old_x, old_y = 100000.00, [300000.00, 500000.00] - push!!(vi, @varname(x), old_x, Normal()) - push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I)) - # then overwrite it - _, new_vi = DynamicPPL.init!!(model, vi, strategy) - new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] - # check that the values are (presumably) different - @test old_x != new_x - @test old_y != new_y - end + empty_vi = VarInfo() + # start by generating some rubbish values + vi = deepcopy(empty_vi) + old_x, old_y = 100000.00, [300000.00, 500000.00] + vi, _ = DynamicPPL.setindex_with_dist!!(vi, old_x, Normal(), @varname(x)) + vi, _ = DynamicPPL.setindex_with_dist!!( + vi, old_y, MvNormal(fill(old_x, 2), I), @varname(y) + ) + # then overwrite it + _, new_vi = DynamicPPL.init!!(model, vi, strategy) + new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] + # check that the values are (presumably) different + @test old_x != new_x + @test old_y != new_y end end function test_rng_respected(strategy::AbstractInitStrategy) @testset "check that RNG is respected: $(typeof(strategy))" begin model = test_init_model() - @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos - _, vi1 = DynamicPPL.init!!( - Xoshiro(468), model, deepcopy(empty_vi), strategy - ) - _, vi2 = DynamicPPL.init!!( - Xoshiro(468), model, deepcopy(empty_vi), strategy - ) - _, vi3 = DynamicPPL.init!!( - Xoshiro(469), model, deepcopy(empty_vi), strategy - ) - @test vi1[@varname(x)] == vi2[@varname(x)] - @test vi1[@varname(y)] == vi2[@varname(y)] - @test vi1[@varname(x)] != vi3[@varname(x)] - @test vi1[@varname(y)] != vi3[@varname(y)] - end + empty_vi = VarInfo() + _, vi1 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi2 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi3 = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), strategy + ) + @test vi1[@varname(x)] == vi2[@varname(x)] + @test vi1[@varname(y)] == vi2[@varname(y)] + @test vi1[@varname(x)] != vi3[@varname(x)] + @test vi1[@varname(y)] != vi3[@varname(y)] end end @@ -594,21 +580,20 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() params_nt = (; x=my_x, y=my_y) params_dict = Dict(@varname(x) => my_x, @varname(y) => my_y) model = test_init_model() - @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos - _, vi = DynamicPPL.init!!( - model, deepcopy(empty_vi), InitFromParams(params_nt) - ) - @test vi[@varname(x)] == my_x - @test vi[@varname(y)] == my_y - logp_nt = getlogp(vi) - _, vi = DynamicPPL.init!!( - model, deepcopy(empty_vi), InitFromParams(params_dict) - ) - @test vi[@varname(x)] == my_x - @test vi[@varname(y)] == my_y - logp_dict = getlogp(vi) - @test logp_nt == logp_dict - end + empty_vi = VarInfo() + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_nt) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_nt = getlogp(vi) + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_dict) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_dict = getlogp(vi) + @test logp_nt == logp_dict end @testset "given only partial parameters" begin @@ -616,56 +601,53 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() params_nt = (; x=my_x) params_dict = Dict(@varname(x) => my_x) model = test_init_model() - @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos - @testset "with InitFromPrior fallback" begin - _, vi = DynamicPPL.init!!( - Xoshiro(468), - model, - deepcopy(empty_vi), - InitFromParams(params_nt, InitFromPrior()), - ) - @test vi[@varname(x)] == my_x - nt_y = vi[@varname(y)] - @test nt_y isa AbstractVector{<:Real} - @test length(nt_y) == 2 - _, vi = DynamicPPL.init!!( - Xoshiro(469), - model, - deepcopy(empty_vi), - InitFromParams(params_dict, InitFromPrior()), - ) - @test vi[@varname(x)] == my_x - dict_y = vi[@varname(y)] - @test dict_y isa AbstractVector{<:Real} - @test length(dict_y) == 2 - # the values should be different since we used different seeds - @test dict_y != nt_y - end + empty_vi = VarInfo() + @testset "with InitFromPrior fallback" begin + _, vi = DynamicPPL.init!!( + Xoshiro(468), + model, + deepcopy(empty_vi), + InitFromParams(params_nt, InitFromPrior()), + ) + @test vi[@varname(x)] == my_x + nt_y = vi[@varname(y)] + @test nt_y isa AbstractVector{<:Real} + @test length(nt_y) == 2 + _, vi = DynamicPPL.init!!( + Xoshiro(469), + model, + deepcopy(empty_vi), + InitFromParams(params_dict, InitFromPrior()), + ) + @test vi[@varname(x)] == my_x + dict_y = vi[@varname(y)] + @test dict_y isa AbstractVector{<:Real} + @test length(dict_y) == 2 + # the values should be different since we used different seeds + @test dict_y != nt_y + end - @testset "with no fallback" begin - # These just don't have an entry for `y`. - @test_throws ErrorException DynamicPPL.init!!( - model, deepcopy(empty_vi), InitFromParams(params_nt, nothing) - ) - @test_throws ErrorException DynamicPPL.init!!( - model, deepcopy(empty_vi), InitFromParams(params_dict, nothing) - ) - # We also explicitly test the case where `y = missing`. - params_nt_missing = (; x=my_x, y=missing) - params_dict_missing = Dict( - @varname(x) => my_x, @varname(y) => missing - ) - @test_throws ErrorException DynamicPPL.init!!( - model, - deepcopy(empty_vi), - InitFromParams(params_nt_missing, nothing), - ) - @test_throws ErrorException DynamicPPL.init!!( - model, - deepcopy(empty_vi), - InitFromParams(params_dict_missing, nothing), - ) - end + @testset "with no fallback" begin + # These just don't have an entry for `y`. + @test_throws ErrorException DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_nt, nothing) + ) + @test_throws ErrorException DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_dict, nothing) + ) + # We also explicitly test the case where `y = missing`. + params_nt_missing = (; x=my_x, y=missing) + params_dict_missing = Dict(@varname(x) => my_x, @varname(y) => missing) + @test_throws ErrorException DynamicPPL.init!!( + model, + deepcopy(empty_vi), + InitFromParams(params_nt_missing, nothing), + ) + @test_throws ErrorException DynamicPPL.init!!( + model, + deepcopy(empty_vi), + InitFromParams(params_dict_missing, nothing), + ) end end end diff --git a/test/debug_utils.jl b/test/debug_utils.jl index f950f6b45..343282480 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -67,60 +67,6 @@ model = ModelOuterWorking2() @test check_model(model, VarInfo(model); error_on_failure=true) end - - @testset "subsumes (x then x[1])" begin - @model function buggy_subsumes_demo_model() - x = Vector{Float64}(undef, 2) - x ~ MvNormal(zeros(2), I) - x[1] ~ Normal() - return nothing - end - buggy_model = buggy_subsumes_demo_model() - varinfo = VarInfo(buggy_model) - - @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) - issuccess = check_model(buggy_model, varinfo) - @test !issuccess - @test_throws ErrorException check_model( - buggy_model, varinfo; error_on_failure=true - ) - end - - @testset "subsumes (x[1] then x)" begin - @model function buggy_subsumes_demo_model() - x = Vector{Float64}(undef, 2) - x[1] ~ Normal() - x ~ MvNormal(zeros(2), I) - return nothing - end - buggy_model = buggy_subsumes_demo_model() - varinfo = VarInfo(buggy_model) - - @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) - issuccess = check_model(buggy_model, varinfo) - @test !issuccess - @test_throws ErrorException check_model( - buggy_model, varinfo; error_on_failure=true - ) - end - - @testset "subsumes (x.a then x)" begin - @model function buggy_subsumes_demo_model() - x = (a=nothing,) - x.a ~ Normal() - x ~ Normal() - return nothing - end - buggy_model = buggy_subsumes_demo_model() - varinfo = VarInfo(buggy_model) - - @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) - issuccess = check_model(buggy_model, varinfo) - @test !issuccess - @test_throws ErrorException check_model( - buggy_model, varinfo; error_on_failure=true - ) - end end @testset "NaN in data" begin diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl deleted file mode 100644 index e46c25113..000000000 --- a/test/ext/DynamicPPLJETExt.jl +++ /dev/null @@ -1,113 +0,0 @@ -@testset "DynamicPPLJETExt.jl" begin - @testset "determine_suitable_varinfo" begin - @model function demo1() - x ~ Bernoulli() - if x - y ~ Normal() - else - z ~ Normal() - end - end - model = demo1() - @test DynamicPPL.Experimental.determine_suitable_varinfo(model) isa - DynamicPPL.UntypedVarInfo - - @model demo2() = x ~ Normal() - @test DynamicPPL.Experimental.determine_suitable_varinfo(demo2()) isa - DynamicPPL.NTVarInfo - - @model function demo3() - # Just making sure that nothing strange happens when type inference fails. - x = Vector(undef, 1) - x[1] ~ Bernoulli() - if x[1] - y ~ Normal() - else - z ~ Normal() - end - end - @test DynamicPPL.Experimental.determine_suitable_varinfo(demo3()) isa - DynamicPPL.UntypedVarInfo - - # Evaluation works (and it would even do so in practice), but sampling - # will fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`. - @model function demo4() - x ~ Bernoulli() - if x - y ~ Normal() - else - y ~ Cauchy() # different distibution, but same transformation - end - end - @test DynamicPPL.Experimental.determine_suitable_varinfo(demo4()) isa - DynamicPPL.UntypedVarInfo - - # In this model, the type error occurs in the user code rather than in DynamicPPL. - @model function demo5() - x ~ Normal() - xs = Any[] - push!(xs, x) - # `sum(::Vector{Any})` can potentially error unless the dynamic manages to resolve the - # correct `zero` method. As a result, this code will run, but JET will raise this is an issue. - return sum(xs) - end - # Should pass if we're only checking the tilde statements. - @test DynamicPPL.Experimental.determine_suitable_varinfo(demo5()) isa - DynamicPPL.NTVarInfo - # Should fail if we're including errors in the model body. - @test DynamicPPL.Experimental.determine_suitable_varinfo( - demo5(); only_dppl=false - ) isa DynamicPPL.UntypedVarInfo - end - - @testset "demo models" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS - if model.f === DynamicPPL.TestUtils.demo_lkjchol - # TODO(mhauru) - # The LKJCholesky model fails with JET. The problem is not with Turing but - # with Distributions, and ultimately this in LinearAlgebra: - # julia> v = @view rand(2,2)[:,1]; - # - # julia> JET.@report_call norm(v) - # ═════ 2 possible errors found ═════ - # blahblah - # The below trivial call to @test is just marking that there's something - # broken here. - @test false broken = true - continue - end - # Use debug logging below. - varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) - # Check that the inferred varinfo is indeed suitable for evaluation - f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo - ) - JET.test_call(f_eval, argtypes_eval) - - # For our demo models, they should all result in typed. - is_typed = varinfo isa DynamicPPL.NTVarInfo - @test is_typed - # If the test failed, check what the type stability problem was for - # the typed varinfo. This is mostly useful for debugging from test - # logs. - if !is_typed - @info "Model `$(model.f)` is not type stable with typed varinfo." - typed_vi = DynamicPPL.typed_varinfo(model) - - @info "Evaluating with DefaultContext:" - model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext()) - f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo - ) - JET.test_call(f, argtypes) - - @info "Initialising with InitContext:" - model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) - f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo - ) - JET.test_call(f, argtypes) - end - end - end -end diff --git a/test/linking.jl b/test/linking.jl index 2047b9d11..bfd1285b1 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -89,7 +89,7 @@ end DynamicPPL.getlogjoint_internal(vi_linked) ≈ log(2) # The non-internal logjoint should be the same since it doesn't depend on linking. @test DynamicPPL.getlogjoint(vi) ≈ DynamicPPL.getlogjoint(vi_linked) - @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) + @test vi_linked[@varname(m)] == LowerTriangular(vi[@varname(m)]) # Linked one should be working with a lower-dimensional representation. @test length(vi_linked[:]) < length(vi[:]) @test length(vi_linked[:]) == length(y) @@ -100,7 +100,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == length(vi[:]) - @test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist]) + @test vi_invlinked[@varname(m)] ≈ LowerTriangular(vi[@varname(m)]) # The non-internal logjoint should still be the same, again since # it doesn't depend on linking. @test DynamicPPL.getlogjoint(vi_invlinked) ≈ DynamicPPL.getlogjoint(vi) @@ -121,7 +121,7 @@ end model, values_original, (@varname(x),) ) @testset "$(short_varinfo_name(vi))" for vi in vis - val = vi[@varname(x), dist] + val = vi[@varname(x)] # Ensure that `reconstruct` works as intended. @test val isa Cholesky @test val.uplo == uplo diff --git a/test/lkj.jl b/test/lkj.jl index 5c5603aba..bab3ce185 100644 --- a/test/lkj.jl +++ b/test/lkj.jl @@ -37,7 +37,7 @@ end last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples ] corr_matrices = map(samples) do s - M = reshape(s.metadata.vals, (2, 2)) + M = reshape(DynamicPPL.getindex_internal(s, @varname(x)), (2, 2)) pd_from_triangular(M, uplo) end @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index ceec4d02a..8d5c464b9 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -17,35 +17,24 @@ using Mooncake: Mooncake @testset "LogDensityFunction: Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS - @testset "$varinfo_func" for varinfo_func in [ - DynamicPPL.untyped_varinfo, - DynamicPPL.typed_varinfo, - DynamicPPL.untyped_vector_varinfo, - DynamicPPL.typed_vector_varinfo, - ] - unlinked_vi = varinfo_func(m) - @testset "$islinked" for islinked in (false, true) - vi = if islinked - DynamicPPL.link!!(unlinked_vi, m) - else - unlinked_vi - end - nt_ranges, dict_ranges = DynamicPPL.get_ranges_and_linked(vi) - params = [x for x in vi[:]] - # Iterate over all variables - for vn in keys(vi) - # Check that `getindex_internal` returns the same thing as using the ranges - # directly - range_with_linked = if AbstractPPL.getoptic(vn) === identity - nt_ranges[AbstractPPL.getsym(vn)] - else - dict_ranges[vn] - end - @test params[range_with_linked.range] == - DynamicPPL.getindex_internal(vi, vn) - # Check that the link status is correct - @test range_with_linked.is_linked == islinked - end + @testset "$islinked" for islinked in (false, true) + unlinked_vi = DynamicPPL.VarInfo(m) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + ranges = DynamicPPL.get_ranges_and_linked(vi) + params = [x for x in vi[:]] + # Iterate over all variables + for vn in keys(vi) + # Check that `getindex_internal` returns the same thing as using the ranges + # directly + range_with_linked = ranges[vn] + @test params[range_with_linked.range] == + DynamicPPL.tovec(DynamicPPL.getindex_internal(vi, vn)) + # Check that the link status is correct + @test range_with_linked.is_linked == islinked end end end @@ -159,8 +148,8 @@ end @testset "LogDensityFunction: Type stability" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS - unlinked_vi = DynamicPPL.VarInfo(m) @testset "$islinked" for islinked in (false, true) + unlinked_vi = DynamicPPL.VarInfo(m) vi = if islinked DynamicPPL.link!!(unlinked_vi, m) else @@ -168,7 +157,12 @@ end end ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi) x = vi[:] - @inferred LogDensityProblems.logdensity(ldf, x) + # The below type inference fails on v1.10. + skip = (VERSION < v"1.11.0" && m.f === DynamicPPL.TestUtils.demo_nested_colons) + @test begin + @inferred LogDensityProblems.logdensity(ldf, x) + true + end skip = skip end end end diff --git a/test/model.jl b/test/model.jl index c878fd905..281eaaad4 100644 --- a/test/model.jl +++ b/test/model.jl @@ -25,10 +25,6 @@ function innermost_distribution_type(d::Distributions.Product) return dists[1] end -is_type_stable_varinfo(::DynamicPPL.AbstractVarInfo) = false -is_type_stable_varinfo(varinfo::DynamicPPL.NTVarInfo) = true -is_type_stable_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true - const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "model.jl" begin @@ -58,6 +54,15 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() #### logprior, logjoint, loglikelihood for MCMC chains #### @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + if model.f === DynamicPPL.TestUtils.demo_nested_colons + # TODO(mhauru) The below test fails on this model, due to the VarName + # s.params[1].subparams[:, 1, :], which AbstractPPL.varname_leaves splits + # into subvarnames like s.params[1].subparams[:, 1, :][1, 1], but the chain + # would know as s.params[1].subparams[1, 1, 1]. Unsure what the correct fix + # is, so leaving this for later. + @test false broken = true + continue + end N = 200 chain = make_chain_from_prior(model, N) logpriors = logprior(model, chain) @@ -213,7 +218,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test !any(map(x -> x isa DynamicPPL.AbstractVarInfo, call_retval)) end - @testset "Dynamic constraints, Metadata" begin + @testset "Dynamic constraints" begin model = DynamicPPL.TestUtils.demo_dynamic_constraint() vi = VarInfo(model) vi = link!!(vi, model) @@ -221,24 +226,13 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() for i in 1:10 # Sample with large variations. r_raw = randn(length(vi[:])) * 10 - vi = DynamicPPL.unflatten(vi, r_raw) + vi = DynamicPPL.unflatten!!(vi, r_raw) @test vi[@varname(m)] == r_raw[1] @test vi[@varname(x)] != r_raw[2] model(vi) end end - @testset "Dynamic constraints, VectorVarInfo" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() - for i in 1:10 - for vi_constructor in - [DynamicPPL.typed_vector_varinfo, DynamicPPL.untyped_vector_varinfo] - vi = vi_constructor(model) - @test vi[@varname(x)] >= vi[@varname(m)] - end - end - end - @testset "rand" begin model = GDEMO_DEFAULT @@ -314,7 +308,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict{VarName,Any}()))) + vi = last(DynamicPPL.init!!(model, VarInfo())) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) @@ -408,12 +402,17 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) ] @testset "$(model.f)" for model in models_to_test + if model.f === DynamicPPL.TestUtils.demo_nested_colons && VERSION < v"1.11" + # On v1.10, the demo_nested_colons model, which uses a lot of + # NamedTuples, is badly type unstable. Not worth doing much about + # it, since it's fixed on later Julia versions, so just skipping + # these tests. + @test false skip = true + continue + end vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = filter( - is_type_stable_varinfo, - DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos @test begin @inferred(DynamicPPL.evaluate!!(model, varinfo)) @@ -433,6 +432,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "values_as_in_model" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS vns = DynamicPPL.TestUtils.varnames(model) + vns_split = DynamicPPL.TestUtils.varnames_split(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos @@ -442,7 +442,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() realizations = values_as_in_model(model, false, varinfo) # Ensure that all variables are found. vns_found = collect(keys(realizations)) - @test vns ∩ vns_found == vns ∪ vns_found + @test vns_split ∩ vns_found == vns_split ∪ vns_found # Ensure that the values are the same. for vn in vns @test realizations[vn] == varinfo[vn] @@ -492,26 +492,17 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end model = product_dirichlet() - varinfos = [ - DynamicPPL.untyped_varinfo(model), - DynamicPPL.typed_varinfo(model), - DynamicPPL.typed_simple_varinfo(model), - DynamicPPL.untyped_simple_varinfo(model), - ] - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - logjoint = getlogjoint(varinfo) # unlinked space - varinfo_linked = DynamicPPL.link(varinfo, model) - varinfo_linked_result = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked)) - ) - # getlogjoint should return the same result as before it was linked - @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) - @test getlogjoint(varinfo_linked) ≈ logjoint - # getlogjoint_internal shouldn't - @test getlogjoint_internal(varinfo_linked) ≈ - getlogjoint_internal(varinfo_linked_result) - @test !isapprox(getlogjoint_internal(varinfo_linked), logjoint) - end + varinfo = DynamicPPL.VarInfo(model) + logjoint = getlogjoint(varinfo) # unlinked space + varinfo_linked = DynamicPPL.link(varinfo, model) + varinfo_linked_result = last(DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked))) + # getlogjoint should return the same result as before it was linked + @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) + @test getlogjoint(varinfo_linked) ≈ logjoint + # getlogjoint_internal shouldn't + @test getlogjoint_internal(varinfo_linked) ≈ + getlogjoint_internal(varinfo_linked_result) + @test !isapprox(getlogjoint_internal(varinfo_linked), logjoint) end @testset "predict" begin diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 780d45b46..fde807dda 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -5,7 +5,7 @@ # Instantiate a `VarInfo` with the example values. vi = VarInfo(model) for vn in DynamicPPL.TestUtils.varnames(model) - vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + vi = DynamicPPL.setindex!!(vi, AbstractPPL.getvalue(example_values, vn), vn) end loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true( diff --git a/test/runtests.jl b/test/runtests.jl index 9649aebbb..6521f1e4a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,8 +28,6 @@ using Test using Distributions using LinearAlgebra # Diagonal -using JET: JET - using Combinatorics: combinations using OrderedCollections: OrderedSet @@ -53,9 +51,8 @@ include("test_util.jl") include("utils.jl") include("accumulators.jl") include("compiler.jl") - include("varnamedvector.jl") + include("varnamedtuple.jl") include("varinfo.jl") - include("simple_varinfo.jl") include("model.jl") include("distribution_wrappers.jl") include("linking.jl") @@ -75,7 +72,6 @@ include("test_util.jl") include("logdensityfunction.jl") @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") - include("ext/DynamicPPLJETExt.jl") include("ext/DynamicPPLMarginalLogDensitiesExt.jl") end @testset "ad" begin diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl deleted file mode 100644 index 42e377440..000000000 --- a/test/simple_varinfo.jl +++ /dev/null @@ -1,337 +0,0 @@ -@testset "simple_varinfo.jl" begin - @testset "constructor & indexing" begin - @testset "NamedTuple" begin - svi = SimpleVarInfo(; m=1.0) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(; m=[1.0]) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(; m=(a=[1.0],)) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - - svi = SimpleVarInfo{Float32}(; m=1.0) - @test getlogjoint(svi) isa Float32 - - svi = SimpleVarInfo((m=1.0,)) - svi = accloglikelihood!!(svi, 1.0) - @test getlogjoint(svi) == 1.0 - end - - @testset "Dict" begin - svi = SimpleVarInfo(OrderedDict(@varname(m) => 1.0)) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(OrderedDict(@varname(m) => [1.0])) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(OrderedDict(@varname(m) => (a=[1.0],))) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - - svi = SimpleVarInfo(OrderedDict(@varname(m.a) => [1.0])) - # Now we only have a variable `m.a` which is subsumed by `m`, - # but we can't guarantee that we have the "entire" `m`. - @test !haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - end - - @testset "VarNamedVector" begin - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => 1.0)) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => [1.0])) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m.a) => [1.0])) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - # The implementation of haskey and getvalue fo VarNamedVector is incomplete, the - # next test is here to remind of us that. - svi = SimpleVarInfo( - push!!(DynamicPPL.VarNamedVector(), @varname(m.a.b) => [1.0]) - ) - @test_broken !haskey(svi, @varname(m.a.b.c.d)) - end - end - - @testset "link!! & invlink!! on $(nameof(model))" for model in - DynamicPPL.TestUtils.ALL_MODELS - values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - @testset "$name" for (name, vi) in ( - ("SVI{Dict}", SimpleVarInfo(OrderedDict{VarName,Any}())), - ("SVI{NamedTuple}", SimpleVarInfo(values_constrained)), - ("SVI{VNV}", SimpleVarInfo(DynamicPPL.VarNamedVector())), - ("TypedVarInfo", DynamicPPL.typed_varinfo(model)), - ) - if name == "SVI{NamedTuple}" && - model.f === DynamicPPL.TestUtils.demo_one_variable_multiple_constraints - # TODO(mhauru) There's a bug in SimpleVarInfo{<:NamedTuple} for cases where - # a variable set with IndexLenses changes dimension under linking. This - # makes the link!! call crash. The below call to @test just marks the fact - # that there's something broken here. - @test false broken = true - continue - end - for vn in DynamicPPL.TestUtils.varnames(model) - vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) - end - vi = last(DynamicPPL.evaluate!!(model, vi)) - - # Calculate ground truth - lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true( - model, values_constrained... - ) - _, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, values_constrained... - ) - - # `link!!` - vi_linked = link!!(deepcopy(vi), model) - lp_unlinked = getlogjoint(vi_linked) - lp_linked = getlogjoint_internal(vi_linked) - @test lp_linked ≈ lp_linked_true - @test lp_unlinked ≈ lp_unlinked_true - @test logjoint(model, vi_linked) ≈ lp_unlinked - - # `invlink!!` - vi_invlinked = invlink!!(deepcopy(vi_linked), model) - lp_unlinked = getlogjoint(vi_invlinked) - also_lp_unlinked = getlogjoint_internal(vi_invlinked) - @test lp_unlinked ≈ lp_unlinked_true - @test also_lp_unlinked ≈ lp_unlinked_true - @test logjoint(model, vi_invlinked) ≈ lp_unlinked - - # Should result in same values. - @test all( - DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_invlinked, vn)) ≈ - DynamicPPL.tovec(get(values_constrained, vn)) for - vn in DynamicPPL.TestUtils.varnames(model) - ) - end - end - - @testset "SimpleVarInfo on $(nameof(model))" for model in - DynamicPPL.TestUtils.ALL_MODELS - # We might need to pre-allocate for the variable `m`, so we need - # to see whether this is the case. - svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) - svi_dict = SimpleVarInfo(VarInfo(model), Dict) - vnv = DynamicPPL.VarNamedVector() - for (k, v) in pairs(DynamicPPL.TestUtils.rand_prior_true(model)) - vnv = push!!(vnv, VarName{k}() => v) - end - svi_vnv = SimpleVarInfo(vnv) - - @testset "$name" for (name, svi) in ( - ("NamedTuple", svi_nt), - ("Dict", svi_dict), - ("VarNamedVector", svi_vnv), - # TODO(mhauru) Fix linked SimpleVarInfos to work with our test models. - # DynamicPPL.set_transformed!!(deepcopy(svi_nt), true), - # DynamicPPL.set_transformed!!(deepcopy(svi_dict), true), - # DynamicPPL.set_transformed!!(deepcopy(svi_vnv), true), - ) - # Random seed is set in each `@testset`, so we need to sample - # a new realization for `m` here. - retval = model() - - ### Sampling ### - # Sample a new varinfo! - _, svi_new = DynamicPPL.init!!(model, svi) - - # Realization for `m` should be different wp. 1. - for vn in DynamicPPL.TestUtils.varnames(model) - @test svi_new[vn] != get(retval, vn) - end - - # Logjoint should be non-zero wp. 1. - @test getlogjoint(svi_new) != 0 - - ### Evaluation ### - values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - if DynamicPPL.is_transformed(svi) - _values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( - model, values_eval_constrained... - ) - values_eval, logπ_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, values_eval_constrained... - ) - # Make sure that these two computation paths provide the same - # transformed values. - @test values_eval == _values_prior - else - logpri_true = DynamicPPL.TestUtils.logprior_true( - model, values_eval_constrained... - ) - logπ_true = DynamicPPL.TestUtils.logjoint_true( - model, values_eval_constrained... - ) - values_eval = values_eval_constrained - end - - # No logabsdet-jacobian correction needed for the likelihood. - loglik_true = DynamicPPL.TestUtils.loglikelihood_true( - model, values_eval_constrained... - ) - - # Update the realizations in `svi_new`. - svi_eval = svi_new - for vn in DynamicPPL.TestUtils.varnames(model) - svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) - end - - # Reset the logp accumulators. - svi_eval = DynamicPPL.resetaccs!!(svi_eval) - - # Compute `logjoint` using the varinfo. - logπ = logjoint(model, svi_eval) - logpri = logprior(model, svi_eval) - loglik = loglikelihood(model, svi_eval) - - # Values should not have changed. - for vn in DynamicPPL.TestUtils.varnames(model) - # TODO(mhauru) Workaround for - # https://github.com/JuliaLang/LinearAlgebra.jl/pull/1404 - # Remove once the fix is all Julia versions we support. - val = get(values_eval, vn) - if val isa Cholesky - @test svi_eval[vn].L == val.L - else - @test svi_eval[vn] == val - end - end - - # Compare log-probability computations. - @test logpri ≈ logpri_true - @test loglik ≈ loglik_true - @test logπ ≈ logπ_true - end - end - - @testset "Dynamic constraints" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() - - # Initialize. - svi_nt = DynamicPPL.set_transformed!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.init!!(model, svi_nt)) - svi_vnv = DynamicPPL.set_transformed!!( - SimpleVarInfo(DynamicPPL.VarNamedVector()), true - ) - svi_vnv = last(DynamicPPL.init!!(model, svi_vnv)) - - for svi in (svi_nt, svi_vnv) - # Sample with large variations in unconstrained space. - for i in 1:10 - for vn in keys(svi) - svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) - end - retval, svi = DynamicPPL.evaluate!!(model, svi) - @test retval.m == svi[@varname(m)] # `m` is unconstrained - @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` - - retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, retval.m, retval.x - ) - - # Realizations from model should all be equal to the unconstrained realization. - for vn in DynamicPPL.TestUtils.varnames(model) - @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 - end - - # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogjoint_internal(svi) - # needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375 - @test lp ≈ lp_true atol = 1.2e-5 - end - end - end - - @testset "Static transformation" begin - model = DynamicPPL.TestUtils.demo_static_transformation() - - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, DynamicPPL.TestUtils.rand_prior_true(model), [@varname(s), @varname(m)] - ) - @testset "$(short_varinfo_name(vi))" for vi in varinfos - # Initialize varinfo and link. - vi_linked = DynamicPPL.link!!(vi, model) - - # Make sure `maybe_invlink_before_eval!!` results in `invlink!!`. - @test !DynamicPPL.is_transformed( - DynamicPPL.maybe_invlink_before_eval!!(deepcopy(vi), model) - ) - - # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.init!!(model, deepcopy(vi))) - @test !DynamicPPL.is_transformed(vi_result) - - # Set the values to something that is out of domain if we're in constrained space. - for vn in keys(vi) - vi_linked = DynamicPPL.setindex!!(vi_linked, -rand(), vn) - end - - # NOTE: Evaluating a linked VarInfo, **specifically when the transformation - # is static**, will result in an invlinked VarInfo. This is because of - # `maybe_invlink_before_eval!`, which only invlinks if the transformation - # is static. (src/abstract_varinfo.jl) - retval, vi_unlinked_again = DynamicPPL.evaluate!!(model, deepcopy(vi_linked)) - - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≠ - DynamicPPL.tovec(retval.s) # `s` is unconstrained in original - @test DynamicPPL.tovec( - DynamicPPL.getindex_internal(vi_unlinked_again, @varname(s)) - ) == DynamicPPL.tovec(retval.s) # `s` is constrained in result - - # `m` should not be transformed. - @test vi_linked[@varname(m)] == retval.m - @test vi_unlinked_again[@varname(m)] == retval.m - - # Get ground truths - retval_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, retval.s, retval.m - ) - lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true(model, retval.s, retval.m) - - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≈ - DynamicPPL.tovec(retval_unconstrained.s) - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(m))) ≈ - DynamicPPL.tovec(retval_unconstrained.m) - - # The unlinked varinfo should hold the unlinked logp. - lp_unlinked = getlogjoint(vi_unlinked_again) - @test getlogjoint(vi_unlinked_again) ≈ lp_unlinked_true - end - end -end diff --git a/test/test_util.jl b/test/test_util.jl index 94fdbd744..8f402ad8f 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -16,28 +16,8 @@ Return string representing a short description of `vi`. function short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) return "threadsafe($(short_varinfo_name(vi.varinfo)))" end -function short_varinfo_name(vi::DynamicPPL.NTVarInfo) - return if DynamicPPL.has_varnamedvector(vi) - "TypedVectorVarInfo" - else - "TypedVarInfo" - end -end -short_varinfo_name(::DynamicPPL.UntypedVarInfo) = "UntypedVarInfo" -short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo" -function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) - return "SimpleVarInfo{<:NamedTuple,<:Ref}" -end -function short_varinfo_name(::SimpleVarInfo{<:OrderedDict,<:Ref}) - return "SimpleVarInfo{<:OrderedDict,<:Ref}" -end -function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector,<:Ref}) - return "SimpleVarInfo{<:VarNamedVector,<:Ref}" -end -short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" -short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" -function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) - return "SimpleVarInfo{<:VarNamedVector}" +function short_varinfo_name(::DynamicPPL.VarInfo) + return "VarInfo" end # convenient functions for testing model.jl diff --git a/test/utils.jl b/test/utils.jl index bef1c2ba8..bc01fc0ce 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -186,29 +186,6 @@ end t = (2.0, [3.0, 4.0]) @test DynamicPPL.tovec(t) == [2.0, 3.0, 4.0] end - - @testset "unique_syms" begin - vns = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2])) - @inferred DynamicPPL.unique_syms(vns) - @inferred DynamicPPL.unique_syms(()) - @test DynamicPPL.unique_syms(vns) == (:x, :y, :z) - @test DynamicPPL.unique_syms(()) == () - end - - @testset "group_varnames_by_symbol" begin - vns_tuple = ( - @varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2]) - ) - vns_vec = collect(vns_tuple) - vns_nt = (; - x=[@varname(x), @varname(x.a)], - y=[@varname(y[1]), @varname(y[2])], - z=[@varname(z[15])], - ) - vns_vec_single_symbol = [@varname(x.a), @varname(x.b), @varname(x[1])] - @inferred DynamicPPL.group_varnames_by_symbol(vns_tuple) - @test DynamicPPL.group_varnames_by_symbol(vns_tuple) == vns_nt - end end end diff --git a/test/varinfo.jl b/test/varinfo.jl index bf5cfe561..323050165 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,23 +1,9 @@ function check_varinfo_keys(varinfo, vns) - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} - # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, - # since `keys(varinfo_merged)` only contains `VarName` with `identity`. - # So we just check that the original keys are present. - for vn in vns - # Should have all the original keys. - @test haskey(varinfo, vn) - end - else - vns_varinfo = keys(varinfo) - # Should be equivalent. - @test union(vns_varinfo, vns) == intersect(vns_varinfo, vns) - end + vns_varinfo = keys(varinfo) + @test union(vns_varinfo, vns) == intersect(vns_varinfo, vns) end function check_metadata_type_equal(v1::VarInfo, v2::VarInfo) - @test typeof(v1.metadata) == typeof(v2.metadata) -end -function check_metadata_type_equal(v1::SimpleVarInfo, v2::SimpleVarInfo) @test typeof(v1.values) == typeof(v2.values) end function check_metadata_type_equal( @@ -27,124 +13,59 @@ function check_metadata_type_equal( return check_metadata_type_equal(v1.varinfo, v2.varinfo) end -""" -Return the value of `vn` in `vi`. If one doesn't exist, sample and set it. -""" -function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) - if !haskey(vi, vn) - r = rand(dist) - push!!(vi, vn, r, dist) - r - else - vi[vn] - end -end - @testset "varinfo.jl" begin - @testset "VarInfo with NT of Metadata" begin - @model gdemo(x, y) = begin - s ~ InverseGamma(2, 3) - m ~ truncated(Normal(0.0, sqrt(s)), 0.0, 2.0) - x ~ Normal(m, sqrt(s)) - y ~ Normal(m, sqrt(s)) - end - model = gdemo(1.0, 2.0) - - _, vi = DynamicPPL.init!!(model, VarInfo(), InitFromUniform()) - tvi = DynamicPPL.typed_varinfo(vi) - - meta = vi.metadata - for f in fieldnames(typeof(tvi.metadata)) - fmeta = getfield(tvi.metadata, f) - for vn in fmeta.vns - @test tvi[vn] == vi[vn] - ind = meta.idcs[vn] - tind = fmeta.idcs[vn] - @test meta.dists[ind] == fmeta.dists[tind] - @test meta.is_transformed[ind] == fmeta.is_transformed[tind] - range = meta.ranges[ind] - trange = fmeta.ranges[tind] - @test all(meta.vals[range] .== fmeta.vals[trange]) - end - end - end - @testset "Base" begin # Test Base functions: - # in, keys, haskey, isempty, push!!, empty!!, + # in, keys, haskey, isempty, setindex!!, empty!!, # getindex, setindex!, getproperty, setproperty! - function test_base(vi_original) - vi = deepcopy(vi_original) - @test getlogjoint(vi) == 0 - @test isempty(vi[:]) - - vn = @varname x - dist = Normal(0, 1) - r = rand(dist) - - @test isempty(vi) - @test !haskey(vi, vn) - @test !(vn in keys(vi)) - vi = push!!(vi, vn, r, dist) - @test !isempty(vi) - @test haskey(vi, vn) - @test vn in keys(vi) - - @test length(vi[vn]) == 1 - @test vi[vn] == r - @test vi[:] == [r] - vi = DynamicPPL.setindex!!(vi, 2 * r, vn) - @test vi[vn] == 2 * r - @test vi[:] == [2 * r] - - # TODO(mhauru) Implement these functions for other VarInfo types too. - if vi isa DynamicPPL.UntypedVectorVarInfo - delete!(vi, vn) - @test isempty(vi) - vi = push!!(vi, vn, r, dist) - end - - vi = empty!!(vi) - @test isempty(vi) - vi = push!!(vi, vn, r, dist) - @test !isempty(vi) - end - - test_base(VarInfo()) - test_base(DynamicPPL.typed_varinfo(VarInfo())) - test_base(SimpleVarInfo()) - test_base(SimpleVarInfo(OrderedDict{VarName,Any}())) - test_base(SimpleVarInfo(DynamicPPL.VarNamedVector())) + vi = VarInfo() + @test getlogjoint(vi) == 0 + @test isempty(vi[:]) + + vn = @varname x + r = rand() + + @test isempty(vi) + @test !haskey(vi, vn) + @test !(vn in keys(vi)) + vi = setindex!!(vi, r, vn) + @test !isempty(vi) + @test haskey(vi, vn) + @test vn in keys(vi) + + @test length(vi[vn]) == 1 + @test vi[vn] == r + @test vi[:] == [r] + vi = DynamicPPL.setindex!!(vi, 2 * r, vn) + @test vi[vn] == 2 * r + @test vi[:] == [2 * r] + + vi = empty!!(vi) + @test isempty(vi) + vi = setindex!!(vi, r, vn) + @test !isempty(vi) end @testset "get/set/acclogp" begin - function test_varinfo_logp!(vi) - @test DynamicPPL.getlogjoint(vi) === 0.0 - vi = DynamicPPL.setlogprior!!(vi, 1.0) - @test DynamicPPL.getlogprior(vi) === 1.0 - @test DynamicPPL.getloglikelihood(vi) === 0.0 - @test DynamicPPL.getlogjoint(vi) === 1.0 - vi = DynamicPPL.acclogprior!!(vi, 1.0) - @test DynamicPPL.getlogprior(vi) === 2.0 - @test DynamicPPL.getloglikelihood(vi) === 0.0 - @test DynamicPPL.getlogjoint(vi) === 2.0 - vi = DynamicPPL.setloglikelihood!!(vi, 1.0) - @test DynamicPPL.getlogprior(vi) === 2.0 - @test DynamicPPL.getloglikelihood(vi) === 1.0 - @test DynamicPPL.getlogjoint(vi) === 3.0 - vi = DynamicPPL.accloglikelihood!!(vi, 1.0) - @test DynamicPPL.getlogprior(vi) === 2.0 - @test DynamicPPL.getloglikelihood(vi) === 2.0 - @test DynamicPPL.getlogjoint(vi) === 4.0 - end - vi = VarInfo() - test_varinfo_logp!(vi) - test_varinfo_logp!(DynamicPPL.typed_varinfo(vi)) - test_varinfo_logp!(SimpleVarInfo()) - test_varinfo_logp!(SimpleVarInfo(OrderedDict())) - test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) + @test DynamicPPL.getlogjoint(vi) === 0.0 + vi = DynamicPPL.setlogprior!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 1.0 + @test DynamicPPL.getloglikelihood(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 1.0 + vi = DynamicPPL.acclogprior!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 2.0 + vi = DynamicPPL.setloglikelihood!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 1.0 + @test DynamicPPL.getlogjoint(vi) === 3.0 + vi = DynamicPPL.accloglikelihood!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 2.0 + @test DynamicPPL.getlogjoint(vi) === 4.0 end @testset "logp accumulators" begin @@ -163,7 +84,7 @@ end lp_d = logpdf(Normal(), values.d) m = demo() | (; c=values.c, d=values.d) - vi = DynamicPPL.unflatten(VarInfo(m), collect(values)) + vi = DynamicPPL.unflatten!!(VarInfo(m), collect(values)) vi = last(DynamicPPL.evaluate!!(m, deepcopy(vi))) @test getlogprior(vi) == lp_a + lp_b @@ -297,39 +218,23 @@ end end @testset "is_transformed flag" begin - # Test is_transformed and set_transformed!! - function test_varinfo!(vi) - vn_x = @varname x - dist = Normal(0, 1) - r = rand(dist) - - push!!(vi, vn_x, r, dist) + vi = VarInfo() + vn_x = @varname x + r = rand() - # is_transformed is set by default - @test !is_transformed(vi, vn_x) + vi = setindex!!(vi, r, vn_x) - vi = set_transformed!!(vi, true, vn_x) - @test is_transformed(vi, vn_x) + # is_transformed is unset by default + @test !is_transformed(vi, vn_x) - vi = set_transformed!!(vi, false, vn_x) - @test !is_transformed(vi, vn_x) - end - vi = VarInfo() - test_varinfo!(vi) - test_varinfo!(empty!!(DynamicPPL.typed_varinfo(vi))) - end + vi = set_transformed!!(vi, true, vn_x) + @test is_transformed(vi, vn_x) - @testset "push!! to VarInfo with NT of Metadata" begin - vn_x = @varname x - vn_y = @varname y - untyped_vi = VarInfo() - untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1)) - typed_vi = DynamicPPL.typed_varinfo(untyped_vi) - typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1)) - @test typed_vi[vn_x] == 1.0 - @test typed_vi[vn_y] == 2.0 + vi = set_transformed!!(vi, false, vn_x) + @test !is_transformed(vi, vn_x) end + # TODO(mhauru) Move this to a different file. @testset "returned on MCMCChains.Chains" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS chain = make_chain_from_prior(model, 10) @@ -367,39 +272,25 @@ end # change the VarInfo object. # TODO(penelopeysm): Move this to InitFromUniform tests rather than here. vi = VarInfo() - meta = vi.metadata _, vi = DynamicPPL.init!!(model, vi, InitFromUniform()) - @test all(x -> !is_transformed(vi, x), meta.vns) + vals = values(vi) + + all_transformed(vi) = + mapreduce(p -> is_transformed(p.second), &, vi.values; init=true) + any_transformed(vi) = + mapreduce(p -> is_transformed(p.second), |, vi.values; init=false) + + @test !any_transformed(vi) # Check that linking and invlinking set the `is_transformed` flag accordingly - v = copy(meta.vals) vi = link!!(vi, model) - @test all(x -> is_transformed(vi, x), meta.vns) + @test all_transformed(vi) vi = invlink!!(vi, model) - @test all(x -> !is_transformed(vi, x), meta.vns) - @test meta.vals ≈ v atol = 1e-10 - - # Check that linking and invlinking preserves the values - vi = DynamicPPL.typed_varinfo(vi) - meta = vi.metadata - v_s = copy(meta.s.vals) - v_m = copy(meta.m.vals) - v_x = copy(meta.x.vals) - v_y = copy(meta.y.vals) - - @test all(x -> !is_transformed(vi, x), meta.s.vns) - @test all(x -> !is_transformed(vi, x), meta.m.vns) - vi = link!!(vi, model) - @test all(x -> is_transformed(vi, x), meta.s.vns) - @test all(x -> is_transformed(vi, x), meta.m.vns) - vi = invlink!!(vi, model) - @test all(x -> !is_transformed(vi, x), meta.s.vns) - @test all(x -> !is_transformed(vi, x), meta.m.vns) - @test meta.s.vals ≈ v_s atol = 1e-10 - @test meta.m.vals ≈ v_m atol = 1e-10 + @test !any_transformed(vi) + @test values(vi) ≈ vals atol = 1e-10 # Transform only one variable - all_vns = vcat(meta.s.vns, meta.m.vns, meta.x.vns, meta.y.vns) + all_vns = keys(vi) for vn in [ @varname(s), @varname(m), @@ -413,14 +304,11 @@ end @test !isempty(target_vns) @test !isempty(other_vns) vi = link!!(vi, (vn,), model) - @test all(x -> is_transformed(vi, x), target_vns) - @test all(x -> !is_transformed(vi, x), other_vns) + @test all_transformed(subset(vi, target_vns)) + @test !any_transformed(subset(vi, other_vns)) vi = invlink!!(vi, (vn,), model) - @test all(x -> !is_transformed(vi, x), all_vns) - @test meta.s.vals ≈ v_s atol = 1e-10 - @test meta.m.vals ≈ v_m atol = 1e-10 - @test meta.x.vals ≈ v_x atol = 1e-10 - @test meta.y.vals ≈ v_y atol = 1e-10 + @test !any_transformed(vi) + @test values(vi) ≈ vals atol = 1e-10 end end @@ -430,46 +318,17 @@ end vn = @varname(x) dist = truncated(Normal(); lower=0) - function test_linked_varinfo(model, vi) - # vn and dist are taken from the containing scope - vi = last(DynamicPPL.init!!(model, vi, InitFromPrior())) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test is_transformed(vi, vn) - @test getlogjoint_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) - @test getlogprior_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) - @test getloglikelihood(vi) == 0.0 - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) - @test getlogprior(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) - end - - ### `VarInfo` - # Need to run once since we can't specify that we want to _sample_ - # in the unconstrained space for `VarInfo` without having `vn` - # present in the `varinfo`. - - ## `untyped_varinfo` - vi = DynamicPPL.untyped_varinfo(model) - vi = DynamicPPL.set_transformed!!(vi, true, vn) - test_linked_varinfo(model, vi) - - ## `typed_varinfo` - vi = DynamicPPL.typed_varinfo(model) + vi = DynamicPPL.VarInfo(model) vi = DynamicPPL.set_transformed!!(vi, true, vn) - test_linked_varinfo(model, vi) - - ### `SimpleVarInfo` - ## `SimpleVarInfo{<:NamedTuple}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(), true) - test_linked_varinfo(model, vi) - - ## `SimpleVarInfo{<:Dict}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true) - test_linked_varinfo(model, vi) - - ## `SimpleVarInfo{<:VarNamedVector}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - test_linked_varinfo(model, vi) + vi = last(DynamicPPL.init!!(model, vi, InitFromPrior())) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) + @test is_transformed(vi, vn) + @test getlogjoint_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogprior_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getloglikelihood(vi) == 0.0 + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) + @test getlogprior(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) end @testset "values_as" begin @@ -484,38 +343,22 @@ end @testset "$(short_varinfo_name(vi))" for vi in varinfos # Just making sure. DynamicPPL.TestUtils.test_values(vi, example_values, vns) - - @testset "NamedTuple" begin - vals = values_as(vi, NamedTuple) - for vn in vns - if haskey(vals, Symbol(vn)) - # Assumed to be of form `(var"m[1]" = 1.0, ...)`. - @test getindex(vals, Symbol(vn)) == getindex(vi, vn) - else - # Assumed to be of form `(m = [1.0, ...], ...)`. - @test get(vals, vn) == getindex(vi, vn) - end - end + vals = values_as(vi, OrderedDict) + # All varnames in `vns` should be subsumed by one of `keys(vals)`. + @test all(vns) do vn + any(DynamicPPL.subsumes(vn_left, vn) for vn_left in keys(vals)) end - - @testset "OrderedDict" begin - vals = values_as(vi, OrderedDict) - # All varnames in `vns` should be subsumed by one of `keys(vals)`. - @test all(vns) do vn - any(DynamicPPL.subsumes(vn_left, vn) for vn_left in keys(vals)) - end - # Iterate over `keys(vals)` because we might have scenarios such as - # `vals = OrderedDict(@varname(m) => [1.0])` but `@varname(m[1])` is - # the varname present in `vns`, not `@varname(m)`. - for vn in keys(vals) - @test getindex(vals, vn) == getindex(vi, vn) - end + # Iterate over `keys(vals)` because we might have scenarios such as + # `vals = OrderedDict(@varname(m) => [1.0])` but `@varname(m[1])` is + # the varname present in `vns`, not `@varname(m)`. + for vn in keys(vals) + @test getindex(vals, vn) == getindex(vi, vn) end end end end - @testset "unflatten + linking" begin + @testset "unflatten!! + linking" begin @testset "Model: $(model.f)" for model in [ DynamicPPL.TestUtils.demo_one_variable_multiple_constraints(), DynamicPPL.TestUtils.demo_lkjchol(), @@ -527,26 +370,6 @@ end model, value_true, varnames; include_threadsafe=true ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} - # NOTE: this is broken since we'll end up trying to set - # - # varinfo[@varname(x[4:5])] = [x[4],] - # - # upon linking (since `x[4:5]` will be projected onto a 1-dimensional - # space). In the case of `SimpleVarInfo{<:NamedTuple}`, this results in - # calling `setindex!!(varinfo.values, [x[4],], @varname(x[4:5]))`, which - # in turn attempts to call `setindex!(varinfo.values.x, [x[4],], 4:5)`, - # i.e. a vector of length 1 (`[x[4],]`) being assigned to 2 indices (`4:5`). - @test_broken false - continue - end - - if DynamicPPL.has_varnamedvector(varinfo) && mutating - # NOTE: Can't handle mutating `link!` and `invlink!` `VarNamedVector`. - @test_broken false - continue - end - # Evaluate the model once to update the logp of the varinfo. varinfo = last(DynamicPPL.evaluate!!(model, varinfo)) @@ -559,8 +382,8 @@ end @test DynamicPPL.is_transformed(varinfo_linked, vn) end @test length(varinfo[:]) > length(varinfo_linked[:]) - varinfo_linked_unflattened = DynamicPPL.unflatten( - varinfo_linked, varinfo_linked[:] + varinfo_linked_unflattened = DynamicPPL.unflatten!!( + copy(varinfo_linked), varinfo_linked[:] ) @test length(varinfo_linked_unflattened[:]) == length(varinfo_linked[:]) @@ -592,7 +415,7 @@ end end end - @testset "unflatten type stability" begin + @testset "unflatten!! type stability" begin @model function demo(y) x ~ Normal() y ~ Normal(x, 1) @@ -604,13 +427,7 @@ end model, (; x=1.0), (@varname(x),); include_threadsafe=true ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - # Skip the inconcrete `SimpleVarInfo` types, since checking for type - # stability for them doesn't make much sense anyway. - if varinfo isa SimpleVarInfo{<:AbstractDict} || - varinfo isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{<:AbstractDict}} - continue - end - @inferred DynamicPPL.unflatten(varinfo, varinfo[:]) + @inferred DynamicPPL.unflatten!!(varinfo, varinfo[:]) end end @@ -630,13 +447,9 @@ end varinfos = DynamicPPL.TestUtils.setup_varinfos( model, model(), vns; include_threadsafe=true ) - varinfos_standard = filter(Base.Fix2(isa, VarInfo), varinfos) - varinfos_simple = filter( - Base.Fix2(isa, DynamicPPL.SimpleOrThreadSafeSimple), varinfos - ) # `VarInfo` supports subsetting using, basically, arbitrary varnames. - vns_supported_standard = [ + vns_supported = [ [@varname(s)], [@varname(m)], [@varname(x[1])], @@ -661,25 +474,10 @@ end [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], ] - # `SimpleVarInfo` only supports subsetting using the varnames as they appear - # in the model. - vns_supported_simple = filter(∈(vns), vns_supported_standard) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos # All variables. check_varinfo_keys(varinfo, vns) - # Added a `convert` to make the naming of the testsets a bit more readable. - # `SimpleVarInfo{<:NamedTuple}` only supports subsetting with "simple" varnames, - ## i.e. `VarName{sym}()` without any indexing, etc. - vns_supported = - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple && - values_as(varinfo) isa NamedTuple - vns_supported_simple - else - vns_supported_standard - end - @testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in vns_supported varinfo_subset = subset(varinfo, VarName[]) @@ -731,15 +529,6 @@ end @test varinfo_subset[:] == ground_truth end end - - # For certain varinfos we should have errors. - # `SimpleVarInfo{<:NamedTuple}` can only handle varnames with `identity`. - varinfo = varinfos[findfirst(Base.Fix2(isa, SimpleVarInfo{<:NamedTuple}), varinfos)] - @testset "$(short_varinfo_name(varinfo)): failure cases" begin - @test_throws ArgumentError subset( - varinfo, [@varname(s), @varname(m), @varname(x[1])] - ) - end end @testset "merge" begin @@ -838,9 +627,9 @@ end @testset "merge different dimensions" begin vn = @varname(x) vi_single = VarInfo() - vi_single = push!!(vi_single, vn, 1.0, Normal()) + vi_single = setindex!!(vi_single, 1.0, vn) vi_double = VarInfo() - vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0)) + vi_double = setindex!!(vi_double, [0.5, 0.6], vn) @test merge(vi_single, vi_double)[vn] == [0.5, 0.6] @test merge(vi_double, vi_single)[vn] == 1.0 end @@ -851,8 +640,9 @@ end n = length(varinfo[:]) # `Bool`. - @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1)) + @test getlogjoint(DynamicPPL.unflatten!!(varinfo, fill(true, n))) isa + typeof(float(1)) # `Int`. - @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1)) + @test getlogjoint(DynamicPPL.unflatten!!(varinfo, fill(1, n))) isa typeof(float(1)) end end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl new file mode 100644 index 000000000..68ab834a6 --- /dev/null +++ b/test/varnamedtuple.jl @@ -0,0 +1,1125 @@ +module VarNamedTupleTests + +using Combinatorics: Combinatorics +using OrderedCollections: OrderedDict +using Test: @inferred, @test, @test_throws, @testset +using DynamicPPL: DynamicPPL, @varname, VarNamedTuple, subset +using DynamicPPL.VarNamedTuples: + PartialArray, ArrayLikeBlock, map_pairs!!, map_values!!, apply!! +using AbstractPPL: AbstractPPL, VarName, concretize, prefix +using BangBang: setindex!!, empty!! + +""" + test_invariants(vnt::VarNamedTuple; skip=()) + +Test properties that should hold for all VarNamedTuples. + +Uses @test for all the tests. Intended to be called inside a @testset. + +`skip` is a tuple of symbols indicating which tests are to be skipped. +""" +function test_invariants(vnt::VarNamedTuple; skip=()) + # These will be needed repeatedly. + vnt_keys = keys(vnt) + vnt_values = values(vnt) + + # Check that for all keys in vnt, haskey is true, and resetting the value is a no-op. + for k in vnt_keys + @test haskey(vnt, k) + v = getindex(vnt, k) + # ArrayLikeBlocks and PartialArrays are implementation details, and should not be + # exposed through getindex. + @test !(v isa ArrayLikeBlock) + @test !(v isa PartialArray) + vnt2 = setindex!!(copy(vnt), v, k) + equality = (vnt == vnt2) + # The value may be `missing` if vnt itself has values that are missing. + @test equality === true || equality === missing + @test isequal(vnt, vnt2) + @test hash(vnt) == hash(vnt2) + end + + # Check that the printed representation can be parsed back to an equal VarNamedTuple. + # The below eval test is a bit fragile: If any elements in vnt don't respect the same + # reconstructability-from-repr property, this will fail. Likewise if any element uses + # in its repr print out types that are not in scope in this module, it will fail. + if !(:parseeval in skip) + vnt3 = eval(Meta.parse(repr(vnt))) + equality = (vnt == vnt3) + # The value may be `missing` if vnt itself has values that are missing. + @test equality === true || equality === missing + @test isequal(vnt, vnt3) + @test hash(vnt) == hash(vnt3) + end + + # Check that merge with an empty VarNamedTuple is a no-op. + @test isequal(merge(vnt, VarNamedTuple()), vnt) + @test isequal(merge(VarNamedTuple(), vnt), vnt) + + # Check that the VNT can be constructed back from its keys and values. + vnt4 = VarNamedTuple() + for (k, v) in zip(vnt_keys, vnt_values) + vnt4 = setindex!!(vnt4, v, k) + end + @test isequal(vnt, vnt4) + + # Check that vnt isempty only if it has no keys + was_empty = isempty(vnt) + @test isequal(was_empty, isempty(vnt_keys)) + @test isequal(was_empty, isempty(vnt_values)) + + # Check that vnt can be emptied + @test empty(vnt) === VarNamedTuple() + emptied_vnt = empty!!(copy(vnt)) + @test isempty(emptied_vnt) + @test isempty(keys(emptied_vnt)) + @test isempty(values(emptied_vnt)) + + # Check that the copy protected the original vnt from being modified. + @test isempty(vnt) == was_empty + + # Check that map is a no-op when using identity functions. + @test isequal(map_pairs!!(pair -> pair.second, copy(vnt)), vnt) + @test isequal(map_values!!(identity, copy(vnt)), vnt) + + # Check that subsetting works as expected. + @test isequal(subset(vnt, vnt_keys), vnt) + @test isequal(subset(vnt, VarName[]), VarNamedTuple()) +end + +""" A type that has a size but is not an Array. Used in ArrayLikeBlock tests.""" +struct SizedThing{T<:Tuple} + size::T +end +Base.size(st::SizedThing) = st.size + +@testset "VarNamedTuple" begin + @testset "Construction" begin + vnt1 = VarNamedTuple() + test_invariants(vnt1) + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + vnt1 = setindex!!(vnt1, [1, 2, 3], @varname(b)) + vnt1 = setindex!!(vnt1, "a", @varname(c.d.e)) + test_invariants(vnt1) + + vnt2 = VarNamedTuple(; + a=1.0, b=[1, 2, 3], c=VarNamedTuple(; d=VarNamedTuple(; e="a")) + ) + test_invariants(vnt2) + @test vnt1 == vnt2 + + vnt3 = VarNamedTuple((; + a=1.0, b=[1, 2, 3], c=VarNamedTuple((; d=VarNamedTuple((; e="a")))) + )) + test_invariants(vnt3) + @test vnt1 == vnt3 + + vnt4 = VarNamedTuple( + OrderedDict( + @varname(a) => 1.0, @varname(b) => [1, 2, 3], @varname(c.d.e) => "a" + ), + ) + test_invariants(vnt4) + @test vnt1 == vnt4 + + vnt5 = VarNamedTuple(( + (@varname(a), 1.0), (@varname(b), [1, 2, 3]), (@varname(c.d.e), "a") + )) + test_invariants(vnt5) + @test vnt1 == vnt5 + + pa1 = PartialArray{Float64,1}() + pa1 = setindex!!(pa1, 1.0, 16) + pa2 = PartialArray{Float64,1}(; min_size=(16,)) + pa2 = setindex!!(pa2, 1.0, 16) + pa3 = PartialArray{Float64,1}(16 => 1.0) + pa4 = PartialArray{Float64,1}((16,) => 1.0) + @test pa1 == pa2 + @test pa1 == pa3 + @test pa1 == pa4 + + pa1 = PartialArray{String,3}() + pa1 = setindex!!(pa1, "a", 2, 3, 4) + pa1 = setindex!!(pa1, "b", 1, 2, 4) + pa2 = PartialArray{String,3}(; min_size=(16, 16, 16)) + pa2 = setindex!!(pa2, "a", 2, 3, 4) + pa2 = setindex!!(pa2, "b", 1, 2, 4) + pa3 = PartialArray{String,3}((2, 3, 4) => "a", (1, 2, 4) => "b") + @test pa1 == pa2 + @test pa1 == pa3 + + @test_throws BoundsError PartialArray{Int,1}((0,) => 1) + @test_throws BoundsError PartialArray{Int,1}((1, 2) => 1) + @test_throws MethodError PartialArray{Int,1}((1,) => "a") + @test_throws MethodError PartialArray{Int,1}((1,) => 1; min_size=(2, 2)) + end + + @testset "Basic sets and gets" begin + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 32.0, @varname(a))) + @test @inferred(getindex(vnt, @varname(a))) == 32.0 + @test haskey(vnt, @varname(a)) + @test !haskey(vnt, @varname(b)) + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(b))) + @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + @test @inferred(getindex(vnt, @varname(b[2]))) == 2 + @test haskey(vnt, @varname(b)) + @test haskey(vnt, @varname(b[1])) + @test haskey(vnt, @varname(b[1:3])) + @test !haskey(vnt, @varname(b[4])) + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, 64.0, @varname(a))) + @test @inferred(getindex(vnt, @varname(a))) == 64.0 + @test @inferred(getindex(vnt, @varname(b))) == [1, 2, 3] + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, 15, @varname(b[2]))) + @test @inferred(getindex(vnt, @varname(b))) == [1, 15, 3] + @test @inferred(getindex(vnt, @varname(b[2]))) == 15 + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, [10], @varname(c.x.y))) + @test @inferred(getindex(vnt, @varname(c.x.y))) == [10] + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, 11, @varname(c.x.y[1]))) + @test @inferred(getindex(vnt, @varname(c.x.y))) == [11] + @test @inferred(getindex(vnt, @varname(c.x.y[1]))) == 11 + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, -1.0, @varname(d[4]))) + @test @inferred(getindex(vnt, @varname(d[4]))) == -1.0 + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, -2.0, @varname(d[4]))) + @test @inferred(getindex(vnt, @varname(d[4]))) == -2.0 + test_invariants(vnt) + + # These can't be @inferred because `d` now has an abstract element type. Note that this + # does not ruin type stability for other varnames that don't involve `d`. + vnt = setindex!!(vnt, "a", @varname(d[5])) + @test getindex(vnt, @varname(d[5])) == "a" + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, 1.0, @varname(e.f[3].g.h[2].i))) + @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 1.0 + @test haskey(vnt, @varname(e.f[3].g.h[2].i)) + @test !haskey(vnt, @varname(e.f[2].g.h[2].i)) + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i))) + @test @inferred(getindex(vnt, @varname(e.f[3].g.h[2].i))) == 2.0 + test_invariants(vnt) + + vec = fill(1.0, 4) + vnt = @inferred(setindex!!(vnt, vec, @varname(j[1:4]))) + @test @inferred(getindex(vnt, @varname(j[1:4]))) == vec + @test @inferred(getindex(vnt, @varname(j[2]))) == vec[2] + @test haskey(vnt, @varname(j[4])) + @test !haskey(vnt, @varname(j[5])) + @test_throws BoundsError getindex(vnt, @varname(j[5])) + test_invariants(vnt) + + vec = fill(2.0, 4) + vnt = @inferred(setindex!!(vnt, vec, @varname(j[2:5]))) + @test @inferred(getindex(vnt, @varname(j[1]))) == 1.0 + @test @inferred(getindex(vnt, @varname(j[2:5]))) == vec + @test haskey(vnt, @varname(j[5])) + test_invariants(vnt) + + arr = fill(2.0, (4, 2)) + vn = @varname(k.l[2:5, 3, 1:2, 2]) + vnt = @inferred(setindex!!(vnt, arr, vn)) + @test @inferred(getindex(vnt, vn)) == arr + # A subset of the elements set just now. + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + test_invariants(vnt) + + # Not enough, or too many, indices. + @test_throws BoundsError setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3])) + @test_throws BoundsError setindex!!(vnt, 0.0, @varname(k.l[1, 2, 3, 4, 5])) + + arr = fill(3.0, (3, 3)) + vn = @varname(k.l[1, 1:3, 1:3, 1]) + vnt = @inferred(setindex!!(vnt, arr, vn)) + @test @inferred(getindex(vnt, vn)) == arr + # A subset of the elements set just now. + @test @inferred(getindex(vnt, @varname(k.l[1, 1:2, 1:2, 1]))) == fill(3.0, 2, 2) + # A subset of the elements set previously. + @test @inferred(getindex(vnt, @varname(k.l[2, 3, 1:2, 2]))) == fill(2.0, 2) + @test !haskey(vnt, @varname(k.l[2, 3, 3, 2])) + test_invariants(vnt) + + vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[2]))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(m[3]))) + @test @inferred(getindex(vnt, @varname(m[2:3]))) == [1.0, 1.0] + @test !haskey(vnt, @varname(m[1])) + test_invariants(vnt) + + # The below tests are mostly significant for the type stability aspect. For the last + # test to pass, PartialArray needs to actively tighten its eltype when possible. + vnt = @inferred(setindex!!(vnt, 1.0, @varname(n[1].a))) + @test @inferred(getindex(vnt, @varname(n[1].a))) == 1.0 + vnt = @inferred(setindex!!(vnt, 1.0, @varname(n[2].a))) + @test @inferred(getindex(vnt, @varname(n[2].a))) == 1.0 + # This can't be type stable, because n[1] has inhomogeneous types. + vnt = setindex!!(vnt, 1.0, @varname(n[1].b)) + @test getindex(vnt, @varname(n[1].b)) == 1.0 + # The setindex!! call can't be type stable either, but it should return a + # VarNamedTuple with a concrete element type, and hence getindex can be inferred. + vnt = setindex!!(vnt, 1.0, @varname(n[2].b)) + @test @inferred(getindex(vnt, @varname(n[2].b))) == 1.0 + test_invariants(vnt) + + # Some funky Symbols in VarNames + # TODO(mhauru) This still isn't as robust as it should be, for instance Symbol(":") + # fails the eval(Meta.parse(print(vnt))) == vnt test because NamedTuple show doesn't + # respect the eval-property. + vn1 = VarName{Symbol("a b c")}() + vnt = @inferred(setindex!!(vnt, 2, vn1)) + @test @inferred(getindex(vnt, vn1)) == 2 + test_invariants(vnt) + vn2 = VarName{Symbol("1")}() + vnt = @inferred(setindex!!(vnt, 3, vn2)) + @test @inferred(getindex(vnt, vn2)) == 3 + test_invariants(vnt) + vn3 = VarName{Symbol("?!")}() + vnt = @inferred(setindex!!(vnt, 4, vn3)) + @test @inferred(getindex(vnt, vn3)) == 4 + test_invariants(vnt) + vnt = VarNamedTuple() + vn4 = prefix(prefix(vn1, vn2), vn3) + vnt = @inferred(setindex!!(vnt, 5, vn4)) + @test @inferred(getindex(vnt, vn4)) == 5 + test_invariants(vnt) + vn5 = prefix(prefix(vn3, vn2), vn1) + vnt = @inferred(setindex!!(vnt, 6, vn5)) + @test @inferred(getindex(vnt, vn5)) == 6 + test_invariants(vnt) + + # TODO(penelopeysm) Colon tests fail + # vnt = VarNamedTuple() + # x = [1, 2, 3] + # vn = concretize(@varname(y[:]), x) + # vnt = @inferred(setindex!!(vnt, x, vn)) + # @test haskey(vnt, vn) + # @test @inferred(getindex(vnt, vn)) == x + # test_invariants(vnt) + + # vnt = VarNamedTuple() + # vnt = @inferred(setindex!!(vnt, SizedThing((3,)), vn)) + # @test haskey(vnt, vn) + # @test vn in keys(vnt) + # @test @inferred(getindex(vnt, vn)) == SizedThing((3,)) + # # TODO(mhauru) The below skip is needed because AbstractPPL's ConretizedSlice + # # objects don't respect the eval(Meta.parse(repr(...))) == ... property. + # test_invariants(vnt; skip=(:parseeval,)) + + # TODO(penelopeysm) Colon tests fail + # vnt = VarNamedTuple() + # y = fill("a", (3, 2, 4)) + # x = y[:, 2, :] + # a = (; b=[nothing, nothing, (; c=(; d=reshape(y, (1, 3, 2, 4, 1))))]) + # vn = @varname(a.b[3].c.d[1, 3:5, 2, :, 1]) + # vn = concretize(vn, a) + # vnt = @inferred(setindex!!(vnt, x, vn)) + # @test haskey(vnt, vn) + # @test @inferred(getindex(vnt, vn)) == x + # test_invariants(vnt) + + # Indices on indices + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1, @varname(a[1][1]))) + @test @inferred(getindex(vnt, @varname(a[1][1]))) == 1 + vnt = @inferred(setindex!!(vnt, 1, @varname(ab[1:2][1]))) + @test @inferred(getindex(vnt, @varname(a[1][1]))) == 1 + vnt = @inferred(setindex!!(vnt, [1], @varname(b[1].c[1]))) + @test @inferred(getindex(vnt, @varname(b[1].c[1]))) == [1] + vnt = @inferred(setindex!!(vnt, [1], @varname(e[3, 2].f[2, 2][10, 10]))) + @test @inferred(getindex(vnt, @varname(e[3, 2].f[2, 2][10, 10]))) == [1] + vnt = @inferred(setindex!!(vnt, [1], @varname(g[3, 2][10, 10].h[2, 2]))) + @test @inferred(getindex(vnt, @varname(g[3, 2][10, 10].h[2, 2]))) == [1] + end + + @testset "equality and hash" begin + # Test all combinations of having or not having the below values set, and having + # them set to any of the possible_values, and check that isequal and == return the + # expected value. + # NOTE: Be very careful adding new values to these sets. The below test has three + # nested loops over Combinatorics.combinations, the run time can explode very, very + # quickly. + varnames = (@varname(b[1]), @varname(b[3]), @varname(c.d[2].e)) + possible_values = (missing, 1, -0.0, 0.0) + for vn_set in Combinatorics.combinations(varnames) + valuesets1 = Combinatorics.with_replacement_combinations( + possible_values, length(vn_set) + ) + valuesets2 = Combinatorics.with_replacement_combinations( + possible_values, length(vn_set) + ) + for vset1 in valuesets1, vset2 in valuesets2 + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + expected_isequal = true + expected_doubleequal = true + for (vn, v1, v2) in zip(vn_set, vset1, vset2) + vnt1 = setindex!!(vnt1, v1, vn) + vnt2 = setindex!!(vnt2, v2, vn) + expected_isequal = expected_isequal & isequal(v1, v2) + expected_doubleequal = expected_doubleequal & (v1 == v2) + end + test_invariants(vnt1) + test_invariants(vnt2) + @test isequal(vnt1, vnt2) == expected_isequal + @test (vnt1 == vnt2) === expected_doubleequal + if expected_isequal + @test hash(vnt1) == hash(vnt2) + end + end + end + end + + @testset "merge" begin + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + expected_merge = VarNamedTuple() + @test @inferred(merge(vnt1, vnt2)) == expected_merge + + vnt1 = setindex!!(vnt1, 1.0, @varname(a)) + vnt2 = setindex!!(vnt2, 2.0, @varname(b)) + vnt1 = setindex!!(vnt1, 1, @varname(c)) + vnt2 = setindex!!(vnt2, 2, @varname(c)) + expected_merge = setindex!!(expected_merge, 1.0, @varname(a)) + expected_merge = setindex!!(expected_merge, 2, @varname(c)) + expected_merge = setindex!!(expected_merge, 2.0, @varname(b)) + @test @inferred(merge(vnt1, vnt2)) == expected_merge + test_invariants(vnt1) + test_invariants(vnt2) + + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + expected_merge = VarNamedTuple() + vnt1 = setindex!!(vnt1, [1], @varname(d.a)) + vnt2 = setindex!!(vnt2, [2, 2], @varname(d.b)) + vnt1 = setindex!!(vnt1, [1], @varname(d.c)) + vnt2 = setindex!!(vnt2, [2, 2], @varname(d.c)) + expected_merge = setindex!!(expected_merge, [1], @varname(d.a)) + expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.c)) + expected_merge = setindex!!(expected_merge, [2, 2], @varname(d.b)) + @test @inferred(merge(vnt1, vnt2)) == expected_merge + + vnt1 = setindex!!(vnt1, 1, @varname(e.a[1])) + vnt2 = setindex!!(vnt2, 2, @varname(e.a[2])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.a[1])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.a[2])) + vnt1 = setindex!!(vnt1, 1, @varname(e.a[3])) + vnt2 = setindex!!(vnt2, 2, @varname(e.a[3])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.a[3])) + @test @inferred(merge(vnt1, vnt2)) == expected_merge + + vnt1 = setindex!!(vnt1, fill(1, 4), @varname(e.a[7:10])) + vnt2 = setindex!!(vnt2, fill(2, 4), @varname(e.a[8:11])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.a[7])) + expected_merge = setindex!!(expected_merge, fill(2, 4), @varname(e.a[8:11])) + @test @inferred(merge(vnt1, vnt2)) == expected_merge + + vnt1 = setindex!!(vnt1, 1, @varname(e.b[1][13])) + vnt2 = setindex!!(vnt2, 2, @varname(e.b[2][13])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.b[1][13])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.b[2][13])) + vnt1 = setindex!!(vnt1, 1, @varname(e.b[3][13])) + vnt2 = setindex!!(vnt2, 2, @varname(e.b[3][13])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.b[3][13])) + @test @inferred(merge(vnt1, vnt2)) == expected_merge + vnt1 = setindex!!(vnt1, 1, @varname(e.b[4][13])) + vnt2 = setindex!!(vnt2, 2, @varname(e.b[4][14])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.b[4][13])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.b[4][14])) + @test @inferred(merge(vnt1, vnt2)) == expected_merge + + vnt1 = setindex!!(vnt1, ["1", "1"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) + vnt2 = setindex!!(vnt2, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) + expected_merge = setindex!!( + expected_merge, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4]) + ) + vnt1 = setindex!!(vnt1, :1, @varname(f.a[1].b.c[3, 2].d[1, 1][14, 13])) + vnt2 = setindex!!(vnt2, :2, @varname(f.a[1].b.c[4, 2].d[1, 1][14, 13])) + expected_merge = setindex!!( + expected_merge, :1, @varname(f.a[1].b.c[3, 2].d[1, 1][14, 13]) + ) + expected_merge = setindex!!( + expected_merge, :2, @varname(f.a[1].b.c[4, 2].d[1, 1][14, 13]) + ) + @test merge(vnt1, vnt2) == expected_merge + test_invariants(vnt1) + test_invariants(vnt2) + + # PartialArrays with different sizes. + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + vnt1 = setindex!!(vnt1, 1, @varname(a[1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[257])) + vnt2 = setindex!!(vnt2, 2, @varname(a[1])) + vnt2 = setindex!!(vnt2, 2, @varname(a[2])) + expected_merge_12 = VarNamedTuple() + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257])) + expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[1])) + expected_merge_12 = setindex!!(expected_merge_12, 2, @varname(a[2])) + @test @inferred(merge(vnt1, vnt2)) == expected_merge_12 + expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1])) + @test @inferred(merge(vnt2, vnt1)) == expected_merge_21 + test_invariants(vnt1) + test_invariants(vnt2) + + vnt1 = VarNamedTuple() + vnt2 = VarNamedTuple() + vnt1 = setindex!!(vnt1, 1, @varname(a[1, 1])) + vnt1 = setindex!!(vnt1, 1, @varname(a[257, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 1])) + vnt2 = setindex!!(vnt2, :2, @varname(a[1, 257])) + expected_merge_12 = VarNamedTuple() + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 1])) + expected_merge_12 = setindex!!(expected_merge_12, 1, @varname(a[257, 1])) + expected_merge_12 = setindex!!(expected_merge_12, :2, @varname(a[1, 257])) + @test merge(vnt1, vnt2) == expected_merge_12 + expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1, 1])) + @test merge(vnt2, vnt1) == expected_merge_21 + test_invariants(vnt1) + test_invariants(vnt2) + end + + @testset "subset" begin + vnt = VarNamedTuple() + vnt = setindex!!(vnt, 1.0, @varname(a)) + vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) + vnt = setindex!!(vnt, [10], @varname(c.x.y)) + vnt = setindex!!(vnt, :1, @varname(d[1])) + vnt = setindex!!(vnt, :2, @varname(d[2])) + vnt = setindex!!(vnt, :3, @varname(d[3])) + vnt = setindex!!(vnt, 2.0, @varname(e.f[3, 3].g.h[2, 4, 1].i)) + vnt = setindex!!(vnt, SizedThing((3, 1, 4)), @varname(p[2, 1][2:4, 5:5, 11:14])) + test_invariants(vnt) + + # TODO(mhauru) I'm a bit saddened by the lack of type stability for subset: It's + # return type always infers as VarNamedTuple. Improving this would require a + # different implementation of subset. + @test subset(vnt, VarName[]) == VarNamedTuple() + @test subset(vnt, (@varname(z),)) == VarNamedTuple() + @test subset(vnt, (@varname(d[4]),)) == VarNamedTuple() + @test subset(vnt, (@varname(d[1, 1]),)) == VarNamedTuple() + @test subset(vnt, [@varname(a)]) == VarNamedTuple(; a=1.0) + @test subset(vnt, [@varname(b), @varname(d[1])]) == + VarNamedTuple((@varname(b) => [1, 2, 3], @varname(d[1]) => :1)) + @test subset(vnt, [@varname(d[2:3])]) == + VarNamedTuple((@varname(d[2]) => :2, @varname(d[3]) => :3)) + @test subset(vnt, [@varname(d)]) == VarNamedTuple(( + @varname(d[1]) => :1, @varname(d[2]) => :2, @varname(d[3]) => :3 + )) + @test subset(vnt, [@varname(c.x.y)]) == VarNamedTuple((@varname(c.x.y) => [10],)) + @test subset(vnt, [@varname(c)]) == VarNamedTuple((@varname(c.x.y) => [10],)) + @test subset(vnt, [@varname(e.f[3, 3].g.h[2, 4, 1].i)]) == + VarNamedTuple((@varname(e.f[3, 3].g.h[2, 4, 1].i) => 2.0,)) + @test subset(vnt, [@varname(p[2, 1][2:4, 5:5, 11:14])]) == + VarNamedTuple((@varname(p[2, 1][2:4, 5:5, 11:14]) => SizedThing((3, 1, 4)),)) + # Cutting the last range a bit short should mean that nothing is returned. + @test subset(vnt, [@varname(p[2, 1][2:4, 5:5, 11:13])]) == VarNamedTuple() + end + + @testset "keys and values" begin + vnt = VarNamedTuple() + @test @inferred(keys(vnt)) == VarName[] + @test @inferred(values(vnt)) == Any[] + + vnt = setindex!!(vnt, 1.0, @varname(a)) + # TODO(mhauru) that the below passes @inferred, but any of the later ones don't. + # We should improve type stability of keys(). + @test @inferred(keys(vnt)) == [@varname(a)] + @test @inferred(values(vnt)) == [1.0] + + vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) + @test keys(vnt) == [@varname(a), @varname(b)] + @test values(vnt) == [1.0, [1, 2, 3]] + + vnt = setindex!!(vnt, 15, @varname(b[2])) + @test keys(vnt) == [@varname(a), @varname(b)] + @test values(vnt) == [1.0, [1, 15, 3]] + + vnt = setindex!!(vnt, [10], @varname(c.x.y)) + @test keys(vnt) == [@varname(a), @varname(b), @varname(c.x.y)] + @test values(vnt) == [1.0, [1, 15, 3], [10]] + + vnt = setindex!!(vnt, -1.0, @varname(d[4])) + @test keys(vnt) == [@varname(a), @varname(b), @varname(c.x.y), @varname(d[4])] + @test values(vnt) == [1.0, [1, 15, 3], [10], -1.0] + + vnt = setindex!!(vnt, 2.0, @varname(e.f[3, 3].g.h[2, 4, 1].i)) + @test keys(vnt) == [ + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + ] + @test values(vnt) == [1.0, [1, 15, 3], [10], -1.0, 2.0] + + vnt = setindex!!(vnt, fill(1.0, 4), @varname(j[1:4])) + @test keys(vnt) == [ + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + ] + @test values(vnt) == [1.0, [1, 15, 3], [10], -1.0, 2.0, fill(1.0, 4)...] + + vnt = setindex!!(vnt, "a", @varname(j[6])) + @test keys(vnt) == [ + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + @varname(j[6]), + ] + @test values(vnt) == [1.0, [1, 15, 3], [10], -1.0, 2.0, fill(1.0, 4)..., "a"] + + vnt = setindex!!(vnt, 1.0, @varname(n[2].a)) + @test keys(vnt) == [ + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + @varname(j[6]), + @varname(n[2].a), + ] + @test values(vnt) == [1.0, [1, 15, 3], [10], -1.0, 2.0, fill(1.0, 4)..., "a", 1.0] + + vnt = setindex!!(vnt, SizedThing((3, 1, 4)), @varname(o[2:4, 5:5, 11:14])) + @test keys(vnt) == [ + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + @varname(j[6]), + @varname(n[2].a), + @varname(o[2:4, 5:5, 11:14]), + ] + @test values(vnt) == [ + 1.0, + [1, 15, 3], + [10], + -1.0, + 2.0, + fill(1.0, 4)..., + "a", + 1.0, + SizedThing((3, 1, 4)), + ] + + vnt = setindex!!(vnt, SizedThing((3, 1, 4)), @varname(p[2, 1][2:4, 5:5, 11:14])) + @test keys(vnt) == [ + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + @varname(j[6]), + @varname(n[2].a), + @varname(o[2:4, 5:5, 11:14]), + @varname(p[2, 1][2:4, 5:5, 11:14]), + ] + @test values(vnt) == [ + 1.0, + [1, 15, 3], + [10], + -1.0, + 2.0, + fill(1.0, 4)..., + "a", + 1.0, + SizedThing((3, 1, 4)), + SizedThing((3, 1, 4)), + ] + test_invariants(vnt) + end + + @testset "length" begin + # Type inference for length fails in some cases on Julia versions < 1.11 + inference_broken = VERSION < v"1.11" + + vnt = VarNamedTuple() + @test @inferred(length(vnt)) == 0 + + vnt = setindex!!(vnt, 1.0, @varname(a)) + @test @inferred(length(vnt)) == 1 + + vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) + @test @inferred(length(vnt)) == 2 + + vnt = setindex!!(vnt, 15, @varname(b[2])) + @test @inferred(length(vnt)) == 2 + + vnt = setindex!!(vnt, [10, 11], @varname(c.x.y)) + @test @inferred(length(vnt)) == 3 + + vnt = setindex!!(vnt, -1.0, @varname(d[4])) + @test @inferred(length(vnt)) == 4 broken = inference_broken + + vnt = setindex!!(vnt, ["a", "b"], @varname(d[1:2])) + @test @inferred(length(vnt)) == 6 broken = inference_broken + + vnt = setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i)) + vnt = setindex!!(vnt, 3.0, @varname(e.f[3].g.h[2].j)) + @test @inferred(length(vnt)) == 8 broken = inference_broken + + vnt = setindex!!(vnt, SizedThing((3, 2)), @varname(x[1, 2:4, 2, 1:2, 3])) + @test @inferred(length(vnt)) == 14 broken = inference_broken + + vnt = setindex!!(vnt, SizedThing((3, 2)), @varname(x[1, 4:6, 2, 1:2, 3])) + @test @inferred(length(vnt)) == 14 broken = inference_broken + + vnt = setindex!!(vnt, [:a, :b], @varname(y[4][3][2][1:2])) + @test @inferred(length(vnt)) == 16 broken = inference_broken + test_invariants(vnt) + end + + @testset "empty" begin + # test_invariants already checks that many different kinds of VarNamedTuples can be + # emptied with empty and empty!!. What remains to check here is that + # 1) isempty gives the expected results: + vnt = VarNamedTuple() + @test @inferred(isempty(vnt)) == true + vnt = setindex!!(vnt, 1.0, @varname(a)) + @test @inferred(isempty(vnt)) == false + test_invariants(vnt) + + vnt = VarNamedTuple() + vnt = setindex!!(vnt, [], @varname(a[1])) + @test @inferred(isempty(vnt)) == false + test_invariants(vnt) + + # 2) empty!! keeps PartialArrays in place: + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(a[1:3]))) + vnt = @inferred(empty!!(vnt)) + @test !haskey(vnt, @varname(a[1])) + @test !haskey(vnt, @varname(a[1:3])) + @test haskey(vnt, @varname(a)) + @test_throws BoundsError getindex(vnt, @varname(a[1])) + @test_throws BoundsError getindex(vnt, @varname(a[1:3])) + @test getindex(vnt, @varname(a)) == [] + vnt = @inferred(setindex!!(vnt, [1, 2, 3], @varname(a[2:4]))) + @test @inferred(getindex(vnt, @varname(a[2:4]))) == [1, 2, 3] + @test haskey(vnt, @varname(a[2:4])) + @test !haskey(vnt, @varname(a[1])) + test_invariants(vnt) + end + + @testset "densification" begin + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 1]))) + @test @inferred(getindex(vnt, @varname(a.b[1].c))) == fill(1.0, (1, 1)) + test_invariants(vnt) + + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 1]))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 2]))) + @test @inferred(getindex(vnt, @varname(a.b[1].c))) == fill(1.0, (1, 2)) + test_invariants(vnt) + + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 1]))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[2, 1]))) + @test @inferred(getindex(vnt, @varname(a.b[1].c))) == fill(1.0, (2, 1)) + test_invariants(vnt) + + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 1]))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 2]))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[2, 1]))) + @test_throws ArgumentError @inferred(getindex(vnt, @varname(a.b[1].c))) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[2, 2]))) + @test @inferred(getindex(vnt, @varname(a.b[1].c))) == fill(1.0, (2, 2)) + vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[3, 3]))) + @test_throws ArgumentError @inferred(getindex(vnt, @varname(a.b[1].c))) + test_invariants(vnt) + + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, SizedThing((2,)), @varname(x[1:2]))) + @test_throws ArgumentError @inferred(getindex(vnt, @varname(x))) + test_invariants(vnt) + end + + @testset "printing" begin + vnt = VarNamedTuple() + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == "VarNamedTuple()" + + vnt = setindex!!(vnt, "s", @varname(a)) + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == """VarNamedTuple(a = "s",)""" + + vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == """VarNamedTuple(a = "s", b = [1, 2, 3])""" + + vnt = setindex!!(vnt, :dada, @varname(c[2])) + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + @test output == """ + VarNamedTuple(a = "s", b = [1, 2, 3], \ + c = PartialArray{Symbol,1}((2,) => :dada))""" + + vnt = setindex!!(vnt, [16.0, 17.0], @varname(d.e[3][2, 2].f.g[1:2])) + io = IOBuffer() + show(io, vnt) + output = String(take!(io)) + # Depending on what's in scope, and maybe sometimes even the Julia version, + # sometimes types in the output are fully qualified, sometimes not. To avoid + # brittle tests, we normalise the output: + output = replace(output, "DynamicPPL." => "", "VarNamedTuples." => "") + @test output == """ + VarNamedTuple(a = "s", b = [1, 2, 3], \ + c = PartialArray{Symbol,1}((2,) => :dada), \ + d = VarNamedTuple(\ + e = PartialArray{PartialArray{VarNamedTuple{(:f,), \ + Tuple{VarNamedTuple{(:g,), \ + Tuple{PartialArray{Float64, 1}}}}}, 2},1}((3,) => \ + PartialArray{VarNamedTuple{(:f,), \ + Tuple{VarNamedTuple{(:g,), \ + Tuple{PartialArray{Float64, 1}}}}},2}((2, 2) => VarNamedTuple(f = VarNamedTuple(g = PartialArray{Float64,1}((1,) => 16.0, \ + (2,) => 17.0),),))),))""" + test_invariants(vnt) + end + + @testset "block variables" begin + # Tests for setting and getting block variables, i.e. variables that have a non-zero + # size in a PartialArray, but are not Arrays themselves. + expected_err = ArgumentError(""" + A non-Array value set with a range of indices must be retrieved with the same + range of indices. + """) + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, SizedThing((3,)), @varname(x[2:4]))) + test_invariants(vnt) + @test haskey(vnt, @varname(x[2:4])) + @test @inferred(getindex(vnt, @varname(x[2:4]))) == SizedThing((3,)) + @test !haskey(vnt, @varname(x[2:3])) + @test_throws expected_err getindex(vnt, @varname(x[2:3])) + @test !haskey(vnt, @varname(x[3])) + @test_throws expected_err getindex(vnt, @varname(x[3])) + @test !haskey(vnt, @varname(x[1])) + @test !haskey(vnt, @varname(x[5])) + vnt = setindex!!(vnt, 1.0, @varname(x[1])) + vnt = setindex!!(vnt, 1.0, @varname(x[5])) + test_invariants(vnt) + @test haskey(vnt, @varname(x[1])) + @test haskey(vnt, @varname(x[5])) + @test_throws expected_err getindex(vnt, @varname(x[1:4])) + @test_throws expected_err getindex(vnt, @varname(x[2:5])) + + # Setting any of these indices should remove the block variable x[2:4]. + @testset "index = $index" for index in (2, 3, 4, 2:3, 3:5) + # Test setting different types of values. + vals = if index isa Int + (2.0,) + else + (fill(2.0, length(index)), SizedThing((length(index),))) + end + @testset "val = $val" for val in vals + vn = @varname(x[index]) + vnt2 = copy(vnt) + vnt2 = setindex!!(vnt2, val, vn) + test_invariants(vnt) + @test !haskey(vnt2, @varname(x[2:4])) + @test_throws BoundsError getindex(vnt2, @varname(x[2:4])) + other_index = index in (2, 2:3) ? 4 : 2 + @test !haskey(vnt2, @varname(x[other_index])) + @test_throws BoundsError getindex(vnt2, @varname(x[other_index])) + @test haskey(vnt2, vn) + @test getindex(vnt2, vn) == val + @test haskey(vnt2, @varname(x[1])) + @test_throws BoundsError getindex(vnt2, @varname(x[1:4])) + end + end + + # Extra checks, mostly for type stability and to confirm that multidimensional + # blocks work too. + val = SizedThing((2, 2)) + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[1:2, 1:2]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[1:2, 1:2])) + @test @inferred(getindex(vnt, @varname(y.z[1:2, 1:2]))) == val + @test !haskey(vnt, @varname(y.z[1, 1])) + @test_throws expected_err getindex(vnt, @varname(y.z[1, 1])) + + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[2:3, 2:3]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[2:3, 2:3])) + @test @inferred(getindex(vnt, @varname(y.z[2:3, 2:3]))) == val + @test !haskey(vnt, @varname(y.z[1:2, 1:2])) + @test_throws BoundsError getindex(vnt, @varname(y.z[1:2, 1:2])) + + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[4:5, 2:3]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[2:3, 2:3])) + @test @inferred(getindex(vnt, @varname(y.z[2:3, 2:3]))) == val + @test haskey(vnt, @varname(y.z[4:5, 2:3])) + @test @inferred(getindex(vnt, @varname(y.z[4:5, 2:3]))) == val + + # A lot like above, but with extra indices that are not ranges. + val = SizedThing((2, 2)) + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[2, 1:2, 3, 1:2, 4]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[2, 1:2, 3, 1:2, 4])) + @test @inferred(getindex(vnt, @varname(y.z[2, 1:2, 3, 1:2, 4]))) == val + @test !haskey(vnt, @varname(y.z[2, 1, 3, 1, 4])) + @test_throws expected_err getindex(vnt, @varname(y.z[2, 1, 3, 1, 4])) + + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[2, 2:3, 3, 2:3, 4]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4])) + @test @inferred(getindex(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4]))) == val + @test !haskey(vnt, @varname(y.z[2, 1:2, 3, 1:2, 4])) + @test_throws BoundsError getindex(vnt, @varname(y.z[2, 1:2, 3, 1:2, 4])) + + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[3, 2:3, 3, 2:3, 4]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4])) + @test @inferred(getindex(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4]))) == val + @test haskey(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4])) + # Type inference fails on this one for Julia versions < 1.11 + @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val + end + + @testset "map and friends" begin + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1, @varname(a))) + vnt = @inferred(setindex!!(vnt, [2, 2], @varname(b[1:2]))) + vnt = @inferred(setindex!!(vnt, [3.0], @varname(c.d))) + vnt = @inferred(setindex!!(vnt, "a", @varname(e.f[3].g.h[2].i))) + # The below can't be type stable because the element type of `h` depends on whether + # we are setting `h[2].j` (which overwrites the earlier `h[2]`) or some other + # `h[index].j` (which would leave both `h[2].i` and `h[index].j` in the same array). + vnt = setindex!!(vnt, 5.0, @varname(e.f[3].g.h[2].j)) + vnt = @inferred( + setindex!!(vnt, SizedThing((2, 2)), @varname(y.z[3, 2:3, 3, 2:3, 4])) + ) + # TODO(penelopeysm) Reenable colons. + # colon_vn = concretize(@varname(v[:]), [0, 0]) + # vnt = @inferred(setindex!!(vnt, SizedThing((2,)), colon_vn)) + vnt = @inferred(setindex!!(vnt, "", @varname(w[4][3][2, 1]))) + # TODO(mhauru) The below skip is needed because AbstractPPL's ConretizedSlice + # objects don't respect the eval(Meta.parse(repr(...))) == ... property. + test_invariants(vnt) + + struct AnotherSizedThing{T<:Tuple} + size::T + end + Base.size(st::AnotherSizedThing) = st.size + + call_counter = 0 + function f_val(val) + call_counter += 1 + if val isa Int + return val + 10 + elseif val isa AbstractVector{Int} + return val .+ 10 + elseif val isa Float64 + return val + 1.0 + elseif val isa AbstractVector{Float64} + return val .- 1.0 + elseif val isa String + return string(val, "b") + elseif val isa SizedThing + return AnotherSizedThing(size(val)) + else + error("Unexpected value type $(typeof(val))") + end + end + + f_pair(pair) = f_val(pair.second) + + val_reduction = mapreduce(pair -> pair.second, vcat, vnt; init=Any[]) + @test val_reduction == vcat( + Any[], + 1, + [2, 2], + [3.0], + "a", + 5.0, + SizedThing((2, 2)), + # The below would have come from colon_vn + # SizedThing((2,)), + "", + ) + key_reduction = mapreduce(pair -> pair.first, vcat, vnt; init=Any[]) + @test key_reduction == vcat( + @varname(a), + @varname(b[1]), + @varname(b[2]), + @varname(c.d), + @varname(e.f[3].g.h[2].i), + @varname(e.f[3].g.h[2].j), + @varname(y.z[3, 2:3, 3, 2:3, 4]), + # colon_vn, + @varname(w[4][3][2, 1]), + ) + + call_counter = 0 + reduction = mapreduce(f_pair, vcat, vnt; init=Any[]) + @test reduction == vcat( + Any[], + 11, + [12, 12], + [2.0], + "ab", + 6.0, + AnotherSizedThing((2, 2)), + # The below would have come from colon_vn + # AnotherSizedThing((2,)), + "b", + ) + # Check that f_pair gets called exactly once per element. + @test call_counter == length(keys(vnt)) + + # TODO(mhauru) This should hopefully be type stable, but fails to be so because of + # some complex VarNames being too much for constant propagation. See comment in + # src/varnamedtuple.jl for more. + call_counter = 0 + vnt_mapped = map_pairs!!(f_pair, copy(vnt)) + # Check that f_pair gets called exactly once per element. + @test call_counter == length(keys(vnt)) + @test vnt_mapped == map_values!!(f_val, copy(vnt)) + test_invariants(vnt_mapped; skip=(:parseeval,)) + @test @inferred(getindex(vnt_mapped, @varname(a))) == 11 + @test @inferred(getindex(vnt_mapped, @varname(b[1:2]))) == [12, 12] + @test @inferred(getindex(vnt_mapped, @varname(c.d))) == [2.0] + @test @inferred(getindex(vnt_mapped, @varname(e.f[3].g.h[2].i))) == "ab" + @test @inferred(getindex(vnt_mapped, @varname(e.f[3].g.h[2].j))) == 6.0 + @test @inferred(getindex(vnt_mapped, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == + AnotherSizedThing((2, 2)) + # @test @inferred(getindex(vnt_mapped, colon_vn)) == AnotherSizedThing((2,)) + @test @inferred(getindex(vnt_mapped, @varname(w[4][3][2, 1]))) == "b" + + call_counter = 0 + vnt_applied = copy(vnt) + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(a))) + @test call_counter == 1 + test_invariants(vnt_applied; skip=(:parseeval,)) + @test @inferred(getindex(vnt_applied, @varname(a))) == 11 + @test @inferred(getindex(vnt_applied, @varname(b[1:2]))) == [2, 2] + + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(b[1:2]))) + # Unlike map_pairs!!, apply!! operates on the whole value at once, rather than + # element-wise, so this is only one more call. + @test call_counter == 2 + test_invariants(vnt_applied; skip=(:parseeval,)) + @test @inferred(getindex(vnt_applied, @varname(a))) == 11 + @test @inferred(getindex(vnt_applied, @varname(b[1:2]))) == [12, 12] + + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(c.d))) + @test call_counter == 3 + test_invariants(vnt_applied; skip=(:parseeval,)) + @test @inferred(getindex(vnt_applied, @varname(c.d))) == [2.0] + + vnt_applied = begin + # The @inferred fails on Julia 1.10. + @static if VERSION < v"1.11" + apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].i)) + else + @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].i))) + end + end + @test call_counter == 4 + test_invariants(vnt_applied; skip=(:parseeval,)) + @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" + @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].j))) == 5.0 + + vnt_applied = begin + # The @inferred fails on Julia 1.10. + @static if VERSION < v"1.11" + apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].j)) + else + @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].j))) + end + end + @test call_counter == 5 + test_invariants(vnt_applied; skip=(:parseeval,)) + @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" + @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].j))) == 6.0 + + # This can't be type stable because y.z might have many elements set, and we can't + # know at compile time that this sets the only one, thus allowing the element type + # to be AnotherSizedThing. + vnt_applied = apply!!(f_val, vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4])) + @test call_counter == 6 + test_invariants(vnt_applied; skip=(:parseeval,)) + @test @inferred(getindex(vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == + AnotherSizedThing((2, 2)) + + # vnt_applied = apply!!(f_val, vnt_applied, colon_vn) + # @test call_counter == 7 + # test_invariants(vnt_applied; skip=(:parseeval,)) + # @test @inferred(getindex(vnt_applied, colon_vn)) == AnotherSizedThing((2,)) + + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(w[4][3][2, 1]))) + @test call_counter == 7 + test_invariants(vnt_applied; skip=(:parseeval,)) + @test @inferred(getindex(vnt_applied, @varname(w[4][3][2, 1]))) == "b" + + # map a function that maps every key => value pair to key => key. + # For this, use a simpler VarNamedTuple, because block variables don't work with + # this mapping function. It also allows us to check type stability. + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1, @varname(a))) + vnt = @inferred(setindex!!(vnt, 2, @varname(b[2]))) + vnt = @inferred(setindex!!(vnt, [3.0], @varname(c.d))) + vnt = @inferred(setindex!!(vnt, :oi, @varname(y.z[3, 2, 3, 2, 4]))) + vnt = @inferred(setindex!!(vnt, "", @varname(w[4][2, 1]))) + + get_key(pair) = pair.first + vnt_key_mapped = @inferred(map_pairs!!(get_key, copy(vnt))) + vnt_key_mapped_expected = VarNamedTuple() + for k in keys(vnt) + vnt_key_mapped_expected = setindex!!(vnt_key_mapped_expected, k, k) + end + @test vnt_key_mapped == vnt_key_mapped_expected + end +end + +end diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl deleted file mode 100644 index 9a4ef12c3..000000000 --- a/test/varnamedvector.jl +++ /dev/null @@ -1,711 +0,0 @@ -replace_sym(vn::VarName, sym_new::Symbol) = VarName{sym_new}(vn.lens) - -increase_size_for_test(x::Real) = [x] -increase_size_for_test(x::AbstractArray) = repeat(x, 2) - -decrease_size_for_test(x::Real) = x -decrease_size_for_test(x::AbstractVector) = first(x) -decrease_size_for_test(x::AbstractArray) = first(eachslice(x; dims=1)) - -function need_varnames_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) - if isconcretetype(eltype(vnv.varnames)) - # If the container is concrete, we need to make sure that the varname types match. - # E.g. if `vnv.varnames` has `eltype` `VarName{:x, IndexLens{Tuple{Int64}}}` then - # we need `vn` to also be of this type. - # => If the varname types don't match, we need to relax the container type. - return any(keys(vnv)) do vn_present - typeof(vn_present) !== typeof(val) - end - end - - return false -end -function need_varnames_relaxation(vnv::DynamicPPL.VarNamedVector, vns, vals) - return any(need_varnames_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) -end - -function need_values_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) - if isconcretetype(eltype(vnv.vals)) - return promote_type(eltype(vnv.vals), eltype(val)) != eltype(vnv.vals) - end - - return false -end -function need_values_relaxation(vnv::DynamicPPL.VarNamedVector, vns, vals) - return any(need_values_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) -end - -function need_transforms_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) - return if isconcretetype(eltype(vnv.transforms)) - # If the container is concrete, we need to make sure that the sizes match. - # => If the sizes don't match, we need to relax the container type. - any(keys(vnv)) do vn_present - size(vnv[vn_present]) != size(val) - end - elseif eltype(vnv.transforms) !== Any - # If it's not concrete AND it's not `Any`, then we should just make it `Any`. - true - else - # Otherwise, it's `Any`, so we don't need to relax the container type. - false - end -end -function need_transforms_relaxation(vnv::DynamicPPL.VarNamedVector, vns, vals) - return any(need_transforms_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) -end - -""" - relax_container_types(vnv::VarNamedVector, vn::VarName, val) - relax_container_types(vnv::VarNamedVector, vns, val) - -Relax the container types of `vnv` if necessary to accommodate `vn` and `val`. - -This attempts to avoid unnecessary container type relaxations by checking whether -the container types of `vnv` are already compatible with `vn` and `val`. - -# Notes -For example, if `vn` is not compatible with the current keys in `vnv`, then -the underlying types will be changed to `VarName` to accommodate `vn`. - -Similarly: -- If `val` is not compatible with the current values in `vnv`, then - the underlying value type will be changed to `Real`. -- If `val` requires a transformation that is not compatible with the current - transformations type in `vnv`, then the underlying transformation type will - be changed to `Any`. -""" -function relax_container_types(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) - return relax_container_types(vnv, [vn], [val]) -end -function relax_container_types(vnv::DynamicPPL.VarNamedVector, vns, vals) - if need_varnames_relaxation(vnv, vns, vals) - varname_to_index_new = convert(Dict{VarName,Int}, vnv.varname_to_index) - varnames_new = convert(Vector{VarName}, vnv.varnames) - else - varname_to_index_new = vnv.varname_to_index - varnames_new = vnv.varnames - end - - transforms_new = if need_transforms_relaxation(vnv, vns, vals) - convert(Vector{Any}, vnv.transforms) - else - vnv.transforms - end - - vals_new = if need_values_relaxation(vnv, vns, vals) - convert(Vector{Real}, vnv.vals) - else - vnv.vals - end - - return DynamicPPL.VarNamedVector( - varname_to_index_new, - varnames_new, - vnv.ranges, - vals_new, - transforms_new, - vnv.is_unconstrained, - vnv.num_inactive, - ) -end - -@testset "VarNamedVector" begin - # Test element-related operations: - # - `getindex` - # - `setindex!` - # - `push!` - # - `update!` - # - `insert!` - # - `reset!` - # - `_internal!` versions of the above - # - !! versions of the above - # - # And these are all be tested for different types of values: - # - scalar - # - vector - # - matrix - - # Test operations on `VarNamedVector`: - # - `empty!` - # - `iterate` - # - `convert` to - # - `AbstractDict` - test_pairs = OrderedDict( - @varname(x[1]) => rand(), - @varname(x[2]) => rand(2), - @varname(x[3]) => rand(2, 3), - @varname(y[1]) => rand(), - @varname(y[2]) => rand(2), - @varname(y[3]) => rand(2, 3), - @varname(z[1]) => rand(1:10), - @varname(z[2]) => rand(1:10, 2), - @varname(z[3]) => rand(1:10, 2, 3), - ) - test_vns = collect(keys(test_pairs)) - test_vals = collect(values(test_pairs)) - - @testset "constructor: no args" begin - # Empty. - vnv = DynamicPPL.VarNamedVector() - @test isempty(vnv) - @test eltype(vnv) == Union{} - - # Empty with types. - vnv = DynamicPPL.VarNamedVector{VarName,Float64,typeof(identity)}() - @test isempty(vnv) - @test eltype(vnv) == Float64 - end - - test_varnames_iter = combinations(test_vns, 2) - @testset "$(vn_left) and $(vn_right)" for (vn_left, vn_right) in test_varnames_iter - val_left = test_pairs[vn_left] - val_right = test_pairs[vn_right] - vnv_base = DynamicPPL.VarNamedVector([vn_left, vn_right], [val_left, val_right]) - - # We'll need the transformations later. - # TODO: Should we test other transformations than just `ReshapeTransform`? - from_vec_left = DynamicPPL.from_vec_transform(val_left) - from_vec_right = DynamicPPL.from_vec_transform(val_right) - to_vec_left = inverse(from_vec_left) - to_vec_right = inverse(from_vec_right) - - # Compare to alternative constructors. - vnv_from_dict = DynamicPPL.VarNamedVector( - OrderedDict(vn_left => val_left, vn_right => val_right) - ) - @test vnv_base == vnv_from_dict - - # We want the types of fields such as `varnames` and `transforms` to specialize - # whenever possible + some functionality, e.g. `push!`, is only sensible - # if the underlying containers can support it. - # Expected behavior - should_have_restricted_varname_type = typeof(vn_left) == typeof(vn_right) - should_have_restricted_transform_type = size(val_left) == size(val_right) - # Actual behavior - has_restricted_transform_type = isconcretetype(eltype(vnv_base.transforms)) - has_restricted_varname_type = isconcretetype(eltype(vnv_base.varnames)) - - @testset "type specialization" begin - @test !should_have_restricted_varname_type || has_restricted_varname_type - @test !should_have_restricted_transform_type || has_restricted_transform_type - end - - @test eltype(vnv_base) == promote_type(eltype(val_left), eltype(val_right)) - @test DynamicPPL.length_internal(vnv_base) == length(val_left) + length(val_right) - @test length(vnv_base) == 2 - - @test !isempty(vnv_base) - - @testset "empty!" begin - vnv = deepcopy(vnv_base) - empty!(vnv) - @test isempty(vnv) - end - - @testset "similar" begin - vnv = similar(vnv_base) - @test isempty(vnv) - @test typeof(vnv) == typeof(vnv_base) - end - - @testset "getindex" begin - # With `VarName` index. - @test vnv_base[vn_left] == val_left - @test vnv_base[vn_right] == val_right - end - - @testset "getindex_internal" begin - @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, vn_left) == - to_vec_left(val_left) - @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, vn_right) == - to_vec_right(val_right) - end - - @testset "getindex_internal with Ints" begin - for (i, val) in enumerate(to_vec_left(val_left)) - @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, i) == val - end - offset = length(to_vec_left(val_left)) - for (i, val) in enumerate(to_vec_right(val_right)) - @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, offset + i) == val - end - end - - @testset "update!" begin - vnv = deepcopy(vnv_base) - DynamicPPL.update!(vnv, val_left .+ 100, vn_left) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.update!(vnv, val_right .+ 100, vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "update!!" begin - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.update!!(vnv, val_left .+ 100, vn_left) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.update!!(vnv, val_right .+ 100, vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "update_internal!" begin - vnv = deepcopy(vnv_base) - DynamicPPL.update_internal!(vnv, to_vec_left(val_left .+ 100), vn_left) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.update_internal!(vnv, to_vec_right(val_right .+ 100), vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "update_internal!!" begin - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.update_internal!!(vnv, to_vec_left(val_left .+ 100), vn_left) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.update_internal!!( - vnv, to_vec_right(val_right .+ 100), vn_right - ) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "delete!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - @test !haskey(vnv, vn_left) - @test haskey(vnv, vn_right) - delete!(vnv, vn_right) - @test !haskey(vnv, vn_right) - end - - @testset "insert!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - delete!(vnv, vn_right) - DynamicPPL.insert!(vnv, val_left .+ 100, vn_left) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.insert!(vnv, val_right .+ 100, vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "insert!!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - delete!(vnv, vn_right) - vnv = DynamicPPL.insert!!(vnv, val_left .+ 100, vn_left) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.insert!!(vnv, val_right .+ 100, vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "insert_internal!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - delete!(vnv, vn_right) - DynamicPPL.insert_internal!( - vnv, to_vec_left(val_left .+ 100), vn_left, from_vec_left - ) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.insert_internal!( - vnv, to_vec_right(val_right .+ 100), vn_right, from_vec_right - ) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "insert_internal!!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - delete!(vnv, vn_right) - vnv = DynamicPPL.insert_internal!!( - vnv, to_vec_left(val_left .+ 100), vn_left, from_vec_left - ) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.insert_internal!!( - vnv, to_vec_right(val_right .+ 100), vn_right, from_vec_right - ) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "merge" begin - # When there are no inactive entries, `merge` on itself result in the same. - @test merge(vnv_base, vnv_base) == vnv_base - - # Merging with empty should result in the same. - @test merge(vnv_base, similar(vnv_base)) == vnv_base - @test merge(similar(vnv_base), vnv_base) == vnv_base - - # With differences. - vnv_left_only = deepcopy(vnv_base) - delete!(vnv_left_only, vn_right) - vnv_right_only = deepcopy(vnv_base) - delete!(vnv_right_only, vn_left) - - # `(x,)` and `(x, y)` should be `(x, y)`. - @test merge(vnv_left_only, vnv_base) == vnv_base - # `(x, y)` and `(x,)` should be `(x, y)`. - @test merge(vnv_base, vnv_left_only) == vnv_base - # `(x, y)` and `(y,)` should be `(x, y)`. - @test merge(vnv_base, vnv_right_only) == vnv_base - # `(y,)` and `(x, y)` should be `(y, x)`. - vnv_merged = merge(vnv_right_only, vnv_base) - @test vnv_merged != vnv_base - @test collect(keys(vnv_merged)) == [vn_right, vn_left] - end - - @testset "push!" begin - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn" for vn in test_vns - val = test_pairs[vn] - vnv_copy = deepcopy(vnv) - push!(vnv, (vn => val)) - @test vnv[vn] == val - end - end - - @testset "setindex_internal!" begin - # Not setting the transformation. - vnv = deepcopy(vnv_base) - DynamicPPL.setindex_internal!(vnv, to_vec_left(val_left .+ 100), vn_left) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.setindex_internal!(vnv, to_vec_right(val_right .+ 100), vn_right) - @test vnv[vn_right] == val_right .+ 100 - - # Explicitly setting the transformation. - increment(x) = x .+ 10 - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.loosen_types!!( - vnv, typeof(vn_left), eltype(vnv), typeof(increment) - ) - DynamicPPL.setindex_internal!( - vnv, to_vec_left(val_left .+ 100), vn_left, increment - ) - @test vnv[vn_left] == to_vec_left(val_left .+ 110) - - vnv = DynamicPPL.loosen_types!!( - vnv, typeof(vn_right), eltype(vnv), typeof(increment) - ) - DynamicPPL.setindex_internal!( - vnv, to_vec_right(val_right .+ 100), vn_right, increment - ) - @test vnv[vn_right] == to_vec_right(val_right .+ 110) - - # Adding new values. - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn" for vn in test_vns - val = test_pairs[vn] - from_vec_vn = DynamicPPL.from_vec_transform(val) - to_vec_vn = inverse(from_vec_vn) - DynamicPPL.setindex_internal!(vnv, to_vec_vn(val), vn, from_vec_vn) - @test vnv[vn] == val - end - end - - @testset "setindex_internal! with Ints" begin - vnv = deepcopy(vnv_base) - for i in 1:DynamicPPL.length_internal(vnv_base) - DynamicPPL.setindex_internal!(vnv, i, i) - end - for i in 1:DynamicPPL.length_internal(vnv_base) - @test DynamicPPL.getindex_internal(vnv, i) == i - end - end - - @testset "setindex_internal!!" begin - # Not setting the transformation. - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.setindex_internal!!(vnv, to_vec_left(val_left .+ 100), vn_left) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.setindex_internal!!( - vnv, to_vec_right(val_right .+ 100), vn_right - ) - @test vnv[vn_right] == val_right .+ 100 - - # Explicitly setting the transformation. - # Note that unlike with setindex_internal!, we don't need loosen_types!! here. - increment(x) = x .+ 10 - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.setindex_internal!!( - vnv, to_vec_left(val_left .+ 100), vn_left, increment - ) - @test vnv[vn_left] == to_vec_left(val_left .+ 110) - - vnv = DynamicPPL.setindex_internal!!( - vnv, to_vec_right(val_right .+ 100), vn_right, increment - ) - @test vnv[vn_right] == to_vec_right(val_right .+ 110) - - # Adding new values. - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn" for vn in test_vns - val = test_pairs[vn] - from_vec_vn = DynamicPPL.from_vec_transform(val) - to_vec_vn = inverse(from_vec_vn) - vnv = DynamicPPL.setindex_internal!!(vnv, to_vec_vn(val), vn, from_vec_vn) - @test vnv[vn] == val - end - end - - @testset "setindex! and reset!" begin - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn" for vn in test_vns - val = test_pairs[vn] - expected_length = if haskey(vnv, vn) - # If it's already present, the resulting length will be unchanged. - DynamicPPL.length_internal(vnv) - else - DynamicPPL.length_internal(vnv) + length(val) - end - - vnv[vn] = val .+ 1 - x = DynamicPPL.getindex_internal(vnv, :) - @test vnv[vn] == val .+ 1 - @test DynamicPPL.length_internal(vnv) == expected_length - @test length(x) == DynamicPPL.length_internal(vnv) - @test all( - DynamicPPL.getindex_internal(vnv, i) == x[i] for i in eachindex(x) - ) - - # There should be no redundant values in the underlying vector. - @test !DynamicPPL.has_inactive(vnv) - end - - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn (increased size)" for vn in test_vns - val_original = test_pairs[vn] - val = increase_size_for_test(val_original) - vn_already_present = haskey(vnv, vn) - expected_length = if vn_already_present - # If it's already present, the resulting length will be altered. - DynamicPPL.length_internal(vnv) + length(val) - length(val_original) - else - DynamicPPL.length_internal(vnv) + length(val) - end - - # Have to use reset!, because setindex! doesn't support decreasing size. - DynamicPPL.reset!(vnv, val .+ 1, vn) - x = DynamicPPL.getindex_internal(vnv, :) - @test vnv[vn] == val .+ 1 - @test DynamicPPL.length_internal(vnv) == expected_length - @test length(x) == DynamicPPL.length_internal(vnv) - @test all( - DynamicPPL.getindex_internal(vnv, i) == x[i] for i in eachindex(x) - ) - end - - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn (decreased size)" for vn in test_vns - val_original = test_pairs[vn] - val = decrease_size_for_test(val_original) - vn_already_present = haskey(vnv, vn) - expected_length = if vn_already_present - # If it's already present, the resulting length will be altered. - DynamicPPL.length_internal(vnv) + length(val) - length(val_original) - else - DynamicPPL.length_internal(vnv) + length(val) - end - - # Have to use reset!, because setindex! doesn't support decreasing size. - DynamicPPL.reset!(vnv, val .+ 1, vn) - x = DynamicPPL.getindex_internal(vnv, :) - @test vnv[vn] == val .+ 1 - @test DynamicPPL.length_internal(vnv) == expected_length - @test length(x) == DynamicPPL.length_internal(vnv) - @test all( - DynamicPPL.getindex_internal(vnv, i) == x[i] for i in eachindex(x) - ) - end - end - end - - @testset "growing and shrinking" begin - @testset "deterministic" begin - n = 5 - vn = @varname(x) - vnv = DynamicPPL.VarNamedVector(Dict(vn => [true])) - @test !DynamicPPL.has_inactive(vnv) - # Growing should not create inactive ranges. - for i in 1:n - x = fill(true, i) - DynamicPPL.update_internal!(vnv, x, vn, identity) - @test !DynamicPPL.has_inactive(vnv) - end - - # Same size should not create inactive ranges. - x = fill(true, n) - DynamicPPL.update_internal!(vnv, x, vn, identity) - @test !DynamicPPL.has_inactive(vnv) - - # Shrinking should create inactive ranges. - for i in (n - 1):-1:1 - x = fill(true, i) - DynamicPPL.update_internal!(vnv, x, vn, identity) - @test DynamicPPL.has_inactive(vnv) - @test DynamicPPL.num_inactive(vnv, vn) == n - i - end - end - - @testset "random" begin - n = 5 - vn = @varname(x) - vnv = DynamicPPL.VarNamedVector(Dict(vn => [true])) - @test !DynamicPPL.has_inactive(vnv) - - # Insert a bunch of random-length vectors. - for i in 1:100 - x = fill(true, rand(1:n)) - DynamicPPL.update!(vnv, x, vn) - end - # Should never be allocating more than `n` elements. - @test DynamicPPL.num_allocated(vnv, vn) ≤ n - - # If we compaticfy, then it should always be the same size as just inserted. - for i in 1:10 - x = fill(true, rand(1:n)) - DynamicPPL.update!(vnv, x, vn) - DynamicPPL.contiguify!(vnv) - @test DynamicPPL.num_allocated(vnv, vn) == length(x) - end - end - end - - @testset "subset" begin - vnv = DynamicPPL.VarNamedVector(test_pairs) - @test subset(vnv, test_vns) == vnv - @test subset(vnv, VarName[]) == DynamicPPL.VarNamedVector() - @test merge(subset(vnv, test_vns[1:3]), subset(vnv, test_vns[4:end])) == vnv - - # Test that subset preserves transformations and unconstrainedness. - vn = @varname(t[1]) - vns = vcat(test_vns, [vn]) - vnv = DynamicPPL.setindex_internal!!(vnv, [2.0], vn, x -> x .^ 2) - DynamicPPL.set_transformed!(vnv, true, @varname(t[1])) - @test vnv[@varname(t[1])] == [4.0] - @test is_transformed(vnv, @varname(t[1])) - @test subset(vnv, vns) == vnv - end - - @testset "loosen and tighten types" begin - """ - test_tightenability(vnv::VarNamedVector) - - Test that tighten_types!! is a no-op on `vnv`. - """ - function test_tightenability(vnv::DynamicPPL.VarNamedVector) - @test vnv == DynamicPPL.tighten_types!!(deepcopy(vnv)) - # TODO(mhauru) We would like to check something more stringent here, namely that - # the operation is compiled to a direct no-op, with no instructions at all. I - # don't know how to do that though, so for now we just check that it doesn't - # allocate. - @allocations(DynamicPPL.tighten_types!!(vnv)) == 0 - return nothing - end - - vn = @varname(a[1]) - # Test that tighten_types!! is a no-op on an empty VarNamedVector. - vnv = DynamicPPL.VarNamedVector() - @test DynamicPPL.is_tightly_typed(vnv) - test_tightenability(vnv) - # Also check that it literally returns the same object, and both tighten and loosen - # are type stable. - @test vnv === DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) - # Likewise for a VarNamedVector with something pushed into it. - vnv = DynamicPPL.VarNamedVector() - vnv = setindex!!(vnv, 1.0, vn) - @test DynamicPPL.is_tightly_typed(vnv) - test_tightenability(vnv) - @test vnv === DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) - # Likewise for a VarNamedVector with abstract element-types, when that is needed for - # the current contents because mixed types have been pushed into it. However, this - # time, since the types are only as tight as they can be, but not actually concrete, - # tighten_types!! can't be type stable. - vnv = DynamicPPL.VarNamedVector() - vnv = setindex!!(vnv, 1.0, vn) - vnv = setindex!!(vnv, 2, @varname(b)) - @test !DynamicPPL.is_tightly_typed(vnv) - test_tightenability(vnv) - @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) - # Likewise when first mixed types are pushed, but then deleted. - vnv = DynamicPPL.VarNamedVector() - vnv = setindex!!(vnv, 1.0, vn) - vnv = setindex!!(vnv, 2, @varname(b)) - @test !DynamicPPL.is_tightly_typed(vnv) - vnv = delete!!(vnv, vn) - @test DynamicPPL.is_tightly_typed(vnv) - test_tightenability(vnv) - @test vnv === DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) - - # Test that loosen_types!! does really loosen them and that tighten_types!! reverts - # that. - vnv = DynamicPPL.VarNamedVector() - vnv = setindex!!(vnv, 1.0, vn) - @test DynamicPPL.is_tightly_typed(vnv) - k = eltype(vnv.varnames) - e = eltype(vnv.vals) - t = eltype(vnv.transforms) - # Loosen key type. - vnv = @inferred DynamicPPL.loosen_types!!(vnv, VarName, e, t) - @test !DynamicPPL.is_tightly_typed(vnv) - vnv = DynamicPPL.tighten_types!!(vnv) - @test DynamicPPL.is_tightly_typed(vnv) - # Loosen element type - vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, Real, t) - @test !DynamicPPL.is_tightly_typed(vnv) - vnv = DynamicPPL.tighten_types!!(vnv) - @test DynamicPPL.is_tightly_typed(vnv) - # Loosen transformation type - vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, e, Function) - @test !DynamicPPL.is_tightly_typed(vnv) - vnv = DynamicPPL.tighten_types!!(vnv) - @test DynamicPPL.is_tightly_typed(vnv) - # Loosening to the same types as currently should do nothing. - vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, e, t) - @test DynamicPPL.is_tightly_typed(vnv) - @allocations(DynamicPPL.loosen_types!!(vnv, k, e, t)) == 0 - end -end - -@testset "VarInfo + VarNamedVector" begin - models = DynamicPPL.TestUtils.ALL_MODELS - @testset "$(model.f)" for model in models - # NOTE: Need to set random seed explicitly to avoid using the same seed - # for initialization as for sampling in the inner testset below. - Random.seed!(42) - value_true = DynamicPPL.TestUtils.rand_prior_true(model) - vns = DynamicPPL.TestUtils.varnames(model) - varnames = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, value_true, varnames; include_threadsafe=false - ) - # Filter out those which are not based on `VarNamedVector`. - varinfos = filter(DynamicPPL.has_varnamedvector, varinfos) - # Get the true log joint. - logp_true = DynamicPPL.TestUtils.logjoint_true(model, value_true...) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - # Need to make sure we're using a different random seed from the - # one used in the above call to `rand_prior_true`. - Random.seed!(43) - - # Are values correct? - DynamicPPL.TestUtils.test_values(varinfo, value_true, vns) - - # Is evaluation correct? - varinfo_eval = last(DynamicPPL.evaluate!!(model, deepcopy(varinfo))) - # Log density should be the same. - @test getlogjoint(varinfo_eval) ≈ logp_true - # Values should be the same. - DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) - - # Is sampling correct? - varinfo_sample = last(DynamicPPL.init!!(model, deepcopy(varinfo))) - # Log density should be different. - @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) - # Values should be different. - DynamicPPL.TestUtils.test_values( - varinfo_sample, value_true, vns; compare=!isequal - ) - end - end -end