From 0413fbe892d38bd3608ac1668b059bd37932ebc5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Apr 2025 18:21:17 +0100 Subject: [PATCH 01/12] Unify {Untyped,Typed}{Vector,}VarInfo constructors --- HISTORY.md | 21 ++++++++- src/deprecated.jl | 12 +++++ src/varinfo.jl | 117 +++++++++++++++++++++++++++------------------- 3 files changed, 102 insertions(+), 48 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index a956bd188..30ebcbf57 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,10 +4,29 @@ **Breaking changes** -### VarInfo constructor +### VarInfo constructors `VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. +**The `VarInfo([rng, ]model[, sampler, context, metadata])` constructor has been replaced with the following methods:** + + 1. `UntypedVarInfo([rng, ]model[, sampler, context])` + 2. `TypedVarInfo([rng, ]model[, sampler, context])` + 3. `DynamicPPL.UntypedVectorVarInfo([rng, ]model[, sampler, context])` + 4. `DynamicPPL.TypedVectorVarInfo([rng, ]model[, sampler, context])` + +**If you were not using the `metadata` argument (most likely), then you can directly replace calls to this constructor with `TypedVarInfo` instead.** +That is to say, if you were using `VarInfo(model)`, you can replace this with `TypedVarInfo(model)`. + +If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `TypedVectorVarInfo` instead. +Note that the `VectorVarInfo` constructors (both `Untyped` and `Typed`) are not exported by default. + +If you were passing a non-empty metadata argument, you should use a different constructor of `VarInfo` instead. + +The reason for this change is that there were several flavours of VarInfo. +Some, like TypedVarInfo, were easy to construct because we had convenience methods for them; however, the others were more difficult. +This change makes it easier to access different VarInfo types, and also makes it more explicit which one you are constructing. + ### VarName prefixing behaviour The way in which VarNames in submodels are prefixed has been changed. diff --git a/src/deprecated.jl b/src/deprecated.jl index 0bcaae9b7..bc00d0aec 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -1 +1,13 @@ @deprecate generated_quantities(model, params) returned(model, params) + +Base.@deprecate VarInfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) TypedVarInfo(rng, model, sampler, context) +Base.@deprecate VarInfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) TypedVarInfo(model, sampler, context) diff --git a/src/varinfo.jl b/src/varinfo.jl index 94b1f1c07..485f4599d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -94,8 +94,17 @@ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo logp::Base.RefValue{Tlogp} num_produce::Base.RefValue{Int} end -const VectorVarInfo = VarInfo{<:VarNamedVector} +const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} const UntypedVarInfo = VarInfo{<:Metadata} +# TODO: TypedVarInfo carries no information about the type of the actual +# metadata i.e. the elements of the NamedTuple. It could be Metadata or it +# could be VarNamedVector. Calling TypedVarInfo(model) will result in a +# TypedVarInfo where the elements are Metadata. +# Resolving this ambiguity would likely require us to replace NamedTuple with +# something which carried both its keys as well as its values' types as type +# parameters. +# Note that below we also define a function TypedVectorVarInfo, which generates +# a TypedVarInfo where the metadata is a NamedTuple of VarNameVectors. const TypedVarInfo = VarInfo{<:NamedTuple} const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} @@ -132,70 +141,84 @@ function metadata_to_varnamedvector(md::Metadata) ) end -function VectorVarInfo(vi::UntypedVarInfo) - md = metadata_to_varnamedvector(vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) -end - -function VectorVarInfo(vi::TypedVarInfo) - md = map(metadata_to_varnamedvector, vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) -end - function has_varnamedvector(vi::VarInfo) return vi.metadata isa VarNamedVector || (vi isa TypedVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata))) end +######################## +# VarInfo constructors # +######################## + """ - untyped_varinfo(model[, context, metadata]) + UntypedVarInfo([rng, ]model[, sampler, context, metadata]) Return an untyped varinfo object for the given `model` and `context`. # Arguments -- `model::Model`: The model for which to create the varinfo object. -- `context::AbstractContext`: The context in which to evaluate the model. Default: `SamplingContext()`. -- `metadata::Union{Metadata,VarNamedVector}`: The metadata to use for the varinfo object. - Default: `Metadata()`. +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ -function untyped_varinfo( +function UntypedVarInfo( + rng::Random.AbstractRNG, model::Model, - context::AbstractContext=SamplingContext(), - metadata::Union{Metadata,VarNamedVector}=Metadata(), + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), ) - varinfo = VarInfo(metadata) - return last( - evaluate!!(model, varinfo, hassampler(context) ? context : SamplingContext(context)) - ) + varinfo = VarInfo(Metadata()) + context = SamplingContext(rng, sampler, context) + return last(evaluate!!(model, varinfo, context)) +end +function UntypedVarInfo(model::Model, args::Union{AbstractSampler,AbstractContext}...) + return UntypedVarInfo(Random.default_rng(), model, args...) end -""" - typed_varinfo(model[, context, metadata]) - -Return a typed varinfo object for the given `model`, `sampler` and `context`. - -This simply calls [`DynamicPPL.untyped_varinfo`](@ref) and converts the resulting -varinfo object to a typed varinfo object. - -See also: [`DynamicPPL.untyped_varinfo`](@ref) -""" -typed_varinfo(args...) = TypedVarInfo(untyped_varinfo(args...)) +function TypedVarInfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + return TypedVarInfo(UntypedVarInfo(rng, model, sampler, context)) +end +function TypedVarInfo(model::Model, args::Union{AbstractSampler,AbstractContext}...) + return TypedVarInfo(Random.default_rng(), model, args...) +end -function VarInfo( +function UntypedVectorVarInfo(vi::UntypedVarInfo) + md = metadata_to_varnamedvector(vi.metadata) + lp = getlogp(vi) + return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) +end +function UntypedVectorVarInfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), - metadata::Union{Metadata,VarNamedVector}=Metadata(), ) - return typed_varinfo(model, SamplingContext(rng, sampler, context), metadata) + return UntypedVectorVarInfo(UntypedVarInfo(rng, model, sampler, context)) +end +function UntypedVectorVarInfo(model::Model, args::Union{AbstractSampler,AbstractContext}...) + return UntypedVectorVarInfo(UntypedVarInfo(Random.default_rng(), model, args...)) end -function VarInfo( - model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}... + +function TypedVectorVarInfo(vi::TypedVarInfo) + md = map(metadata_to_varnamedvector, vi.metadata) + lp = getlogp(vi) + return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) +end +function TypedVectorVarInfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), ) - return VarInfo(Random.default_rng(), model, args...) + return TypedVectorVarInfo(TypedVarInfo(rng, model, sampler, context)) +end +function TypedVectorVarInfo(model::Model, args::Union{AbstractSampler,AbstractContext}...) + return TypedVectorVarInfo(Random.default_rng(), model, args...) end """ @@ -749,7 +772,7 @@ end VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) -function TypedVarInfo(vi::VectorVarInfo) +function TypedVarInfo(vi::UntypedVectorVarInfo) new_metas = group_by_symbol(vi.metadata) logp = getlogp(vi) num_produce = get_num_produce(vi) @@ -1627,12 +1650,12 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) return vi end -function Base.push!(vi::VectorVarInfo, vn::VarName, val, args...) +function Base.push!(vi::UntypedVectorVarInfo, vn::VarName, val, args...) push!(getmetadata(vi, vn), vn, val, args...) return vi end -function Base.push!(vi::VectorVarInfo, pair::Pair, args...) +function Base.push!(vi::UntypedVectorVarInfo, pair::Pair, args...) vn, val = pair return push!(vi, vn, val, args...) end @@ -2061,8 +2084,8 @@ function values_as( return ConstructionBase.constructorof(D)(iter) end -values_as(vi::VectorVarInfo, args...) = values_as(vi.metadata, args...) -values_as(vi::VectorVarInfo, T::Type{Vector}) = values_as(vi.metadata, T) +values_as(vi::UntypedVectorVarInfo, args...) = values_as(vi.metadata, args...) +values_as(vi::UntypedVectorVarInfo, T::Type{Vector}) = values_as(vi.metadata, T) function values_from_metadata(md::Metadata) return ( From 7b3310350d2423ef2d700cddcc66b45032f81082 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Apr 2025 18:38:45 +0100 Subject: [PATCH 02/12] Update invocations --- benchmarks/src/DynamicPPLBenchmarks.jl | 6 +- docs/src/api.md | 7 -- docs/src/internals/varinfo.md | 6 +- ext/DynamicPPLJETExt.jl | 4 +- src/abstract_varinfo.jl | 6 +- src/experimental.jl | 6 +- src/logdensityfunction.jl | 4 +- src/model.jl | 12 +- src/pointwise_logdensities.jl | 6 +- src/sampler.jl | 2 +- src/submodel_macro.jl | 4 +- src/values_as_in_model.jl | 2 +- src/varinfo.jl | 148 ++++++++++++------------- test/ad.jl | 2 +- test/compiler.jl | 34 +++--- test/deprecated.jl | 4 +- test/ext/DynamicPPLForwardDiffExt.jl | 2 +- test/ext/DynamicPPLJETExt.jl | 2 +- test/model.jl | 17 ++- test/pointwise_logdensities.jl | 2 +- test/sampler.jl | 2 +- test/simple_varinfo.jl | 2 +- test/test_util.jl | 2 +- test/threadsafe.jl | 4 +- test/utils.jl | 2 +- test/varinfo.jl | 36 +++--- 26 files changed, 152 insertions(+), 172 deletions(-) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 4c73bf355..a0f45a81b 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -67,11 +67,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: suite = BenchmarkGroup() vi = if varinfo_choice == :untyped - vi = VarInfo() - model(rng, vi) - vi + UntypedVarInfo(rng, model) elseif varinfo_choice == :typed - VarInfo(rng, model) + TypedVarInfo(rng, model) elseif varinfo_choice == :simple_namedtuple SimpleVarInfo{Float64}(model(rng)) elseif varinfo_choice == :simple_dict diff --git a/docs/src/api.md b/docs/src/api.md index 2f6376f5d..e33d49cc5 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -291,13 +291,6 @@ AbstractVarInfo But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary. -For constructing the "default" typed and untyped varinfo types used in DynamicPPL (see [the section on varinfo design](@ref "Design of `VarInfo`") for more on this), we have the following two methods: - -```@docs -DynamicPPL.untyped_varinfo -DynamicPPL.typed_varinfo -``` - #### `VarInfo` ```@docs diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md index e6e1f2619..03f6e22c0 100644 --- a/docs/src/internals/varinfo.md +++ b/docs/src/internals/varinfo.md @@ -79,13 +79,13 @@ For example, with the model above we have ```@example varinfo-design # Type-unstable `VarInfo` -varinfo_untyped = DynamicPPL.untyped_varinfo(demo()) +varinfo_untyped = DynamicPPL.UntypedVarInfo(demo()) typeof(varinfo_untyped.metadata) ``` ```@example varinfo-design # Type-stable `VarInfo` -varinfo_typed = DynamicPPL.typed_varinfo(demo()) +varinfo_typed = DynamicPPL.TypedVarInfo(demo()) typeof(varinfo_typed.metadata) ``` @@ -154,7 +154,7 @@ For example, we want to optimize code-paths which effectively boil down to inner ```julia # Construct a `VarInfo` with types inferred from `model`. -varinfo = VarInfo(model) +varinfo = TypedVarInfo(model) # Repeatedly sample from `model`. for _ in 1:num_samples diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl index aa95093f2..5d2d01662 100644 --- a/ext/DynamicPPLJETExt.jl +++ b/ext/DynamicPPLJETExt.jl @@ -27,7 +27,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true ) # First we try with the typed varinfo. - varinfo = DynamicPPL.typed_varinfo(model, context) + varinfo = DynamicPPL.TypedVarInfo(model, context) # Let's make sure that both evaluation and sampling doesn't result in type errors. issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( @@ -46,7 +46,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( else # Warn the user that we can't use the type stable one. @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." - DynamicPPL.untyped_varinfo(model, context) + DynamicPPL.UntypedVarInfo(model, context) end end diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 44edaa4e9..c00c93391 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -251,7 +251,7 @@ julia> values_as(SimpleVarInfo(data), Vector) ```jldoctest julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); + vi = TypedVarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; @@ -277,7 +277,7 @@ julia> values_as(vi, Vector) ```jldoctest julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = VarInfo(); DynamicPPL.TestUtils.demo_assume_dot_observe()(vi); + vi = UntypedVarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; @@ -354,7 +354,7 @@ demo (generic function with 2 methods) julia> model = demo(); -julia> varinfo = VarInfo(model); +julia> varinfo = TypedVarInfo(model); julia> keys(varinfo) 4-element Vector{VarName}: diff --git a/src/experimental.jl b/src/experimental.jl index 84038803c..ff25b96e0 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -72,7 +72,7 @@ julia> # Typed varinfo cannot handle this random support model properly ┌ Warning: Model seems incompatible with typed varinfo. Falling back to untyped varinfo. └ @ DynamicPPLJETExt ~/.julia/dev/DynamicPPL.jl/ext/DynamicPPLJETExt.jl:48 -julia> vi isa typeof(DynamicPPL.untyped_varinfo(model)) +julia> vi isa typeof(DynamicPPL.UntypedVarInfo(model)) true julia> # In contrast, a simple model with no random support can be handled by typed varinfo. @@ -81,7 +81,7 @@ model_with_static_support (generic function with 2 methods) julia> vi = determine_suitable_varinfo(model_with_static_support()); -julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support())) +julia> vi isa typeof(DynamicPPL.TypedVarInfo(model_with_static_support())) true ``` """ @@ -97,7 +97,7 @@ function determine_suitable_varinfo( # Warn the user. @warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo." # Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat). - DynamicPPL.typed_varinfo(model, context) + DynamicPPL.TypedVarInfo(model, context) end end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index a42855f05..26025d5c7 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -79,7 +79,7 @@ julia> LogDensityProblems.logdensity(f, [0.0]) -2.3378770664093453 julia> # This also respects the context in `model`. - f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), VarInfo(model)); + f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), TypedVarInfo(model)); julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true @@ -109,7 +109,7 @@ struct LogDensityFunction{ function LogDensityFunction( model::Model, - varinfo::AbstractVarInfo=VarInfo(model), + varinfo::AbstractVarInfo=TypedVarInfo(model), context::AbstractContext=leafcontext(model.context); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) diff --git a/src/model.jl b/src/model.jl index b4d5f6bb7..7ac1e1e1e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1037,7 +1037,7 @@ julia> logjoint(demo_model([1., 2.]), chain); ``` """ function logjoint(model::Model, chain::AbstractMCMC.AbstractChains) - var_info = VarInfo(model) # extract variables info from the model + var_info = TypedVarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) argvals_dict = OrderedDict( vn_parent => @@ -1084,7 +1084,7 @@ julia> logprior(demo_model([1., 2.]), chain); ``` """ function logprior(model::Model, chain::AbstractMCMC.AbstractChains) - var_info = VarInfo(model) # extract variables info from the model + var_info = TypedVarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) argvals_dict = OrderedDict( vn_parent => @@ -1131,7 +1131,7 @@ julia> loglikelihood(demo_model([1., 2.]), chain); ``` """ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) - var_info = VarInfo(model) # extract variables info from the model + var_info = TypedVarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) argvals_dict = OrderedDict( vn_parent => @@ -1339,7 +1339,7 @@ julia> @model function demo2(x, y) When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: ```jldoctest submodel-to_submodel -julia> vi = VarInfo(demo2(missing, 0.4)); +julia> vi = TypedVarInfo(demo2(missing, 0.4)); julia> @varname(a.x) in keys(vi) true @@ -1376,7 +1376,7 @@ julia> @model function demo2_no_prefix(x, z) return z ~ Uniform(-a, 1) end; -julia> vi = VarInfo(demo2_no_prefix(missing, 0.4)); +julia> vi = TypedVarInfo(demo2_no_prefix(missing, 0.4)); julia> @varname(x) in keys(vi) # here we just use `x` instead of `a.x` true @@ -1391,7 +1391,7 @@ julia> @model function demo2(x, y, z) return z ~ Uniform(-a, b) end; -julia> vi = VarInfo(demo2(missing, missing, 0.4)); +julia> vi = TypedVarInfo(demo2(missing, missing, 0.4)); julia> @varname(sub1.x) in keys(vi) true diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index cb9ea4894..b69ccce67 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -226,12 +226,12 @@ julia> @model function demo(x) julia> m = demo([1.0, ]); -julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x[1])]) +julia> ℓ = pointwise_logdensities(m, TypedVarInfo(m)); first(ℓ[@varname(x[1])]) -1.4189385332046727 julia> m = demo([1.0; 1.0]); -julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) +julia> ℓ = pointwise_logdensities(m, TypedVarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) (-1.4189385332046727, -1.4189385332046727) ``` @@ -240,7 +240,7 @@ function pointwise_logdensities( model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() ) where {T} # Get the data by executing the model once - vi = VarInfo(model) + vi = TypedVarInfo(model) point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) diff --git a/src/sampler.jl b/src/sampler.jl index ff008cc93..c15ac6006 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -86,7 +86,7 @@ function default_varinfo( context::AbstractContext, ) init_sampler = initialsampler(sampler) - return VarInfo(rng, model, init_sampler, context) + return TypedVarInfo(rng, model, init_sampler, context) end function AbstractMCMC.sample( diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index f6b9c4479..3089a9a97 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -24,7 +24,7 @@ julia> @model function demo2(x, y) When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: ```jldoctest submodel -julia> vi = VarInfo(demo2(missing, 0.4)); +julia> vi = TypedVarInfo(demo2(missing, 0.4)); ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -91,7 +91,7 @@ julia> @model function demo2(x, y, z) When we sample from the model `demo2(missing, missing, 0.4)` random variables `sub1.x` and `sub2.x` will be sampled: ```jldoctest submodelprefix -julia> vi = VarInfo(demo2(missing, missing, 0.4)); +julia> vi = TypedVarInfo(demo2(missing, missing, 0.4)); ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index d3bfd697a..6d4cf019d 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -131,7 +131,7 @@ julia> @model function model_changing_support() julia> model = model_changing_support(); julia> # Construct initial type-stable `VarInfo`. - varinfo = VarInfo(rng, model); + varinfo = TypedVarInfo(rng, model); julia> # Link it so it works in unconstrained space. varinfo_linked = DynamicPPL.link(varinfo, model); diff --git a/src/varinfo.jl b/src/varinfo.jl index 485f4599d..d3d18accb 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -94,6 +94,8 @@ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo logp::Base.RefValue{Tlogp} num_produce::Base.RefValue{Int} end +VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) + const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} const UntypedVarInfo = VarInfo{<:Metadata} # TODO: TypedVarInfo carries no information about the type of the actual @@ -171,10 +173,71 @@ function UntypedVarInfo( context = SamplingContext(rng, sampler, context) return last(evaluate!!(model, varinfo, context)) end -function UntypedVarInfo(model::Model, args::Union{AbstractSampler,AbstractContext}...) - return UntypedVarInfo(Random.default_rng(), model, args...) +function UntypedVarInfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + return UntypedVarInfo(Random.default_rng(), model, sampler, context) end +function UntypedVarInfo(model::Model, context::AbstractContext=DefaultContext()) + return UntypedVarInfo(Random.default_rng(), model, SampleFromPrior(), context) +end + +""" + TypedVarInfo(vi::UntypedVarInfo) + +This function finds all the unique `sym`s from the instances of `VarName{sym}` found in +`vi.metadata.vns`. It then extracts the metadata associated with each symbol from the +global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as +a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each +symbol. +""" +function TypedVarInfo(vi::UntypedVarInfo) + meta = vi.metadata + new_metas = Metadata[] + # Symbols of all instances of `VarName{sym}` in `vi.vns` + syms_tuple = Tuple(syms(vi)) + for s in syms_tuple + # Find all indices in `vns` with symbol `s` + inds = findall(vn -> getsym(vn) === s, meta.vns) + n = length(inds) + # New `vns` + sym_vns = getindex.((meta.vns,), inds) + # New idcs + sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) + # New dists + sym_dists = getindex.((meta.dists,), inds) + # New orders + sym_orders = getindex.((meta.orders,), inds) + # New flags + sym_flags = Dict(a => meta.flags[a][inds] for a in keys(meta.flags)) + # Extract new ranges and vals + _ranges = getindex.((meta.ranges,), inds) + # `copy.()` is a workaround to reduce the eltype from Real to Int or Float64 + _vals = [copy.(meta.vals[_ranges[i]]) for i in 1:n] + sym_ranges = Vector{eltype(_ranges)}(undef, n) + start = 0 + for i in 1:n + sym_ranges[i] = (start + 1):(start + length(_vals[i])) + start += length(_vals[i]) + end + sym_vals = foldl(vcat, _vals) + + push!( + new_metas, + Metadata( + sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_orders, sym_flags + ), + ) + end + logp = getlogp(vi) + num_produce = get_num_produce(vi) + nt = NamedTuple{syms_tuple}(Tuple(new_metas)) + return VarInfo(nt, Ref(logp), Ref(num_produce)) +end +TypedVarInfo(vi::TypedVarInfo) = vi function TypedVarInfo( rng::Random.AbstractRNG, model::Model, @@ -209,13 +272,20 @@ function TypedVectorVarInfo(vi::TypedVarInfo) lp = getlogp(vi) return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) end +function TypedVectorVarInfo(vi::UntypedVectorVarInfo) + new_metas = group_by_symbol(vi.metadata) + logp = getlogp(vi) + num_produce = get_num_produce(vi) + nt = NamedTuple(new_metas) + return VarInfo(nt, Ref(logp), Ref(num_produce)) +end function TypedVectorVarInfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) - return TypedVectorVarInfo(TypedVarInfo(rng, model, sampler, context)) + return TypedVectorVarInfo(UntypedVectorVarInfo(rng, model, sampler, context)) end function TypedVectorVarInfo(model::Model, args::Union{AbstractSampler,AbstractContext}...) return TypedVectorVarInfo(Random.default_rng(), model, args...) @@ -264,11 +334,6 @@ end unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) -# without AbstractSampler -function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) - return VarInfo(rng, model, SampleFromPrior(), context) -end - #### #### Internal functions #### @@ -768,73 +833,6 @@ end #### APIs for typed and untyped VarInfo #### -# VarInfo - -VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) - -function TypedVarInfo(vi::UntypedVectorVarInfo) - new_metas = group_by_symbol(vi.metadata) - logp = getlogp(vi) - num_produce = get_num_produce(vi) - nt = NamedTuple(new_metas) - return VarInfo(nt, Ref(logp), Ref(num_produce)) -end - -""" - TypedVarInfo(vi::UntypedVarInfo) - -This function finds all the unique `sym`s from the instances of `VarName{sym}` found in -`vi.metadata.vns`. It then extracts the metadata associated with each symbol from the -global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as -a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each -symbol. -""" -function TypedVarInfo(vi::UntypedVarInfo) - meta = vi.metadata - new_metas = Metadata[] - # Symbols of all instances of `VarName{sym}` in `vi.vns` - syms_tuple = Tuple(syms(vi)) - for s in syms_tuple - # Find all indices in `vns` with symbol `s` - inds = findall(vn -> getsym(vn) === s, meta.vns) - n = length(inds) - # New `vns` - sym_vns = getindex.((meta.vns,), inds) - # New idcs - sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) - # New dists - sym_dists = getindex.((meta.dists,), inds) - # New orders - sym_orders = getindex.((meta.orders,), inds) - # New flags - sym_flags = Dict(a => meta.flags[a][inds] for a in keys(meta.flags)) - - # Extract new ranges and vals - _ranges = getindex.((meta.ranges,), inds) - # `copy.()` is a workaround to reduce the eltype from Real to Int or Float64 - _vals = [copy.(meta.vals[_ranges[i]]) for i in 1:n] - sym_ranges = Vector{eltype(_ranges)}(undef, n) - start = 0 - for i in 1:n - sym_ranges[i] = (start + 1):(start + length(_vals[i])) - start += length(_vals[i]) - end - sym_vals = foldl(vcat, _vals) - - push!( - new_metas, - Metadata( - sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_orders, sym_flags - ), - ) - end - logp = getlogp(vi) - num_produce = get_num_produce(vi) - nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, Ref(logp), Ref(num_produce)) -end -TypedVarInfo(vi::TypedVarInfo) = vi - function BangBang.empty!!(vi::VarInfo) _empty!(vi.metadata) resetlogp!!(vi) diff --git a/test/ad.jl b/test/ad.jl index a4f3dbfa7..3b665fd37 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -100,7 +100,7 @@ using DynamicPPL: LogDensityFunction # Compiling the ReverseDiff tape used to fail here spl = Sampler(MyEmptyAlg()) - vi = VarInfo(model) + vi = TypedVarInfo(model) ldf = LogDensityFunction( model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true) ) diff --git a/test/compiler.jl b/test/compiler.jl index a0286d405..7ce4edafc 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -193,7 +193,7 @@ module Issue537 end return x end model = testmodel_missing3([1.0]) - varinfo = VarInfo(model) + varinfo = TypedVarInfo(model) @test getlogp(varinfo) == lp @test varinfo_ isa AbstractVarInfo @test model_ === model @@ -213,7 +213,7 @@ module Issue537 end end false lpold = lp model = testmodel_missing4([1.0]) - varinfo = VarInfo(model) + varinfo = TypedVarInfo(model) @test getlogp(varinfo) == lp == lpold # test DPPL#61 @@ -234,13 +234,13 @@ module Issue537 end end end x = [1.0, missing] - VarInfo(gdemo(x)) + TypedVarInfo(gdemo(x)) @test ismissing(x[2]) # https://github.com/TuringLang/Turing.jl/issues/1464#issuecomment-731153615 - vi = VarInfo(gdemo(x)) + vi = TypedVarInfo(gdemo(x)) @test haskey(vi.metadata, :x) - vi = VarInfo(gdemo(x)) + vi = TypedVarInfo(gdemo(x)) @test haskey(vi.metadata, :x) # Non-array variables @@ -339,16 +339,16 @@ module Issue537 end return testmodel end model = makemodel(0.5)([1.0]) - varinfo = VarInfo(model) + varinfo = TypedVarInfo(model) @test getlogp(varinfo) == lp end @testset "user-defined variable name" begin @model f1() = x ~ NamedDist(Normal(), :y) @model f2() = x ~ NamedDist(Normal(), @varname(y[2][:, 1])) @model f3() = x ~ NamedDist(Normal(), @varname(y[1])) - vi1 = VarInfo(f1()) - vi2 = VarInfo(f2()) - vi3 = VarInfo(f3()) + vi1 = TypedVarInfo(f1()) + vi2 = TypedVarInfo(f2()) + vi3 = TypedVarInfo(f3()) @test haskey(vi1.metadata, :y) @test first(Base.keys(vi1.metadata.y)) == @varname(y) @test haskey(vi2.metadata, :y) @@ -434,28 +434,28 @@ module Issue537 end end # No observation. m = demo2(missing, missing) - vi = VarInfo(m) + vi = TypedVarInfo(m) ks = keys(vi) @test @varname(x) ∈ ks @test @varname(y) ∈ ks # Observation in top-level. m = demo2(missing, 1.0) - vi = VarInfo(m) + vi = TypedVarInfo(m) ks = keys(vi) @test @varname(x) ∈ ks @test @varname(y) ∉ ks # Observation in nested model. m = demo2(1000.0, missing) - vi = VarInfo(m) + vi = TypedVarInfo(m) ks = keys(vi) @test @varname(x) ∉ ks @test @varname(y) ∈ ks # Observe all. m = demo2(1000.0, 0.5) - vi = VarInfo(m) + vi = TypedVarInfo(m) ks = keys(vi) @test isempty(ks) @@ -479,7 +479,7 @@ module Issue537 end return z ~ Normal(sub1 + sub2 + 100, 1.0) end m = demo_useval(missing, missing) - vi = VarInfo(m) + vi = TypedVarInfo(m) ks = keys(vi) @test @varname(sub1.x) ∈ ks @test @varname(sub2.x) ∈ ks @@ -512,7 +512,7 @@ module Issue537 end ys = [randn(10), randn(10)] m = demo(ys) - vi = VarInfo(m) + vi = TypedVarInfo(m) for vn in [@varname(α), @varname(μ), @varname(σ), @varname(ar1_1.η), @varname(ar1_2.η)] @@ -600,7 +600,7 @@ module Issue537 end # Make sure that a return-value of `x = 1` isn't combined into # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. @model empty_model() = return x = 1 - empty_vi = VarInfo() + empty_vi = TypedVarInfo() retval_and_vi = DynamicPPL.evaluate!!(empty_model(), empty_vi, SamplingContext()) @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} @@ -740,7 +740,7 @@ module Issue537 end @test model() isa NamedTuple{(:x, :y)} # `VarInfo` should only contain `x`. - varinfo = VarInfo(model) + varinfo = TypedVarInfo(model) @test haskey(varinfo, @varname(x)) @test !haskey(varinfo, @varname(y)) diff --git a/test/deprecated.jl b/test/deprecated.jl index 500d3eb7f..25c4fc2b9 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -29,7 +29,7 @@ return a, b end @test outer()() isa Tuple{Float64,Float64} - vi = VarInfo(outer()) + vi = TypedVarInfo(outer()) @test @varname(x) in keys(vi) @test @varname(sub.x) in keys(vi) end @@ -46,7 +46,7 @@ @test model() == y_val x_val = 1.5 - vi = VarInfo(outer(y_val)) + vi = TypedVarInfo(outer(y_val)) DynamicPPL.setindex!!(vi, x_val, @varname(x)) @test logprior(model, vi) ≈ logpdf(Normal(), x_val) @test loglikelihood(model, vi) ≈ logpdf(Normal(x_val), y_val) diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl index 73a0510e9..e9b8d381c 100644 --- a/test/ext/DynamicPPLForwardDiffExt.jl +++ b/test/ext/DynamicPPLForwardDiffExt.jl @@ -13,7 +13,7 @@ using Test: @test, @testset MODEL_SIZE = 10 @model f() = x ~ MvNormal(zeros(MODEL_SIZE), I) model = f() - varinfo = VarInfo(model) + varinfo = TypedVarInfo(model) context = DefaultContext() @testset "Chunk size setting" for chunksize in (nothing, 0) diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 933bfb1d1..78f972b07 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -79,7 +79,7 @@ @test is_typed # If the test failed, check why it didn't infer a typed varinfo if !is_typed - typed_vi = VarInfo(model) + typed_vi = TypedVarInfo(model) f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, typed_vi ) diff --git a/test/model.jl b/test/model.jl index 447a9ecaa..984b8b196 100644 --- a/test/model.jl +++ b/test/model.jl @@ -36,7 +36,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() model = GDEMO_DEFAULT # sample from model and extract variables - vi = VarInfo(model) + vi = TypedVarInfo(model) s = vi[@varname(s)] m = vi[@varname(m)] @@ -67,7 +67,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Construct mapping of varname symbols to varname-parent symbols. # Here, varname_leaves is used to ensure compatibility with the # variables stored in the chain - var_info = VarInfo(model) + var_info = TypedVarInfo(model) chain_sym_map = Dict{Symbol,Symbol}() for vn_parent in keys(var_info) sym = DynamicPPL.getsym(vn_parent) @@ -219,7 +219,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() model = GDEMO_DEFAULT # sample from model and extract variables - vi = VarInfo(model) + vi = TypedVarInfo(model) # Second component of return-value of `evaluate!!` should # be a `DynamicPPL.AbstractVarInfo`. @@ -233,8 +233,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "Dynamic constraints, Metadata" begin model = DynamicPPL.TestUtils.demo_dynamic_constraint() - spl = SampleFromPrior() - vi = VarInfo(model, spl, DefaultContext(), DynamicPPL.Metadata()) + vi = TypedVarInfo(model) vi = link!!(vi, model) for i in 1:10 @@ -250,7 +249,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "Dynamic constraints, VectorVarInfo" begin model = DynamicPPL.TestUtils.demo_dynamic_constraint() for i in 1:10 - vi = VarInfo(model) + vi = TypedVarInfo(model) @test vi[@varname(x)] >= vi[@varname(m)] end end @@ -454,7 +453,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end for model in (outer_auto_prefix(), outer_manual_prefix()) - vi = VarInfo(model) + vi = TypedVarInfo(model) vns = Set(keys(values_as_in_model(model, false, vi))) @test vns == Set([@varname(a.x), @varname(b.x)]) end @@ -482,8 +481,8 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() model = product_dirichlet() varinfos = [ - DynamicPPL.untyped_varinfo(model), - DynamicPPL.typed_varinfo(model), + DynamicPPL.UntypedVarInfo(model), + DynamicPPL.TypedVarInfo(model), DynamicPPL.typed_simple_varinfo(model), DynamicPPL.untyped_simple_varinfo(model), ] diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 61c842638..a7c99adae 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -5,7 +5,7 @@ example_values = DynamicPPL.TestUtils.rand_prior_true(model) # Instantiate a `VarInfo` with the example values. - vi = VarInfo(model) + vi = TypedVarInfo(model) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end diff --git a/test/sampler.jl b/test/sampler.jl index 8c4f1ed96..addcf5202 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -192,7 +192,7 @@ initial_z = 15 initial_x = [0.2, 0.5] model = constrained_uniform(n) - vi = VarInfo(model) + vi = TypedVarInfo(model) @test_throws ArgumentError DynamicPPL.initialize_parameters!!( vi, [initial_z, initial_x], model diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 8e48814a4..60f392f92 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -92,7 +92,7 @@ SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), SimpleVarInfo(DynamicPPL.VarNamedVector()), - VarInfo(model), + TypedVarInfo(model), ) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) diff --git a/test/test_util.jl b/test/test_util.jl index 87c69b5fe..2c5181b0c 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -10,7 +10,7 @@ const gdemo_default = gdemo_d() # TODO(penelopeysm): Remove this (and also test/compat/ad.jl) function test_model_ad(model, logp_manual) - vi = VarInfo(model) + vi = TypedVarInfo(model) x = vi[:] # Log probabilities using the model. diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 72c439db8..e1490c280 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -1,6 +1,6 @@ @testset "threadsafe.jl" begin @testset "constructor" begin - vi = VarInfo(gdemo_default) + vi = TypedVarInfo(gdemo_default) threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi) @test threadsafe_vi.varinfo === vi @@ -11,7 +11,7 @@ # TODO: Add more tests of the public API @testset "API" begin - vi = VarInfo(gdemo_default) + vi = TypedVarInfo(gdemo_default) threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi) lp = getlogp(vi) diff --git a/test/utils.jl b/test/utils.jl index d683f132d..91bba082e 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -7,7 +7,7 @@ end model = testmodel() - varinfo = VarInfo(model) + varinfo = TypedVarInfo(model) @test iszero(lp_before) @test getlogp(varinfo) == lp_after == 42 end diff --git a/test/varinfo.jl b/test/varinfo.jl index 74feb42f6..42e85ba8c 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -43,8 +43,7 @@ end end model = gdemo(1.0, 2.0) - vi = VarInfo(DynamicPPL.Metadata()) - model(vi, SampleFromUniform()) + vi = UntypedVarInfo(model, SampleFromUniform()) tvi = TypedVarInfo(vi) meta = vi.metadata @@ -160,7 +159,7 @@ end unset_flag!(vi, vn_x, "del") @test !is_flagged(vi, vn_x, "del") end - vi = VarInfo(DynamicPPL.Metadata()) + vi = VarInfo() test_varinfo!(vi) test_varinfo!(empty!!(TypedVarInfo(vi))) end @@ -206,16 +205,10 @@ end m_vns = model == model_uv ? [@varname(m[i]) for i in 1:5] : @varname(m) s_vns = @varname(s) - vi_typed = VarInfo( - model, SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata() - ) - vi_untyped = VarInfo(DynamicPPL.Metadata()) - vi_vnv = VarInfo(DynamicPPL.VarNamedVector()) - vi_vnv_typed = VarInfo( - model, SampleFromPrior(), DefaultContext(), DynamicPPL.VarNamedVector() - ) - model(vi_untyped, SampleFromPrior()) - model(vi_vnv, SampleFromPrior()) + vi_typed = TypedVarInfo(model) + vi_untyped = UntypedVarInfo(model) + vi_vnv = DynamicPPL.UntypedVectorVarInfo(model) + vi_vnv_typed = DynamicPPL.TypedVectorVarInfo(model) model_name = model == model_uv ? "univariate" : "multivariate" @testset "$(model_name), $(short_varinfo_name(vi))" for vi in [ @@ -318,7 +311,7 @@ end return x ~ filldist(MvNormal([1, 100], I), 2) end - vi = VarInfo(demo()) + vi = TypedVarInfo(demo()) vals_prev = vi.metadata.x.vals ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])] DynamicPPL.setval!(vi, vi.metadata.x.vals, ks) @@ -339,7 +332,7 @@ end `m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. """ function test_setval!(model, chain; sample_idx=1, chain_idx=1) - var_info = VarInfo(model) + var_info = TypedVarInfo(model) θ_old = var_info[:] DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) θ_new = var_info[:] @@ -460,8 +453,7 @@ end # in the unconstrained space for `VarInfo` without having `vn` # present in the `varinfo`. ## `UntypedVarInfo` - vi = VarInfo() - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + vi = UntypedVarInfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) @@ -470,7 +462,7 @@ end @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `TypedVarInfo` - vi = VarInfo(model) + vi = TypedVarInfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) @@ -821,8 +813,8 @@ end model_left = demo_merge_different_y() model_right = demo_merge_different_z() - varinfo_left = VarInfo(model_left) - varinfo_right = VarInfo(model_right) + varinfo_left = TypedVarInfo(model_left) + varinfo_right = TypedVarInfo(model_right) varinfo_right = DynamicPPL.settrans!!(varinfo_right, true, @varname(x)) varinfo_merged = merge(varinfo_left, varinfo_right) @@ -877,7 +869,7 @@ end return x end model1 = demo_dot(1) - varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1) + varinfo1 = DynamicPPL.link!!(DynamicPPL.UntypedVarInfo(model1), model1) # Sampling from `model2` should hit the `istrans(vi) == true` branches # because all the existing variables are linked. model2 = demo_dot(2) @@ -1020,7 +1012,7 @@ end @testset "issue #842" begin model = DynamicPPL.TestUtils.DEMO_MODELS[1] - varinfo = VarInfo(model) + varinfo = TypedVarInfo(model) n = length(varinfo[:]) # `Bool`. From 2ed9a59432790b004bc264503a48ec7abc1118b1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Apr 2025 18:54:32 +0100 Subject: [PATCH 03/12] NTVarInfo --- HISTORY.md | 6 +++ src/deprecated.jl | 3 ++ src/simple_varinfo.jl | 4 +- src/varinfo.jl | 108 +++++++++++++++++++++--------------------- 4 files changed, 65 insertions(+), 56 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 30ebcbf57..12416f1e5 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -27,6 +27,12 @@ The reason for this change is that there were several flavours of VarInfo. Some, like TypedVarInfo, were easy to construct because we had convenience methods for them; however, the others were more difficult. This change makes it easier to access different VarInfo types, and also makes it more explicit which one you are constructing. +The `untyped_varinfo` and `typed_varinfo` functions have also been removed; you can use `UntypedVarInfo` and `TypedVarInfo` as direct replacements. + +Finally, `TypedVarInfo` is no longer a type. +It has been replaced with `NTVarInfo`. +If you were dispatching on this, you should replace it with `NTVarInfo` instead. + ### VarName prefixing behaviour The way in which VarNames in submodels are prefixed has been changed. diff --git a/src/deprecated.jl b/src/deprecated.jl index bc00d0aec..51ff992ff 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -11,3 +11,6 @@ Base.@deprecate VarInfo( sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) TypedVarInfo(model, sampler, context) +Base.@deprecate VarInfo(model::Model, context::AbstractContext=DefaultContext()) TypedVarInfo( + model, context +) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 064483ddd..74cbfb231 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -10,7 +10,7 @@ Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict`. $(FIELDS) # Notes -The major differences between this and `TypedVarInfo` are: +The major differences between this and `NTVarInfo` are: 1. `SimpleVarInfo` does not require linearization. 2. `SimpleVarInfo` can use more efficient bijectors. 3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either @@ -244,7 +244,7 @@ function SimpleVarInfo{T}( end # Constructor from `VarInfo`. -function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D} +function SimpleVarInfo(vi::NTVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D} return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...) end function SimpleVarInfo{T}( diff --git a/src/varinfo.jl b/src/varinfo.jl index d3d18accb..0e0ddc67a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -83,7 +83,7 @@ for all the sybmols. `VarInfo{<:Metadata}` is aliased `UntypedVarInfo`. If `vi isa VarInfo{<:NamedTuple}`, then `vi.metadata` is a `NamedTuple` that maps each symbol used on the LHS of `~` in the model to its `Metadata` instance. The latter allows for the type specialization of `vi` after the first sampling iteration when all the -symbols have been observed. `VarInfo{<:NamedTuple}` is aliased `TypedVarInfo`. +symbols have been observed. `VarInfo{<:NamedTuple}` is aliased `NTVarInfo`. Note: It is the user's responsibility to ensure that each "symbol" is visited at least once whenever the model is called, regardless of any stochastic branching. Each symbol @@ -98,16 +98,12 @@ VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} const UntypedVarInfo = VarInfo{<:Metadata} -# TODO: TypedVarInfo carries no information about the type of the actual -# metadata i.e. the elements of the NamedTuple. It could be Metadata or it -# could be VarNamedVector. Calling TypedVarInfo(model) will result in a -# TypedVarInfo where the elements are Metadata. +# TODO: NTVarInfo carries no information about the type of the actual metadata +# i.e. the elements of the NamedTuple. It could be Metadata or it could be # Resolving this ambiguity would likely require us to replace NamedTuple with # something which carried both its keys as well as its values' types as type # parameters. -# Note that below we also define a function TypedVectorVarInfo, which generates -# a TypedVarInfo where the metadata is a NamedTuple of VarNameVectors. -const TypedVarInfo = VarInfo{<:NamedTuple} +const NTVarInfo = VarInfo{<:NamedTuple} const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} } @@ -145,7 +141,7 @@ end function has_varnamedvector(vi::VarInfo) return vi.metadata isa VarNamedVector || - (vi isa TypedVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata))) + (vi isa NTVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata))) end ######################## @@ -237,7 +233,12 @@ function TypedVarInfo(vi::UntypedVarInfo) nt = NamedTuple{syms_tuple}(Tuple(new_metas)) return VarInfo(nt, Ref(logp), Ref(num_produce)) end -TypedVarInfo(vi::TypedVarInfo) = vi +function TypedVarInfo(vi::NTVarInfo) + # This function preserves the behaviour of TypedVarInfo(vi) where vi is + # already a TypedVarInfo + has_varnamedvector(vi) && error("Cannot convert TypedVectorVarInfo to TypedVarInfo") + return vi +end function TypedVarInfo( rng::Random.AbstractRNG, model::Model, @@ -267,7 +268,7 @@ function UntypedVectorVarInfo(model::Model, args::Union{AbstractSampler,Abstract return UntypedVectorVarInfo(UntypedVarInfo(Random.default_rng(), model, args...)) end -function TypedVectorVarInfo(vi::TypedVarInfo) +function TypedVectorVarInfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) lp = getlogp(vi) return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) @@ -297,7 +298,7 @@ end Return the length of the vector representation of `varinfo`. """ vector_length(varinfo::VarInfo) = length(varinfo.metadata) -vector_length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata) +vector_length(varinfo::NTVarInfo) = sum(length, varinfo.metadata) vector_length(md::Metadata) = sum(length, md.ranges) function unflatten(vi::VarInfo, x::AbstractVector) @@ -588,7 +589,7 @@ setval!(vi::UntypedVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val Return the metadata in `vi` that belongs to `vn`. """ getmetadata(vi::VarInfo, vn::VarName) = vi.metadata -getmetadata(vi::TypedVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) +getmetadata(vi::NTVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) """ getidx(vi::VarInfo, vn::VarName) @@ -629,7 +630,7 @@ end Return the range corresponding to `varname` in the vector representation of `varinfo`. """ vector_getrange(vi::VarInfo, vn::VarName) = getrange(vi.metadata, vn) -function vector_getrange(vi::TypedVarInfo, vn::VarName) +function vector_getrange(vi::NTVarInfo, vn::VarName) offset = 0 for md in values(vi.metadata) # First, we need to check if `vn` is in `md`. @@ -651,8 +652,8 @@ Return the range corresponding to `varname` in the vector representation of `var function vector_getranges(varinfo::VarInfo, varname::Vector{<:VarName}) return map(Base.Fix1(vector_getrange, varinfo), varname) end -# Specialized version for `TypedVarInfo`. -function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName}) +# Specialized version for `NTVarInfo`. +function vector_getranges(varinfo::NTVarInfo, vns::Vector{<:VarName}) # TODO: Does it help if we _don't_ convert to a vector here? metadatas = collect(values(varinfo.metadata)) # Extract the offsets. @@ -712,7 +713,7 @@ end getindex_internal(vi::VarInfo, ::Colon) = getindex_internal(vi.metadata, Colon()) # NOTE: `mapreduce` over `NamedTuple` results in worse type-inference. # See for example https://github.com/JuliaLang/julia/pull/46381. -function getindex_internal(vi::TypedVarInfo, ::Colon) +function getindex_internal(vi::NTVarInfo, ::Colon) return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata)) end function getindex_internal(md::Metadata, ::Colon) @@ -772,10 +773,10 @@ settrans!!(vi::VarInfo, trans::AbstractTransformation) = settrans!!(vi, true) Returns a tuple of the unique symbols of random variables in `vi`. """ syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols -syms(vi::TypedVarInfo) = keys(vi.metadata) +syms(vi::NTVarInfo) = keys(vi.metadata) _getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs) -_getidcs(vi::TypedVarInfo) = _getidcs(vi.metadata) +_getidcs(vi::NTVarInfo) = _getidcs(vi.metadata) @generated function _getidcs(metadata::NamedTuple{names}) where {names} exprs = [] @@ -790,12 +791,11 @@ end findinds(vnv::VarNamedVector) = 1:length(vnv.varnames) """ - all_varnames_grouped_by_symbol(vi::TypedVarInfo) + all_varnames_grouped_by_symbol(vi::NTVarInfo) Return a `NamedTuple` of the variables in `vi` grouped by symbol. """ -all_varnames_grouped_by_symbol(vi::TypedVarInfo) = - all_varnames_grouped_by_symbol(vi.metadata) +all_varnames_grouped_by_symbol(vi::NTVarInfo) = all_varnames_grouped_by_symbol(vi.metadata) @generated function all_varnames_grouped_by_symbol(md::NamedTuple{names}) where {names} expr = Expr(:tuple) @@ -855,8 +855,8 @@ Base.keys(vi::VarInfo) = Base.keys(vi.metadata) # HACK: Necessary to avoid returning `Any[]` which won't dispatch correctly # on other methods in the codebase which requires `Vector{<:VarName}`. -Base.keys(vi::TypedVarInfo{<:NamedTuple{()}}) = VarName[] -@generated function Base.keys(vi::TypedVarInfo{<:NamedTuple{names}}) where {names} +Base.keys(vi::NTVarInfo{<:NamedTuple{()}}) = VarName[] +@generated function Base.keys(vi::NTVarInfo{<:NamedTuple{names}}) where {names} expr = Expr(:call) push!(expr.args, :vcat) @@ -919,7 +919,7 @@ _isempty(vnv::VarNamedVector) = isempty(vnv) return Expr(:&&, (:(_isempty(metadata.$f)) for f in names)...) end -function link!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) +function link!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) @@ -973,13 +973,13 @@ function _link!(vi::UntypedVarInfo, vns) end end -# If we try to _link! a TypedVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the TypedVarInfo. -function _link!(vi::TypedVarInfo, vns::VarNameTuple) +# If we try to _link! a NTVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the NTVarInfo. +function _link!(vi::NTVarInfo, vns::VarNameTuple) return _link!(vi, group_varnames_by_symbol(vns)) end -function _link!(vi::TypedVarInfo, vns::NamedTuple) +function _link!(vi::NTVarInfo, vns::NamedTuple) return _link!(vi.metadata, vi, vns) end @@ -1023,7 +1023,7 @@ end return expr end -function invlink!!(::DynamicTransformation, vi::TypedVarInfo, model::Model) +function invlink!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) @@ -1085,13 +1085,13 @@ function _invlink!(vi::UntypedVarInfo, vns) end end -# If we try to _invlink! a TypedVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the TypedVarInfo. -function _invlink!(vi::TypedVarInfo, vns::VarNameTuple) +# If we try to _invlink! a NTVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the NTVarInfo. +function _invlink!(vi::NTVarInfo, vns::VarNameTuple) return _invlink!(vi.metadata, vi, group_varnames_by_symbol(vns)) end -function _invlink!(vi::TypedVarInfo, vns::NamedTuple) +function _invlink!(vi::NTVarInfo, vns::NamedTuple) return _invlink!(vi.metadata, vi, vns) end @@ -1142,7 +1142,7 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) return vi end -function link(::DynamicTransformation, vi::TypedVarInfo, model::Model) +function link(::DynamicTransformation, vi::NTVarInfo, model::Model) return _link(model, vi, all_varnames_grouped_by_symbol(vi)) end @@ -1177,13 +1177,13 @@ function _link(model::Model, varinfo::VarInfo, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -# If we try to _link a TypedVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the TypedVarInfo. -function _link(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) +# If we try to _link a NTVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the NTVarInfo. +function _link(model::Model, varinfo::NTVarInfo, vns::VarNameTuple) return _link(model, varinfo, group_varnames_by_symbol(vns)) end -function _link(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) +function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md = _link_metadata!(model, varinfo, varinfo.metadata, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) @@ -1278,7 +1278,7 @@ function _link_metadata!!( return metadata end -function invlink(::DynamicTransformation, vi::TypedVarInfo, model::Model) +function invlink(::DynamicTransformation, vi::NTVarInfo, model::Model) return _invlink(model, vi, all_varnames_grouped_by_symbol(vi)) end @@ -1318,13 +1318,13 @@ function _invlink(model::Model, varinfo::VarInfo, vns) ) end -# If we try to _invlink a TypedVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the TypedVarInfo. -function _invlink(model::Model, varinfo::TypedVarInfo, vns::VarNameTuple) +# If we try to _invlink a NTVarInfo with a Tuple of VarNames, first convert it to a +# NamedTuple that matches the structure of the NTVarInfo. +function _invlink(model::Model, varinfo::NTVarInfo, vns::VarNameTuple) return _invlink(model, varinfo, group_varnames_by_symbol(vns)) end -function _invlink(model::Model, varinfo::TypedVarInfo, vns::NamedTuple) +function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) @@ -1415,7 +1415,7 @@ end # TODO(mhauru) The treatment of the case when some variables are linked and others are not # should be revised. It used to be the case that for UntypedVarInfo `islinked` returned -# whether the first variable was linked. For TypedVarInfo we did an OR over the first +# whether the first variable was linked. For NTVarInfo we did an OR over the first # variables under each symbol. We now more consistently use OR, but I'm not convinced this # is really the right thing to do. """ @@ -1559,7 +1559,7 @@ Base.haskey(metadata::Metadata, vn::VarName) = haskey(metadata.idcs, vn) Check whether `vn` has a value in `vi`. """ Base.haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn), vn) -function Base.haskey(vi::TypedVarInfo, vn::VarName) +function Base.haskey(vi::NTVarInfo, vn::VarName) md_haskey = map(vi.metadata) do metadata haskey(metadata, vn) end @@ -1622,12 +1622,12 @@ the `VarInfo` `vi`, mutating if it makes sense. function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) if vi isa UntypedVarInfo @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" - elseif vi isa TypedVarInfo - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist" + elseif vi isa NTVarInfo + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist" end sym = getsym(vn) - if vi isa TypedVarInfo && ~haskey(vi.metadata, sym) + if vi isa NTVarInfo && ~haskey(vi.metadata, sym) # The NamedTuple doesn't have an entry for this variable, let's add one. val = tovec(r) md = Metadata( @@ -1658,8 +1658,8 @@ function Base.push!(vi::UntypedVectorVarInfo, pair::Pair, args...) return push!(vi, vn, val, args...) end -# TODO(mhauru) push! can't be implemented in-place for TypedVarInfo if the symbol doesn't -# exist in the TypedVarInfo already. We could implement it in the cases where it it does +# TODO(mhauru) push! can't be implemented in-place for NTVarInfo if the symbol doesn't +# exist in the NTVarInfo already. We could implement it in the cases where it it does # exist, but that feels a bit pointless. I think we should rather rely on `push!!`. function Base.push!(meta::Metadata, vn, r, dist, num_produce) @@ -1781,7 +1781,7 @@ function set_retained_vns_del!(vi::UntypedVarInfo) end return nothing end -function set_retained_vns_del!(vi::TypedVarInfo) +function set_retained_vns_del!(vi::NTVarInfo) idcs = _getidcs(vi) return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi)) end @@ -1842,12 +1842,12 @@ function _apply!(kernel!, vi::VarInfoOrThreadSafeVarInfo, values, keys) return vi end -function _apply!(kernel!, vi::TypedVarInfo, values, keys) +function _apply!(kernel!, vi::NTVarInfo, values, keys) return _typed_apply!(kernel!, vi, vi.metadata, values, collect_maybe(keys)) end @generated function _typed_apply!( - kernel!, vi::TypedVarInfo, metadata::NamedTuple{names}, values, keys + kernel!, vi::NTVarInfo, metadata::NamedTuple{names}, values, keys ) where {names} updates = map(names) do n quote From 5cc1287724f3bc9f86ef7766721bcb18e22886a8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Apr 2025 19:15:21 +0100 Subject: [PATCH 04/12] Fix tests --- src/DynamicPPL.jl | 1 + src/deprecated.jl | 4 +--- src/test_utils/varinfo.jl | 10 ++++------ src/varinfo.jl | 2 +- test/compiler.jl | 2 +- test/model.jl | 8 ++++---- test/test_util.jl | 6 +++--- test/varinfo.jl | 2 +- 8 files changed, 16 insertions(+), 19 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 9f45718c5..053f1bcfb 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -48,6 +48,7 @@ export AbstractVarInfo, UntypedVarInfo, TypedVarInfo, SimpleVarInfo, + NTVarInfo, push!!, empty!!, subset, diff --git a/src/deprecated.jl b/src/deprecated.jl index 51ff992ff..a54142867 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -11,6 +11,4 @@ Base.@deprecate VarInfo( sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) TypedVarInfo(model, sampler, context) -Base.@deprecate VarInfo(model::Model, context::AbstractContext=DefaultContext()) TypedVarInfo( - model, context -) +Base.@deprecate VarInfo(model::Model, context::AbstractContext) TypedVarInfo(model, context) diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 6a655ded4..c092201de 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -27,12 +27,10 @@ function setup_varinfos( model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false ) # VarInfo - vi_untyped_metadata = VarInfo(DynamicPPL.Metadata()) - vi_untyped_vnv = VarInfo(DynamicPPL.VarNamedVector()) - model(vi_untyped_metadata) - model(vi_untyped_vnv) - vi_typed_metadata = DynamicPPL.TypedVarInfo(vi_untyped_metadata) - vi_typed_vnv = DynamicPPL.TypedVarInfo(vi_untyped_vnv) + vi_untyped_metadata = UntypedVarInfo(model) + vi_untyped_vnv = DynamicPPL.UntypedVectorVarInfo(model) + vi_typed_metadata = TypedVarInfo(vi_untyped_metadata) + vi_typed_vnv = DynamicPPL.TypedVectorVarInfo(vi_untyped_vnv) # SimpleVarInfo svi_typed = SimpleVarInfo(example_values) diff --git a/src/varinfo.jl b/src/varinfo.jl index 0e0ddc67a..8eb4056a9 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -176,7 +176,7 @@ function UntypedVarInfo( ) return UntypedVarInfo(Random.default_rng(), model, sampler, context) end -function UntypedVarInfo(model::Model, context::AbstractContext=DefaultContext()) +function UntypedVarInfo(model::Model, context::AbstractContext) return UntypedVarInfo(Random.default_rng(), model, SampleFromPrior(), context) end diff --git a/test/compiler.jl b/test/compiler.jl index 7ce4edafc..84c19e55f 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -600,7 +600,7 @@ module Issue537 end # Make sure that a return-value of `x = 1` isn't combined into # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. @model empty_model() = return x = 1 - empty_vi = TypedVarInfo() + empty_vi = VarInfo() retval_and_vi = DynamicPPL.evaluate!!(empty_model(), empty_vi, SamplingContext()) @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} diff --git a/test/model.jl b/test/model.jl index 984b8b196..a41accaeb 100644 --- a/test/model.jl +++ b/test/model.jl @@ -25,9 +25,9 @@ function innermost_distribution_type(d::Distributions.Product) return dists[1] end -is_typed_varinfo(::DynamicPPL.AbstractVarInfo) = false -is_typed_varinfo(varinfo::DynamicPPL.TypedVarInfo) = true -is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true +is_type_stable_varinfo(::DynamicPPL.AbstractVarInfo) = false +is_type_stable_varinfo(varinfo::DynamicPPL.NTVarInfo) = true +is_type_stable_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @@ -399,7 +399,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) varinfos = filter( - is_typed_varinfo, + is_type_stable_varinfo, DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos diff --git a/test/test_util.jl b/test/test_util.jl index 2c5181b0c..a1061d89a 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -35,12 +35,12 @@ Return string representing a short description of `vi`. """ short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = "threadsafe($(short_varinfo_name(vi.varinfo)))" -function short_varinfo_name(vi::TypedVarInfo) - DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector" +function short_varinfo_name(vi::NTVarInfo) + DynamicPPL.has_varnamedvector(vi) && return "TypedVectorVarInfo" return "TypedVarInfo" end short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" -short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo" +short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo" function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) return "SimpleVarInfo{<:NamedTuple,<:Ref}" end diff --git a/test/varinfo.jl b/test/varinfo.jl index 42e85ba8c..a79dba816 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -101,7 +101,7 @@ end @test vi[vn] == 2 * r # TODO(mhauru) Implement these functions for other VarInfo types too. - if vi isa DynamicPPL.VectorVarInfo + if vi isa DynamicPPL.UntypedVectorVarInfo delete!(vi, vn) @test isempty(vi) vi = push!!(vi, vn, r, dist) From d8c360faf39adfbaffa2cbbbe4eae57e27884ba6 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Apr 2025 19:35:32 +0100 Subject: [PATCH 05/12] More fixes --- benchmarks/src/DynamicPPLBenchmarks.jl | 6 ++-- docs/src/api.md | 3 +- docs/src/internals/varinfo.md | 4 +-- src/varinfo.jl | 38 +++++++++++++++++++++++++- 4 files changed, 44 insertions(+), 7 deletions(-) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index a0f45a81b..8c0fc8a80 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -1,6 +1,6 @@ module DynamicPPLBenchmarks -using DynamicPPL: VarInfo, SimpleVarInfo, VarName +using DynamicPPL: VarInfo, UntypedVarInfo, TypedVarInfo, SimpleVarInfo, VarName using BenchmarkTools: BenchmarkGroup, @benchmarkable using DynamicPPL: DynamicPPL using ADTypes: ADTypes @@ -52,8 +52,8 @@ end Create a benchmark suite for `model` using the selected varinfo type and AD backend. Available varinfo choices: - • `:untyped` → uses `VarInfo()` - • `:typed` → uses `VarInfo(model)` + • `:untyped` → uses `UntypedVarInfo(model)` + • `:typed` → uses `TypedVarInfo(model)` • `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())` • `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs) diff --git a/docs/src/api.md b/docs/src/api.md index e33d49cc5..38cc3af5e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -295,6 +295,7 @@ But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary. ```@docs VarInfo +UntypedVarInfo TypedVarInfo ``` @@ -448,7 +449,7 @@ Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a give DynamicPPL.default_varinfo ``` -There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.typed_varinfo`](@ref) or [`DynamicPPL.untyped_varinfo`](@ref), depending on which supports the model: +There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.TypedVarInfo`](@ref) or [`DynamicPPL.UntypedVarInfo`](@ref), depending on which supports the model: ```@docs DynamicPPL.Experimental.determine_suitable_varinfo diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md index 03f6e22c0..de1a5e758 100644 --- a/docs/src/internals/varinfo.md +++ b/docs/src/internals/varinfo.md @@ -227,13 +227,13 @@ Continuing from the example from the previous section, we can use a `VarInfo` wi ```@example varinfo-design # Type-unstable -varinfo_untyped_vnv = DynamicPPL.VectorVarInfo(varinfo_untyped) +varinfo_untyped_vnv = DynamicPPL.UntypedVectorVarInfo(varinfo_untyped) varinfo_untyped_vnv[@varname(x)], varinfo_untyped_vnv[@varname(y)] ``` ```@example varinfo-design # Type-stable -varinfo_typed_vnv = DynamicPPL.VectorVarInfo(varinfo_typed) +varinfo_typed_vnv = DynamicPPL.TypedVectorVarInfo(varinfo_typed) varinfo_typed_vnv[@varname(x)], varinfo_typed_vnv[@varname(y)] ``` diff --git a/src/varinfo.jl b/src/varinfo.jl index 8eb4056a9..2ac379d4f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -151,7 +151,8 @@ end """ UntypedVarInfo([rng, ]model[, sampler, context, metadata]) -Return an untyped varinfo object for the given `model` and `context`. +Return a VarInfo object for the given `model` and `context`, which has just a +single `Metadata` as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation @@ -181,6 +182,17 @@ function UntypedVarInfo(model::Model, context::AbstractContext) end """ + TypedVarInfo([rng, ]model[, sampler, context, metadata]) + +Return a VarInfo object for the given `model` and `context`, which has a NamedTuple of +`Metadata` structs as its metadata field. + +# Arguments +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. + TypedVarInfo(vi::UntypedVarInfo) This function finds all the unique `sym`s from the instances of `VarName{sym}` found in @@ -251,6 +263,18 @@ function TypedVarInfo(model::Model, args::Union{AbstractSampler,AbstractContext} return TypedVarInfo(Random.default_rng(), model, args...) end +""" + UntypedVectorVarInfo([rng, ]model[, sampler, context, metadata]) + +Return a VarInfo object for the given `model` and `context`, which has just a +single `VarNamedVector` as its metadata field. + +# Arguments +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. +""" function UntypedVectorVarInfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) lp = getlogp(vi) @@ -268,6 +292,18 @@ function UntypedVectorVarInfo(model::Model, args::Union{AbstractSampler,Abstract return UntypedVectorVarInfo(UntypedVarInfo(Random.default_rng(), model, args...)) end +""" + TypedVectorVarInfo([rng, ]model[, sampler, context, metadata]) + +Return a VarInfo object for the given `model` and `context`, which has a +NamedTuple of `VarNamedVector`s as its metadata field. + +# Arguments +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. +""" function TypedVectorVarInfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) lp = getlogp(vi) From 460fb2f7bb9451885ee2e59ea4c12555dcf22835 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Apr 2025 19:52:35 +0100 Subject: [PATCH 06/12] Fixes --- src/model.jl | 8 +++--- src/model_utils.jl | 12 ++++---- src/submodel_macro.jl | 12 ++++---- src/varinfo.jl | 66 +++++++++++++++++++++++++++++++++++++------ 4 files changed, 73 insertions(+), 25 deletions(-) diff --git a/src/model.jl b/src/model.jl index 7ac1e1e1e..76ad4e08d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -433,7 +433,7 @@ julia> conditioned(cm) julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, # `a.m` is treated as a random variable. - keys(VarInfo(cm)) + keys(TypedVarInfo(cm)) 1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: a.m @@ -446,7 +446,7 @@ julia> conditioned(cm)[@varname(x)] julia> conditioned(cm)[@varname(a.m)] 1.0 -julia> keys(VarInfo(cm)) # No variables are sampled +julia> keys(TypedVarInfo(cm)) # No variables are sampled VarName[] ``` """ @@ -773,7 +773,7 @@ julia> fixed(cm) julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed, # `a.m` is treated as a random variable. - keys(VarInfo(cm)) + keys(TypedVarInfo(cm)) 1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: a.m @@ -786,7 +786,7 @@ julia> fixed(cm)[@varname(x)] julia> fixed(cm)[@varname(a.m)] 1.0 -julia> keys(VarInfo(cm)) # <= no variables are sampled +julia> keys(TypedVarInfo(cm)) # <= no variables are sampled VarName[] ``` """ diff --git a/src/model_utils.jl b/src/model_utils.jl index ac4ec7022..290231014 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -4,7 +4,7 @@ Return `true` if all variable names in `model`/`varinfo` are in `chain`. """ -varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain) +varnames_in_chain(model::Model, chain) = varnames_in_chain(TypedVarInfo(model), chain) function varnames_in_chain(varinfo::VarInfo, chain) return all(vn -> varname_in_chain(varinfo, vn, chain, 1, 1), keys(varinfo)) end @@ -16,7 +16,7 @@ end Return `out` with `true` for all variable names in `model` that are in `chain`. """ function varnames_in_chain!(model::Model, chain, out) - return varnames_in_chain!(VarInfo(model), chain, out) + return varnames_in_chain!(TypedVarInfo(model), chain, out) end function varnames_in_chain!(varinfo::VarInfo, chain, out) for vn in keys(varinfo) @@ -33,7 +33,7 @@ end Return `true` if `vn` is in `chain` at `chain_idx` and `iteration_idx`. """ function varname_in_chain(model::Model, vn, chain, chain_idx, iteration_idx) - return varname_in_chain(VarInfo(model), vn, chain, chain_idx, iteration_idx) + return varname_in_chain(TypedVarInfo(model), vn, chain, chain_idx, iteration_idx) end function varname_in_chain(varinfo::AbstractVarInfo, vn, chain, chain_idx, iteration_idx) @@ -60,7 +60,7 @@ This differs from [`varname_in_chain`](@ref) in that it returns a dictionary rather than a single boolean. This can be quite useful for debugging purposes. """ function varname_in_chain!(model::Model, vn, chain, chain_idx, iteration_idx, out) - return varname_in_chain!(VarInfo(model), vn, chain, chain_idx, iteration_idx, out) + return varname_in_chain!(TypedVarInfo(model), vn, chain, chain_idx, iteration_idx, out) end function varname_in_chain!( @@ -132,7 +132,7 @@ Mutate `out` to map each variable name in `model`/`varinfo` to its value in `chain` at `chain_idx` and `iteration_idx`. """ function values_from_chain!(model::Model, chain, chain_idx, iteration_idx, out) - return values_from_chain(VarInfo(model), chain, chain_idx, iteration_idx, out) + return values_from_chain(TypedVarInfo(model), chain, chain_idx, iteration_idx, out) end function values_from_chain!(vi::AbstractVarInfo, chain, chain_idx, iteration_idx, out) @@ -197,7 +197,7 @@ julia> conditioned_model() # <= results in same values as the `first(iter)` abo ``` """ function value_iterator_from_chain(model::Model, chain) - return value_iterator_from_chain(VarInfo(model), chain) + return value_iterator_from_chain(TypedVarInfo(model), chain) end function value_iterator_from_chain(vi::AbstractVarInfo, chain) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 3089a9a97..b4090e116 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -137,7 +137,7 @@ julia> # When `prefix` is unspecified, no prefix is used. @model submodel_noprefix() = @submodel a = inner() submodel_noprefix (generic function with 2 methods) -julia> @varname(x) in keys(VarInfo(submodel_noprefix())) +julia> @varname(x) in keys(TypedVarInfo(submodel_noprefix())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -147,7 +147,7 @@ julia> # Explicitely don't use any prefix. @model submodel_prefix_false() = @submodel prefix=false a = inner() submodel_prefix_false (generic function with 2 methods) -julia> @varname(x) in keys(VarInfo(submodel_prefix_false())) +julia> @varname(x) in keys(TypedVarInfo(submodel_prefix_false())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -157,7 +157,7 @@ julia> # Automatically determined from `a`. @model submodel_prefix_true() = @submodel prefix=true a = inner() submodel_prefix_true (generic function with 2 methods) -julia> @varname(a.x) in keys(VarInfo(submodel_prefix_true())) +julia> @varname(a.x) in keys(TypedVarInfo(submodel_prefix_true())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -167,7 +167,7 @@ julia> # Using a static string. @model submodel_prefix_string() = @submodel prefix="my prefix" a = inner() submodel_prefix_string (generic function with 2 methods) -julia> @varname(var"my prefix".x) in keys(VarInfo(submodel_prefix_string())) +julia> @varname(var"my prefix".x) in keys(TypedVarInfo(submodel_prefix_string())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -177,7 +177,7 @@ julia> # Using string interpolation. @model submodel_prefix_interpolation() = @submodel prefix="\$(nameof(inner()))" a = inner() submodel_prefix_interpolation (generic function with 2 methods) -julia> @varname(inner.x) in keys(VarInfo(submodel_prefix_interpolation())) +julia> @varname(inner.x) in keys(TypedVarInfo(submodel_prefix_interpolation())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -187,7 +187,7 @@ julia> # Or using some arbitrary expression. @model submodel_prefix_expr() = @submodel prefix=1 + 2 a = inner() submodel_prefix_expr (generic function with 2 methods) -julia> @varname(var"3".x) in keys(VarInfo(submodel_prefix_expr())) +julia> @varname(var"3".x) in keys(TypedVarInfo(submodel_prefix_expr())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 diff --git a/src/varinfo.jl b/src/varinfo.jl index 2ac379d4f..51fb482c4 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -175,10 +175,16 @@ function UntypedVarInfo( sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) + # No rng return UntypedVarInfo(Random.default_rng(), model, sampler, context) end +function UntypedVarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) + # No sampler + return UntypedVarInfo(rng, model, SampleFromPrior(), context) +end function UntypedVarInfo(model::Model, context::AbstractContext) - return UntypedVarInfo(Random.default_rng(), model, SampleFromPrior(), context) + # No sampler, no rng + return UntypedVarInfo(model, SampleFromPrior(), context) end """ @@ -259,8 +265,21 @@ function TypedVarInfo( ) return TypedVarInfo(UntypedVarInfo(rng, model, sampler, context)) end -function TypedVarInfo(model::Model, args::Union{AbstractSampler,AbstractContext}...) - return TypedVarInfo(Random.default_rng(), model, args...) +function TypedVarInfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + # No rng + return TypedVarInfo(Random.default_rng(), model, sampler, context) +end +function TypedVarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) + # No sampler + return TypedVarInfo(rng, model, SampleFromPrior(), context) +end +function TypedVarInfo(model::Model, context::AbstractContext) + # No sampler, no rng + return TypedVarInfo(model, SampleFromPrior(), context) end """ @@ -288,8 +307,23 @@ function UntypedVectorVarInfo( ) return UntypedVectorVarInfo(UntypedVarInfo(rng, model, sampler, context)) end -function UntypedVectorVarInfo(model::Model, args::Union{AbstractSampler,AbstractContext}...) - return UntypedVectorVarInfo(UntypedVarInfo(Random.default_rng(), model, args...)) +function UntypedVectorVarInfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + # No rng + return UntypedVectorVarInfo(Random.default_rng(), model, sampler, context) +end +function UntypedVectorVarInfo( + rng::Random.AbstractRNG, model::Model, context::AbstractContext +) + # No sampler + return UntypedVectorVarInfo(rng, model, SampleFromPrior(), context) +end +function UntypedVectorVarInfo(model::Model, context::AbstractContext) + # No sampler, no rng + return UntypedVectorVarInfo(model, SampleFromPrior(), context) end """ @@ -324,8 +358,21 @@ function TypedVectorVarInfo( ) return TypedVectorVarInfo(UntypedVectorVarInfo(rng, model, sampler, context)) end -function TypedVectorVarInfo(model::Model, args::Union{AbstractSampler,AbstractContext}...) - return TypedVectorVarInfo(Random.default_rng(), model, args...) +function TypedVectorVarInfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + # No rng + return TypedVectorVarInfo(Random.default_rng(), model, sampler, context) +end +function TypedVectorVarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) + # No sampler + return TypedVectorVarInfo(rng, model, SampleFromPrior(), context) +end +function TypedVectorVarInfo(model::Model, context::AbstractContext) + # No sampler, no rng + return TypedVectorVarInfo(model, SampleFromPrior(), context) end """ @@ -1956,7 +2003,7 @@ julia> rng = StableRNG(42); julia> m = demo([missing]); -julia> var_info = DynamicPPL.VarInfo(rng, m); +julia> var_info = DynamicPPL.TypedVarInfo(rng, m); julia> var_info[@varname(m)] -0.6702516921145671 @@ -2020,7 +2067,8 @@ julia> rng = StableRNG(42); julia> m = demo([missing]); -julia> var_info = DynamicPPL.VarInfo(rng, m, SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata()); # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. +julia> var_info = DynamicPPL.TypedVarInfo(rng, m); + # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. julia> var_info[@varname(m)] -0.6702516921145671 From 1f7b201e49b20b523396270b9c08ded5917a3baa Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Apr 2025 20:56:51 +0100 Subject: [PATCH 07/12] Fixes --- test/ext/DynamicPPLJETExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 78f972b07..21a817306 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -14,7 +14,7 @@ @model demo2() = x ~ Normal() @test DynamicPPL.Experimental.determine_suitable_varinfo(demo2()) isa - DynamicPPL.TypedVarInfo + DynamicPPL.NTVarInfo @model function demo3() # Just making sure that nothing strange happens when type inference fails. @@ -53,7 +53,7 @@ end # Should pass if we're only checking the tilde statements. @test DynamicPPL.Experimental.determine_suitable_varinfo(demo5()) isa - DynamicPPL.TypedVarInfo + DynamicPPL.NTVarInfo # Should fail if we're including errors in the model body. @test DynamicPPL.Experimental.determine_suitable_varinfo( demo5(); only_ddpl=false @@ -75,7 +75,7 @@ ) JET.test_call(f_sample, argtypes_sample) # For our demo models, they should all result in typed. - is_typed = varinfo isa DynamicPPL.TypedVarInfo + is_typed = varinfo isa DynamicPPL.NTVarInfo @test is_typed # If the test failed, check why it didn't infer a typed varinfo if !is_typed From 416d729cde291c2250bb4fe7fccedb9e8ab55821 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 1 Apr 2025 21:46:14 +0100 Subject: [PATCH 08/12] Fixes --- ext/DynamicPPLMCMCChainsExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 7fcbd6a7c..1237ea396 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -110,7 +110,7 @@ function DynamicPPL.predict( include_all=false, ) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - varinfo = DynamicPPL.VarInfo(model) + varinfo = DynamicPPL.TypedVarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) @@ -245,7 +245,7 @@ julia> returned(model, chain) """ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains) chain = MCMCChains.get_sections(chain_full, :parameters) - varinfo = DynamicPPL.VarInfo(model) + varinfo = DynamicPPL.TypedVarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) # TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702. From 04e6cb8fbcf06763b23a17563ad2b8d41c9aa729 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 8 Apr 2025 15:57:21 +0100 Subject: [PATCH 09/12] Use lowercase functions, don't deprecate VarInfo --- HISTORY.md | 46 ++++---- benchmarks/src/DynamicPPLBenchmarks.jl | 10 +- docs/src/api.md | 11 +- docs/src/internals/varinfo.md | 10 +- ext/DynamicPPLJETExt.jl | 4 +- ext/DynamicPPLMCMCChainsExt.jl | 4 +- src/DynamicPPL.jl | 3 - src/abstract_varinfo.jl | 10 +- src/deprecated.jl | 13 --- src/experimental.jl | 6 +- src/logdensityfunction.jl | 4 +- src/model.jl | 20 ++-- src/model_utils.jl | 12 +- src/pointwise_logdensities.jl | 6 +- src/sampler.jl | 2 +- src/simple_varinfo.jl | 2 +- src/submodel_macro.jl | 16 +-- src/test_utils/contexts.jl | 2 +- src/test_utils/varinfo.jl | 8 +- src/values_as_in_model.jl | 2 +- src/varinfo.jl | 146 +++++++++++++++---------- test/ad.jl | 2 +- test/compiler.jl | 32 +++--- test/deprecated.jl | 4 +- test/ext/DynamicPPLForwardDiffExt.jl | 2 +- test/ext/DynamicPPLJETExt.jl | 2 +- test/model.jl | 21 ++-- test/pointwise_logdensities.jl | 2 +- test/sampler.jl | 2 +- test/simple_varinfo.jl | 2 +- test/test_util.jl | 18 +-- test/threadsafe.jl | 4 +- test/utils.jl | 2 +- test/varinfo.jl | 49 +++++---- 34 files changed, 257 insertions(+), 222 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 12416f1e5..1af5c2ca3 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -8,30 +8,20 @@ `VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. -**The `VarInfo([rng, ]model[, sampler, context, metadata])` constructor has been replaced with the following methods:** +The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed. +If you were not using this argument (most likely), then there is no change needed. +If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below). - 1. `UntypedVarInfo([rng, ]model[, sampler, context])` - 2. `TypedVarInfo([rng, ]model[, sampler, context])` - 3. `DynamicPPL.UntypedVectorVarInfo([rng, ]model[, sampler, context])` - 4. `DynamicPPL.TypedVectorVarInfo([rng, ]model[, sampler, context])` +The `UntypedVarInfo` constructor and type is no longer exported. +If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead. -**If you were not using the `metadata` argument (most likely), then you can directly replace calls to this constructor with `TypedVarInfo` instead.** -That is to say, if you were using `VarInfo(model)`, you can replace this with `TypedVarInfo(model)`. +The `TypedVarInfo` constructor and type is no longer exported. +The _type_ has been replaced with `DynamicPPL.NTVarInfo`. +The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`. -If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `TypedVectorVarInfo` instead. -Note that the `VectorVarInfo` constructors (both `Untyped` and `Typed`) are not exported by default. - -If you were passing a non-empty metadata argument, you should use a different constructor of `VarInfo` instead. - -The reason for this change is that there were several flavours of VarInfo. -Some, like TypedVarInfo, were easy to construct because we had convenience methods for them; however, the others were more difficult. -This change makes it easier to access different VarInfo types, and also makes it more explicit which one you are constructing. - -The `untyped_varinfo` and `typed_varinfo` functions have also been removed; you can use `UntypedVarInfo` and `TypedVarInfo` as direct replacements. - -Finally, `TypedVarInfo` is no longer a type. -It has been replaced with `NTVarInfo`. -If you were dispatching on this, you should replace it with `NTVarInfo` instead. +Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail. +Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs. +Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface. ### VarName prefixing behaviour @@ -78,6 +68,20 @@ outer() | (a.x=1.0,) If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. (This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.) +**Other changes** + +While these are technically breaking, they are only internal changes and do not affect the public API. +The following four functions have been added and/or reworked to make it easier to construct VarInfos with different types of metadata: + + 1. `DynamicPPL.untyped_varinfo([rng, ]model[, sampler, context])` + 2. `DynamicPPL.typed_varinfo([rng, ]model[, sampler, context])` + 3. `DynamicPPL.untyped_vector_varinfo([rng, ]model[, sampler, context])` + 4. `DynamicPPL.typed_vector_varinfo([rng, ]model[, sampler, context])` + +The reason for this change is that there were several flavours of VarInfo. +Some, like `typed_varinfo`, were easy to construct because we had convenience methods for them; however, the others were more difficult. +This change makes it easier to access different VarInfo types, and also makes it more explicit which one you are constructing. + ## 0.35.5 Several internal methods have been removed: diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 8c0fc8a80..16338de2f 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -1,6 +1,6 @@ module DynamicPPLBenchmarks -using DynamicPPL: VarInfo, UntypedVarInfo, TypedVarInfo, SimpleVarInfo, VarName +using DynamicPPL: VarInfo, SimpleVarInfo, VarName using BenchmarkTools: BenchmarkGroup, @benchmarkable using DynamicPPL: DynamicPPL using ADTypes: ADTypes @@ -52,8 +52,8 @@ end Create a benchmark suite for `model` using the selected varinfo type and AD backend. Available varinfo choices: - • `:untyped` → uses `UntypedVarInfo(model)` - • `:typed` → uses `TypedVarInfo(model)` + • `:untyped` → uses `DynamicPPL.untyped_varinfo(model)` + • `:typed` → uses `DynamicPPL.typed_varinfo(model)` • `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())` • `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs) @@ -67,9 +67,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: suite = BenchmarkGroup() vi = if varinfo_choice == :untyped - UntypedVarInfo(rng, model) + DynamicPPL.untyped_varinfo(rng, model) elseif varinfo_choice == :typed - TypedVarInfo(rng, model) + DynamicPPL.typed_varinfo(rng, model) elseif varinfo_choice == :simple_namedtuple SimpleVarInfo{Float64}(model(rng)) elseif varinfo_choice == :simple_dict diff --git a/docs/src/api.md b/docs/src/api.md index 38cc3af5e..f83a96886 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -295,8 +295,13 @@ But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary. ```@docs VarInfo -UntypedVarInfo -TypedVarInfo +``` + +```@docs +DynamicPPL.untyped_varinfo +DynamicPPL.typed_varinfo +DynamicPPL.untyped_vector_varinfo +DynamicPPL.typed_vector_varinfo ``` One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [main Turing documentation](https://turinglang.org/docs/developers/transforms/dynamicppl/). @@ -449,7 +454,7 @@ Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a give DynamicPPL.default_varinfo ``` -There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.TypedVarInfo`](@ref) or [`DynamicPPL.UntypedVarInfo`](@ref), depending on which supports the model: +There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.typed_varinfo`](@ref) or [`DynamicPPL.untyped_varinfo`](@ref), depending on which supports the model: ```@docs DynamicPPL.Experimental.determine_suitable_varinfo diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md index de1a5e758..b04913aaf 100644 --- a/docs/src/internals/varinfo.md +++ b/docs/src/internals/varinfo.md @@ -79,13 +79,13 @@ For example, with the model above we have ```@example varinfo-design # Type-unstable `VarInfo` -varinfo_untyped = DynamicPPL.UntypedVarInfo(demo()) +varinfo_untyped = DynamicPPL.untyped_varinfo(demo()) typeof(varinfo_untyped.metadata) ``` ```@example varinfo-design # Type-stable `VarInfo` -varinfo_typed = DynamicPPL.TypedVarInfo(demo()) +varinfo_typed = DynamicPPL.typed_varinfo(demo()) typeof(varinfo_typed.metadata) ``` @@ -154,7 +154,7 @@ For example, we want to optimize code-paths which effectively boil down to inner ```julia # Construct a `VarInfo` with types inferred from `model`. -varinfo = TypedVarInfo(model) +varinfo = VarInfo(model) # Repeatedly sample from `model`. for _ in 1:num_samples @@ -227,13 +227,13 @@ Continuing from the example from the previous section, we can use a `VarInfo` wi ```@example varinfo-design # Type-unstable -varinfo_untyped_vnv = DynamicPPL.UntypedVectorVarInfo(varinfo_untyped) +varinfo_untyped_vnv = DynamicPPL.untyped_vector_varinfo(varinfo_untyped) varinfo_untyped_vnv[@varname(x)], varinfo_untyped_vnv[@varname(y)] ``` ```@example varinfo-design # Type-stable -varinfo_typed_vnv = DynamicPPL.TypedVectorVarInfo(varinfo_typed) +varinfo_typed_vnv = DynamicPPL.typed_vector_varinfo(varinfo_typed) varinfo_typed_vnv[@varname(x)], varinfo_typed_vnv[@varname(y)] ``` diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl index 5d2d01662..aa95093f2 100644 --- a/ext/DynamicPPLJETExt.jl +++ b/ext/DynamicPPLJETExt.jl @@ -27,7 +27,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true ) # First we try with the typed varinfo. - varinfo = DynamicPPL.TypedVarInfo(model, context) + varinfo = DynamicPPL.typed_varinfo(model, context) # Let's make sure that both evaluation and sampling doesn't result in type errors. issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( @@ -46,7 +46,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( else # Warn the user that we can't use the type stable one. @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." - DynamicPPL.UntypedVarInfo(model, context) + DynamicPPL.untyped_varinfo(model, context) end end diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 1237ea396..7fcbd6a7c 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -110,7 +110,7 @@ function DynamicPPL.predict( include_all=false, ) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - varinfo = DynamicPPL.TypedVarInfo(model) + varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) @@ -245,7 +245,7 @@ julia> returned(model, chain) """ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains) chain = MCMCChains.get_sections(chain_full, :parameters) - varinfo = DynamicPPL.TypedVarInfo(model) + varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) # TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 053f1bcfb..51fa53079 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -45,10 +45,7 @@ import Base: # VarInfo export AbstractVarInfo, VarInfo, - UntypedVarInfo, - TypedVarInfo, SimpleVarInfo, - NTVarInfo, push!!, empty!!, subset, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index c00c93391..f11b8a3ec 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -247,11 +247,11 @@ julia> values_as(SimpleVarInfo(data), Vector) 2.0 ``` -`TypedVarInfo`: +`VarInfo` with `NamedTuple` of `Metadata`: ```jldoctest julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = TypedVarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); + vi = DynamicPPL.typed_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; @@ -273,11 +273,11 @@ julia> values_as(vi, Vector) 2.0 ``` -`UntypedVarInfo`: +`VarInfo` with `Metadata`: ```jldoctest julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = UntypedVarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); + vi = DynamicPPL.untyped_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; @@ -354,7 +354,7 @@ demo (generic function with 2 methods) julia> model = demo(); -julia> varinfo = TypedVarInfo(model); +julia> varinfo = VarInfo(model); julia> keys(varinfo) 4-element Vector{VarName}: diff --git a/src/deprecated.jl b/src/deprecated.jl index a54142867..0bcaae9b7 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -1,14 +1 @@ @deprecate generated_quantities(model, params) returned(model, params) - -Base.@deprecate VarInfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) TypedVarInfo(rng, model, sampler, context) -Base.@deprecate VarInfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) TypedVarInfo(model, sampler, context) -Base.@deprecate VarInfo(model::Model, context::AbstractContext) TypedVarInfo(model, context) diff --git a/src/experimental.jl b/src/experimental.jl index ff25b96e0..84038803c 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -72,7 +72,7 @@ julia> # Typed varinfo cannot handle this random support model properly ┌ Warning: Model seems incompatible with typed varinfo. Falling back to untyped varinfo. └ @ DynamicPPLJETExt ~/.julia/dev/DynamicPPL.jl/ext/DynamicPPLJETExt.jl:48 -julia> vi isa typeof(DynamicPPL.UntypedVarInfo(model)) +julia> vi isa typeof(DynamicPPL.untyped_varinfo(model)) true julia> # In contrast, a simple model with no random support can be handled by typed varinfo. @@ -81,7 +81,7 @@ model_with_static_support (generic function with 2 methods) julia> vi = determine_suitable_varinfo(model_with_static_support()); -julia> vi isa typeof(DynamicPPL.TypedVarInfo(model_with_static_support())) +julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support())) true ``` """ @@ -97,7 +97,7 @@ function determine_suitable_varinfo( # Warn the user. @warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo." # Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat). - DynamicPPL.TypedVarInfo(model, context) + DynamicPPL.typed_varinfo(model, context) end end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 26025d5c7..a42855f05 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -79,7 +79,7 @@ julia> LogDensityProblems.logdensity(f, [0.0]) -2.3378770664093453 julia> # This also respects the context in `model`. - f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), TypedVarInfo(model)); + f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), VarInfo(model)); julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true @@ -109,7 +109,7 @@ struct LogDensityFunction{ function LogDensityFunction( model::Model, - varinfo::AbstractVarInfo=TypedVarInfo(model), + varinfo::AbstractVarInfo=VarInfo(model), context::AbstractContext=leafcontext(model.context); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) diff --git a/src/model.jl b/src/model.jl index 76ad4e08d..b4d5f6bb7 100644 --- a/src/model.jl +++ b/src/model.jl @@ -433,7 +433,7 @@ julia> conditioned(cm) julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, # `a.m` is treated as a random variable. - keys(TypedVarInfo(cm)) + keys(VarInfo(cm)) 1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: a.m @@ -446,7 +446,7 @@ julia> conditioned(cm)[@varname(x)] julia> conditioned(cm)[@varname(a.m)] 1.0 -julia> keys(TypedVarInfo(cm)) # No variables are sampled +julia> keys(VarInfo(cm)) # No variables are sampled VarName[] ``` """ @@ -773,7 +773,7 @@ julia> fixed(cm) julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed, # `a.m` is treated as a random variable. - keys(TypedVarInfo(cm)) + keys(VarInfo(cm)) 1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: a.m @@ -786,7 +786,7 @@ julia> fixed(cm)[@varname(x)] julia> fixed(cm)[@varname(a.m)] 1.0 -julia> keys(TypedVarInfo(cm)) # <= no variables are sampled +julia> keys(VarInfo(cm)) # <= no variables are sampled VarName[] ``` """ @@ -1037,7 +1037,7 @@ julia> logjoint(demo_model([1., 2.]), chain); ``` """ function logjoint(model::Model, chain::AbstractMCMC.AbstractChains) - var_info = TypedVarInfo(model) # extract variables info from the model + var_info = VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) argvals_dict = OrderedDict( vn_parent => @@ -1084,7 +1084,7 @@ julia> logprior(demo_model([1., 2.]), chain); ``` """ function logprior(model::Model, chain::AbstractMCMC.AbstractChains) - var_info = TypedVarInfo(model) # extract variables info from the model + var_info = VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) argvals_dict = OrderedDict( vn_parent => @@ -1131,7 +1131,7 @@ julia> loglikelihood(demo_model([1., 2.]), chain); ``` """ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) - var_info = TypedVarInfo(model) # extract variables info from the model + var_info = VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) argvals_dict = OrderedDict( vn_parent => @@ -1339,7 +1339,7 @@ julia> @model function demo2(x, y) When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: ```jldoctest submodel-to_submodel -julia> vi = TypedVarInfo(demo2(missing, 0.4)); +julia> vi = VarInfo(demo2(missing, 0.4)); julia> @varname(a.x) in keys(vi) true @@ -1376,7 +1376,7 @@ julia> @model function demo2_no_prefix(x, z) return z ~ Uniform(-a, 1) end; -julia> vi = TypedVarInfo(demo2_no_prefix(missing, 0.4)); +julia> vi = VarInfo(demo2_no_prefix(missing, 0.4)); julia> @varname(x) in keys(vi) # here we just use `x` instead of `a.x` true @@ -1391,7 +1391,7 @@ julia> @model function demo2(x, y, z) return z ~ Uniform(-a, b) end; -julia> vi = TypedVarInfo(demo2(missing, missing, 0.4)); +julia> vi = VarInfo(demo2(missing, missing, 0.4)); julia> @varname(sub1.x) in keys(vi) true diff --git a/src/model_utils.jl b/src/model_utils.jl index 290231014..ac4ec7022 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -4,7 +4,7 @@ Return `true` if all variable names in `model`/`varinfo` are in `chain`. """ -varnames_in_chain(model::Model, chain) = varnames_in_chain(TypedVarInfo(model), chain) +varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain) function varnames_in_chain(varinfo::VarInfo, chain) return all(vn -> varname_in_chain(varinfo, vn, chain, 1, 1), keys(varinfo)) end @@ -16,7 +16,7 @@ end Return `out` with `true` for all variable names in `model` that are in `chain`. """ function varnames_in_chain!(model::Model, chain, out) - return varnames_in_chain!(TypedVarInfo(model), chain, out) + return varnames_in_chain!(VarInfo(model), chain, out) end function varnames_in_chain!(varinfo::VarInfo, chain, out) for vn in keys(varinfo) @@ -33,7 +33,7 @@ end Return `true` if `vn` is in `chain` at `chain_idx` and `iteration_idx`. """ function varname_in_chain(model::Model, vn, chain, chain_idx, iteration_idx) - return varname_in_chain(TypedVarInfo(model), vn, chain, chain_idx, iteration_idx) + return varname_in_chain(VarInfo(model), vn, chain, chain_idx, iteration_idx) end function varname_in_chain(varinfo::AbstractVarInfo, vn, chain, chain_idx, iteration_idx) @@ -60,7 +60,7 @@ This differs from [`varname_in_chain`](@ref) in that it returns a dictionary rather than a single boolean. This can be quite useful for debugging purposes. """ function varname_in_chain!(model::Model, vn, chain, chain_idx, iteration_idx, out) - return varname_in_chain!(TypedVarInfo(model), vn, chain, chain_idx, iteration_idx, out) + return varname_in_chain!(VarInfo(model), vn, chain, chain_idx, iteration_idx, out) end function varname_in_chain!( @@ -132,7 +132,7 @@ Mutate `out` to map each variable name in `model`/`varinfo` to its value in `chain` at `chain_idx` and `iteration_idx`. """ function values_from_chain!(model::Model, chain, chain_idx, iteration_idx, out) - return values_from_chain(TypedVarInfo(model), chain, chain_idx, iteration_idx, out) + return values_from_chain(VarInfo(model), chain, chain_idx, iteration_idx, out) end function values_from_chain!(vi::AbstractVarInfo, chain, chain_idx, iteration_idx, out) @@ -197,7 +197,7 @@ julia> conditioned_model() # <= results in same values as the `first(iter)` abo ``` """ function value_iterator_from_chain(model::Model, chain) - return value_iterator_from_chain(TypedVarInfo(model), chain) + return value_iterator_from_chain(VarInfo(model), chain) end function value_iterator_from_chain(vi::AbstractVarInfo, chain) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index b69ccce67..cb9ea4894 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -226,12 +226,12 @@ julia> @model function demo(x) julia> m = demo([1.0, ]); -julia> ℓ = pointwise_logdensities(m, TypedVarInfo(m)); first(ℓ[@varname(x[1])]) +julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x[1])]) -1.4189385332046727 julia> m = demo([1.0; 1.0]); -julia> ℓ = pointwise_logdensities(m, TypedVarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) +julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) (-1.4189385332046727, -1.4189385332046727) ``` @@ -240,7 +240,7 @@ function pointwise_logdensities( model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() ) where {T} # Get the data by executing the model once - vi = TypedVarInfo(model) + vi = VarInfo(model) point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) diff --git a/src/sampler.jl b/src/sampler.jl index c15ac6006..49d910fec 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -86,7 +86,7 @@ function default_varinfo( context::AbstractContext, ) init_sampler = initialsampler(sampler) - return TypedVarInfo(rng, model, init_sampler, context) + return typed_varinfo(rng, model, init_sampler, context) end function AbstractMCMC.sample( diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 74cbfb231..abf14b8fc 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -244,7 +244,7 @@ function SimpleVarInfo{T}( end # Constructor from `VarInfo`. -function SimpleVarInfo(vi::NTVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D} +function SimpleVarInfo(vi::NTVarInfo, (::Type{D})=NamedTuple; kwargs...) where {D} return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...) end function SimpleVarInfo{T}( diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index b4090e116..f6b9c4479 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -24,7 +24,7 @@ julia> @model function demo2(x, y) When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: ```jldoctest submodel -julia> vi = TypedVarInfo(demo2(missing, 0.4)); +julia> vi = VarInfo(demo2(missing, 0.4)); ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -91,7 +91,7 @@ julia> @model function demo2(x, y, z) When we sample from the model `demo2(missing, missing, 0.4)` random variables `sub1.x` and `sub2.x` will be sampled: ```jldoctest submodelprefix -julia> vi = TypedVarInfo(demo2(missing, missing, 0.4)); +julia> vi = VarInfo(demo2(missing, missing, 0.4)); ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -137,7 +137,7 @@ julia> # When `prefix` is unspecified, no prefix is used. @model submodel_noprefix() = @submodel a = inner() submodel_noprefix (generic function with 2 methods) -julia> @varname(x) in keys(TypedVarInfo(submodel_noprefix())) +julia> @varname(x) in keys(VarInfo(submodel_noprefix())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -147,7 +147,7 @@ julia> # Explicitely don't use any prefix. @model submodel_prefix_false() = @submodel prefix=false a = inner() submodel_prefix_false (generic function with 2 methods) -julia> @varname(x) in keys(TypedVarInfo(submodel_prefix_false())) +julia> @varname(x) in keys(VarInfo(submodel_prefix_false())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -157,7 +157,7 @@ julia> # Automatically determined from `a`. @model submodel_prefix_true() = @submodel prefix=true a = inner() submodel_prefix_true (generic function with 2 methods) -julia> @varname(a.x) in keys(TypedVarInfo(submodel_prefix_true())) +julia> @varname(a.x) in keys(VarInfo(submodel_prefix_true())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -167,7 +167,7 @@ julia> # Using a static string. @model submodel_prefix_string() = @submodel prefix="my prefix" a = inner() submodel_prefix_string (generic function with 2 methods) -julia> @varname(var"my prefix".x) in keys(TypedVarInfo(submodel_prefix_string())) +julia> @varname(var"my prefix".x) in keys(VarInfo(submodel_prefix_string())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -177,7 +177,7 @@ julia> # Using string interpolation. @model submodel_prefix_interpolation() = @submodel prefix="\$(nameof(inner()))" a = inner() submodel_prefix_interpolation (generic function with 2 methods) -julia> @varname(inner.x) in keys(TypedVarInfo(submodel_prefix_interpolation())) +julia> @varname(inner.x) in keys(VarInfo(submodel_prefix_interpolation())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 @@ -187,7 +187,7 @@ julia> # Or using some arbitrary expression. @model submodel_prefix_expr() = @submodel prefix=1 + 2 a = inner() submodel_prefix_expr (generic function with 2 methods) -julia> @varname(var"3".x) in keys(TypedVarInfo(submodel_prefix_expr())) +julia> @varname(var"3".x) in keys(VarInfo(submodel_prefix_expr())) ┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. │ caller = ip:0x0 └ @ Core :-1 diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 5150be64b..7404a9af7 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -94,7 +94,7 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod @test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true) @test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true) # Typed varinfo. - varinfo_typed = DynamicPPL.TypedVarInfo(varinfo_untyped) + varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped) @test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true) @test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true) end diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index c092201de..539872143 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -27,10 +27,10 @@ function setup_varinfos( model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false ) # VarInfo - vi_untyped_metadata = UntypedVarInfo(model) - vi_untyped_vnv = DynamicPPL.UntypedVectorVarInfo(model) - vi_typed_metadata = TypedVarInfo(vi_untyped_metadata) - vi_typed_vnv = DynamicPPL.TypedVectorVarInfo(vi_untyped_vnv) + vi_untyped_metadata = DynamicPPL.untyped_varinfo(model) + vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model) + vi_typed_metadata = DynamicPPL.typed_varinfo(model) + vi_typed_vnv = DynamicPPL.typed_vector_varinfo(model) # SimpleVarInfo svi_typed = SimpleVarInfo(example_values) diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 6d4cf019d..d3bfd697a 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -131,7 +131,7 @@ julia> @model function model_changing_support() julia> model = model_changing_support(); julia> # Construct initial type-stable `VarInfo`. - varinfo = TypedVarInfo(rng, model); + varinfo = VarInfo(rng, model); julia> # Link it so it works in unconstrained space. varinfo_linked = DynamicPPL.link(varinfo, model); diff --git a/src/varinfo.jl b/src/varinfo.jl index 51fb482c4..2791386f1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -95,11 +95,40 @@ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo num_produce::Base.RefValue{Int} end VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) +function VarInfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + return typed_varinfo(rng, model, sampler, context) +end +function VarInfo( + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + # No rng + return VarInfo(Random.default_rng(), model, sampler, context) +end +function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) + # No rng + return VarInfo(Random.default_rng(), model, sampler, context) +end +function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) + # No sampler + return VarInfo(rng, model, SampleFromPrior(), context) +end +function VarInfo(model::Model, context::AbstractContext) + # No sampler, no rng + return VarInfo(Random.default_rng(), model, SampleFromPrior(), context) +end const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} const UntypedVarInfo = VarInfo{<:Metadata} # TODO: NTVarInfo carries no information about the type of the actual metadata # i.e. the elements of the NamedTuple. It could be Metadata or it could be +# VarNamedVector. # Resolving this ambiguity would likely require us to replace NamedTuple with # something which carried both its keys as well as its values' types as type # parameters. @@ -149,7 +178,7 @@ end ######################## """ - UntypedVarInfo([rng, ]model[, sampler, context, metadata]) + untyped_varinfo([rng, ]model[, sampler, context, metadata]) Return a VarInfo object for the given `model` and `context`, which has just a single `Metadata` as its metadata field. @@ -160,7 +189,7 @@ single `Metadata` as its metadata field. - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. - `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ -function UntypedVarInfo( +function untyped_varinfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), @@ -170,36 +199,25 @@ function UntypedVarInfo( context = SamplingContext(rng, sampler, context) return last(evaluate!!(model, varinfo, context)) end -function UntypedVarInfo( +function untyped_varinfo( model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) # No rng - return UntypedVarInfo(Random.default_rng(), model, sampler, context) + return untyped_varinfo(Random.default_rng(), model, sampler, context) end -function UntypedVarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) +function untyped_varinfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) # No sampler - return UntypedVarInfo(rng, model, SampleFromPrior(), context) + return untyped_varinfo(rng, model, SampleFromPrior(), context) end -function UntypedVarInfo(model::Model, context::AbstractContext) +function untyped_varinfo(model::Model, context::AbstractContext) # No sampler, no rng - return UntypedVarInfo(model, SampleFromPrior(), context) + return untyped_varinfo(model, SampleFromPrior(), context) end """ - TypedVarInfo([rng, ]model[, sampler, context, metadata]) - -Return a VarInfo object for the given `model` and `context`, which has a NamedTuple of -`Metadata` structs as its metadata field. - -# Arguments -- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation -- `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. - - TypedVarInfo(vi::UntypedVarInfo) + typed_varinfo(vi::UntypedVarInfo) This function finds all the unique `sym`s from the instances of `VarName{sym}` found in `vi.metadata.vns`. It then extracts the metadata associated with each symbol from the @@ -207,7 +225,7 @@ global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `meta a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each symbol. """ -function TypedVarInfo(vi::UntypedVarInfo) +function typed_varinfo(vi::UntypedVarInfo) meta = vi.metadata new_metas = Metadata[] # Symbols of all instances of `VarName{sym}` in `vi.vns` @@ -251,39 +269,53 @@ function TypedVarInfo(vi::UntypedVarInfo) nt = NamedTuple{syms_tuple}(Tuple(new_metas)) return VarInfo(nt, Ref(logp), Ref(num_produce)) end -function TypedVarInfo(vi::NTVarInfo) - # This function preserves the behaviour of TypedVarInfo(vi) where vi is - # already a TypedVarInfo - has_varnamedvector(vi) && error("Cannot convert TypedVectorVarInfo to TypedVarInfo") +function typed_varinfo(vi::NTVarInfo) + # This function preserves the behaviour of typed_varinfo(vi) where vi is + # already a NTVarInfo + has_varnamedvector(vi) && error( + "Cannot convert VarInfo with NamedTuple of VarNamedVector to VarInfo with NamedTuple of Metadata", + ) return vi end -function TypedVarInfo( +""" + typed_varinfo([rng, ]model[, sampler, context, metadata]) + +Return a VarInfo object for the given `model` and `context`, which has a NamedTuple of +`Metadata` structs as its metadata field. + +# Arguments +- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation +- `model::Model`: The model for which to create the varinfo object +- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. +""" +function typed_varinfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) - return TypedVarInfo(UntypedVarInfo(rng, model, sampler, context)) + return typed_varinfo(untyped_varinfo(rng, model, sampler, context)) end -function TypedVarInfo( +function typed_varinfo( model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) # No rng - return TypedVarInfo(Random.default_rng(), model, sampler, context) + return typed_varinfo(Random.default_rng(), model, sampler, context) end -function TypedVarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) +function typed_varinfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) # No sampler - return TypedVarInfo(rng, model, SampleFromPrior(), context) + return typed_varinfo(rng, model, SampleFromPrior(), context) end -function TypedVarInfo(model::Model, context::AbstractContext) +function typed_varinfo(model::Model, context::AbstractContext) # No sampler, no rng - return TypedVarInfo(model, SampleFromPrior(), context) + return typed_varinfo(model, SampleFromPrior(), context) end """ - UntypedVectorVarInfo([rng, ]model[, sampler, context, metadata]) + untyped_vector_varinfo([rng, ]model[, sampler, context, metadata]) Return a VarInfo object for the given `model` and `context`, which has just a single `VarNamedVector` as its metadata field. @@ -294,40 +326,40 @@ single `VarNamedVector` as its metadata field. - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. - `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ -function UntypedVectorVarInfo(vi::UntypedVarInfo) +function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) lp = getlogp(vi) return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) end -function UntypedVectorVarInfo( +function untyped_vector_varinfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) - return UntypedVectorVarInfo(UntypedVarInfo(rng, model, sampler, context)) + return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler, context)) end -function UntypedVectorVarInfo( +function untyped_vector_varinfo( model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) # No rng - return UntypedVectorVarInfo(Random.default_rng(), model, sampler, context) + return untyped_vector_varinfo(Random.default_rng(), model, sampler, context) end -function UntypedVectorVarInfo( +function untyped_vector_varinfo( rng::Random.AbstractRNG, model::Model, context::AbstractContext ) # No sampler - return UntypedVectorVarInfo(rng, model, SampleFromPrior(), context) + return untyped_vector_varinfo(rng, model, SampleFromPrior(), context) end -function UntypedVectorVarInfo(model::Model, context::AbstractContext) +function untyped_vector_varinfo(model::Model, context::AbstractContext) # No sampler, no rng - return UntypedVectorVarInfo(model, SampleFromPrior(), context) + return untyped_vector_varinfo(model, SampleFromPrior(), context) end """ - TypedVectorVarInfo([rng, ]model[, sampler, context, metadata]) + typed_vector_varinfo([rng, ]model[, sampler, context, metadata]) Return a VarInfo object for the given `model` and `context`, which has a NamedTuple of `VarNamedVector`s as its metadata field. @@ -338,41 +370,43 @@ NamedTuple of `VarNamedVector`s as its metadata field. - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. - `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ -function TypedVectorVarInfo(vi::NTVarInfo) +function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) lp = getlogp(vi) return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) end -function TypedVectorVarInfo(vi::UntypedVectorVarInfo) +function typed_vector_varinfo(vi::UntypedVectorVarInfo) new_metas = group_by_symbol(vi.metadata) logp = getlogp(vi) num_produce = get_num_produce(vi) nt = NamedTuple(new_metas) return VarInfo(nt, Ref(logp), Ref(num_produce)) end -function TypedVectorVarInfo( +function typed_vector_varinfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) - return TypedVectorVarInfo(UntypedVectorVarInfo(rng, model, sampler, context)) + return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler, context)) end -function TypedVectorVarInfo( +function typed_vector_varinfo( model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) # No rng - return TypedVectorVarInfo(Random.default_rng(), model, sampler, context) + return typed_vector_varinfo(Random.default_rng(), model, sampler, context) end -function TypedVectorVarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) +function typed_vector_varinfo( + rng::Random.AbstractRNG, model::Model, context::AbstractContext +) # No sampler - return TypedVectorVarInfo(rng, model, SampleFromPrior(), context) + return typed_vector_varinfo(rng, model, SampleFromPrior(), context) end -function TypedVectorVarInfo(model::Model, context::AbstractContext) +function typed_vector_varinfo(model::Model, context::AbstractContext) # No sampler, no rng - return TypedVectorVarInfo(model, SampleFromPrior(), context) + return typed_vector_varinfo(model, SampleFromPrior(), context) end """ @@ -2003,7 +2037,7 @@ julia> rng = StableRNG(42); julia> m = demo([missing]); -julia> var_info = DynamicPPL.TypedVarInfo(rng, m); +julia> var_info = DynamicPPL.VarInfo(rng, m); julia> var_info[@varname(m)] -0.6702516921145671 @@ -2067,7 +2101,7 @@ julia> rng = StableRNG(42); julia> m = demo([missing]); -julia> var_info = DynamicPPL.TypedVarInfo(rng, m); +julia> var_info = DynamicPPL.VarInfo(rng, m); # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. julia> var_info[@varname(m)] diff --git a/test/ad.jl b/test/ad.jl index 3b665fd37..a4f3dbfa7 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -100,7 +100,7 @@ using DynamicPPL: LogDensityFunction # Compiling the ReverseDiff tape used to fail here spl = Sampler(MyEmptyAlg()) - vi = TypedVarInfo(model) + vi = VarInfo(model) ldf = LogDensityFunction( model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true) ) diff --git a/test/compiler.jl b/test/compiler.jl index 84c19e55f..a0286d405 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -193,7 +193,7 @@ module Issue537 end return x end model = testmodel_missing3([1.0]) - varinfo = TypedVarInfo(model) + varinfo = VarInfo(model) @test getlogp(varinfo) == lp @test varinfo_ isa AbstractVarInfo @test model_ === model @@ -213,7 +213,7 @@ module Issue537 end end false lpold = lp model = testmodel_missing4([1.0]) - varinfo = TypedVarInfo(model) + varinfo = VarInfo(model) @test getlogp(varinfo) == lp == lpold # test DPPL#61 @@ -234,13 +234,13 @@ module Issue537 end end end x = [1.0, missing] - TypedVarInfo(gdemo(x)) + VarInfo(gdemo(x)) @test ismissing(x[2]) # https://github.com/TuringLang/Turing.jl/issues/1464#issuecomment-731153615 - vi = TypedVarInfo(gdemo(x)) + vi = VarInfo(gdemo(x)) @test haskey(vi.metadata, :x) - vi = TypedVarInfo(gdemo(x)) + vi = VarInfo(gdemo(x)) @test haskey(vi.metadata, :x) # Non-array variables @@ -339,16 +339,16 @@ module Issue537 end return testmodel end model = makemodel(0.5)([1.0]) - varinfo = TypedVarInfo(model) + varinfo = VarInfo(model) @test getlogp(varinfo) == lp end @testset "user-defined variable name" begin @model f1() = x ~ NamedDist(Normal(), :y) @model f2() = x ~ NamedDist(Normal(), @varname(y[2][:, 1])) @model f3() = x ~ NamedDist(Normal(), @varname(y[1])) - vi1 = TypedVarInfo(f1()) - vi2 = TypedVarInfo(f2()) - vi3 = TypedVarInfo(f3()) + vi1 = VarInfo(f1()) + vi2 = VarInfo(f2()) + vi3 = VarInfo(f3()) @test haskey(vi1.metadata, :y) @test first(Base.keys(vi1.metadata.y)) == @varname(y) @test haskey(vi2.metadata, :y) @@ -434,28 +434,28 @@ module Issue537 end end # No observation. m = demo2(missing, missing) - vi = TypedVarInfo(m) + vi = VarInfo(m) ks = keys(vi) @test @varname(x) ∈ ks @test @varname(y) ∈ ks # Observation in top-level. m = demo2(missing, 1.0) - vi = TypedVarInfo(m) + vi = VarInfo(m) ks = keys(vi) @test @varname(x) ∈ ks @test @varname(y) ∉ ks # Observation in nested model. m = demo2(1000.0, missing) - vi = TypedVarInfo(m) + vi = VarInfo(m) ks = keys(vi) @test @varname(x) ∉ ks @test @varname(y) ∈ ks # Observe all. m = demo2(1000.0, 0.5) - vi = TypedVarInfo(m) + vi = VarInfo(m) ks = keys(vi) @test isempty(ks) @@ -479,7 +479,7 @@ module Issue537 end return z ~ Normal(sub1 + sub2 + 100, 1.0) end m = demo_useval(missing, missing) - vi = TypedVarInfo(m) + vi = VarInfo(m) ks = keys(vi) @test @varname(sub1.x) ∈ ks @test @varname(sub2.x) ∈ ks @@ -512,7 +512,7 @@ module Issue537 end ys = [randn(10), randn(10)] m = demo(ys) - vi = TypedVarInfo(m) + vi = VarInfo(m) for vn in [@varname(α), @varname(μ), @varname(σ), @varname(ar1_1.η), @varname(ar1_2.η)] @@ -740,7 +740,7 @@ module Issue537 end @test model() isa NamedTuple{(:x, :y)} # `VarInfo` should only contain `x`. - varinfo = TypedVarInfo(model) + varinfo = VarInfo(model) @test haskey(varinfo, @varname(x)) @test !haskey(varinfo, @varname(y)) diff --git a/test/deprecated.jl b/test/deprecated.jl index 25c4fc2b9..500d3eb7f 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -29,7 +29,7 @@ return a, b end @test outer()() isa Tuple{Float64,Float64} - vi = TypedVarInfo(outer()) + vi = VarInfo(outer()) @test @varname(x) in keys(vi) @test @varname(sub.x) in keys(vi) end @@ -46,7 +46,7 @@ @test model() == y_val x_val = 1.5 - vi = TypedVarInfo(outer(y_val)) + vi = VarInfo(outer(y_val)) DynamicPPL.setindex!!(vi, x_val, @varname(x)) @test logprior(model, vi) ≈ logpdf(Normal(), x_val) @test loglikelihood(model, vi) ≈ logpdf(Normal(x_val), y_val) diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl index e9b8d381c..73a0510e9 100644 --- a/test/ext/DynamicPPLForwardDiffExt.jl +++ b/test/ext/DynamicPPLForwardDiffExt.jl @@ -13,7 +13,7 @@ using Test: @test, @testset MODEL_SIZE = 10 @model f() = x ~ MvNormal(zeros(MODEL_SIZE), I) model = f() - varinfo = TypedVarInfo(model) + varinfo = VarInfo(model) context = DefaultContext() @testset "Chunk size setting" for chunksize in (nothing, 0) diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 21a817306..86329a51d 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -79,7 +79,7 @@ @test is_typed # If the test failed, check why it didn't infer a typed varinfo if !is_typed - typed_vi = TypedVarInfo(model) + typed_vi = DynamicPPL.typed_varinfo(model) f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, typed_vi ) diff --git a/test/model.jl b/test/model.jl index a41accaeb..dd5a35fe6 100644 --- a/test/model.jl +++ b/test/model.jl @@ -36,7 +36,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() model = GDEMO_DEFAULT # sample from model and extract variables - vi = TypedVarInfo(model) + vi = VarInfo(model) s = vi[@varname(s)] m = vi[@varname(m)] @@ -67,7 +67,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Construct mapping of varname symbols to varname-parent symbols. # Here, varname_leaves is used to ensure compatibility with the # variables stored in the chain - var_info = TypedVarInfo(model) + var_info = VarInfo(model) chain_sym_map = Dict{Symbol,Symbol}() for vn_parent in keys(var_info) sym = DynamicPPL.getsym(vn_parent) @@ -219,7 +219,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() model = GDEMO_DEFAULT # sample from model and extract variables - vi = TypedVarInfo(model) + vi = VarInfo(model) # Second component of return-value of `evaluate!!` should # be a `DynamicPPL.AbstractVarInfo`. @@ -233,7 +233,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "Dynamic constraints, Metadata" begin model = DynamicPPL.TestUtils.demo_dynamic_constraint() - vi = TypedVarInfo(model) + vi = VarInfo(model) vi = link!!(vi, model) for i in 1:10 @@ -249,8 +249,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "Dynamic constraints, VectorVarInfo" begin model = DynamicPPL.TestUtils.demo_dynamic_constraint() for i in 1:10 - vi = TypedVarInfo(model) - @test vi[@varname(x)] >= vi[@varname(m)] + for vi_constructor in + [DynamicPPL.typed_vector_varinfo, DynamicPPL.untyped_vector_varinfo] + vi = vi_constructor(model) + @test vi[@varname(x)] >= vi[@varname(m)] + end end end @@ -453,7 +456,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end for model in (outer_auto_prefix(), outer_manual_prefix()) - vi = TypedVarInfo(model) + vi = VarInfo(model) vns = Set(keys(values_as_in_model(model, false, vi))) @test vns == Set([@varname(a.x), @varname(b.x)]) end @@ -481,8 +484,8 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() model = product_dirichlet() varinfos = [ - DynamicPPL.UntypedVarInfo(model), - DynamicPPL.TypedVarInfo(model), + DynamicPPL.untyped_varinfo(model), + DynamicPPL.typed_varinfo(model), DynamicPPL.typed_simple_varinfo(model), DynamicPPL.untyped_simple_varinfo(model), ] diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index a7c99adae..61c842638 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -5,7 +5,7 @@ example_values = DynamicPPL.TestUtils.rand_prior_true(model) # Instantiate a `VarInfo` with the example values. - vi = TypedVarInfo(model) + vi = VarInfo(model) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end diff --git a/test/sampler.jl b/test/sampler.jl index addcf5202..8c4f1ed96 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -192,7 +192,7 @@ initial_z = 15 initial_x = [0.2, 0.5] model = constrained_uniform(n) - vi = TypedVarInfo(model) + vi = VarInfo(model) @test_throws ArgumentError DynamicPPL.initialize_parameters!!( vi, [initial_z, initial_x], model diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 60f392f92..aa3b592f7 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -92,7 +92,7 @@ SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), SimpleVarInfo(DynamicPPL.VarNamedVector()), - TypedVarInfo(model), + DynamicPPL.typed_varinfo(model), ) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) diff --git a/test/test_util.jl b/test/test_util.jl index a1061d89a..902dd7230 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -10,7 +10,7 @@ const gdemo_default = gdemo_d() # TODO(penelopeysm): Remove this (and also test/compat/ad.jl) function test_model_ad(model, logp_manual) - vi = TypedVarInfo(model) + vi = VarInfo(model) x = vi[:] # Log probabilities using the model. @@ -33,13 +33,17 @@ end Return string representing a short description of `vi`. """ -short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = - "threadsafe($(short_varinfo_name(vi.varinfo)))" -function short_varinfo_name(vi::NTVarInfo) - DynamicPPL.has_varnamedvector(vi) && return "TypedVectorVarInfo" - return "TypedVarInfo" +function short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) + return "threadsafe($(short_varinfo_name(vi.varinfo)))" end -short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" +function short_varinfo_name(vi::DynamicPPL.NTVarInfo) + return if DynamicPPL.has_varnamedvector(vi) + "TypedVectorVarInfo" + else + "TypedVarInfo" + end +end +short_varinfo_name(::DynamicPPL.UntypedVarInfo) = "UntypedVarInfo" short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo" function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) return "SimpleVarInfo{<:NamedTuple,<:Ref}" diff --git a/test/threadsafe.jl b/test/threadsafe.jl index e1490c280..72c439db8 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -1,6 +1,6 @@ @testset "threadsafe.jl" begin @testset "constructor" begin - vi = TypedVarInfo(gdemo_default) + vi = VarInfo(gdemo_default) threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi) @test threadsafe_vi.varinfo === vi @@ -11,7 +11,7 @@ # TODO: Add more tests of the public API @testset "API" begin - vi = TypedVarInfo(gdemo_default) + vi = VarInfo(gdemo_default) threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi) lp = getlogp(vi) diff --git a/test/utils.jl b/test/utils.jl index 91bba082e..d683f132d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -7,7 +7,7 @@ end model = testmodel() - varinfo = TypedVarInfo(model) + varinfo = VarInfo(model) @test iszero(lp_before) @test getlogp(varinfo) == lp_after == 42 end diff --git a/test/varinfo.jl b/test/varinfo.jl index a79dba816..777917aa6 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -34,7 +34,7 @@ function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) end @testset "varinfo.jl" begin - @testset "TypedVarInfo with Metadata" begin + @testset "VarInfo with NT of Metadata" begin @model gdemo(x, y) = begin s ~ InverseGamma(2, 3) m ~ truncated(Normal(0.0, sqrt(s)), 0.0, 2.0) @@ -43,8 +43,8 @@ end end model = gdemo(1.0, 2.0) - vi = UntypedVarInfo(model, SampleFromUniform()) - tvi = TypedVarInfo(vi) + vi = DynamicPPL.untyped_varinfo(model, SampleFromUniform()) + tvi = DynamicPPL.typed_varinfo(vi) meta = vi.metadata for f in fieldnames(typeof(tvi.metadata)) @@ -115,7 +115,7 @@ end vi = VarInfo() test_base!!(vi) - test_base!!(TypedVarInfo(vi)) + test_base!!(DynamicPPL.typed_varinfo(vi)) test_base!!(SimpleVarInfo()) test_base!!(SimpleVarInfo(Dict())) test_base!!(SimpleVarInfo(DynamicPPL.VarNamedVector())) @@ -134,7 +134,7 @@ end vi = VarInfo() test_varinfo_logp!(vi) - test_varinfo_logp!(TypedVarInfo(vi)) + test_varinfo_logp!(DynamicPPL.typed_varinfo(vi)) test_varinfo_logp!(SimpleVarInfo()) test_varinfo_logp!(SimpleVarInfo(Dict())) test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) @@ -161,15 +161,15 @@ end end vi = VarInfo() test_varinfo!(vi) - test_varinfo!(empty!!(TypedVarInfo(vi))) + test_varinfo!(empty!!(DynamicPPL.typed_varinfo(vi))) end - @testset "push!! to TypedVarInfo" begin + @testset "push!! to VarInfo with NT of Metadata" begin vn_x = @varname x vn_y = @varname y untyped_vi = VarInfo() untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1)) - typed_vi = TypedVarInfo(untyped_vi) + typed_vi = DynamicPPL.typed_varinfo(untyped_vi) typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1)) @test typed_vi[vn_x] == 1.0 @test typed_vi[vn_y] == 2.0 @@ -205,10 +205,10 @@ end m_vns = model == model_uv ? [@varname(m[i]) for i in 1:5] : @varname(m) s_vns = @varname(s) - vi_typed = TypedVarInfo(model) - vi_untyped = UntypedVarInfo(model) - vi_vnv = DynamicPPL.UntypedVectorVarInfo(model) - vi_vnv_typed = DynamicPPL.TypedVectorVarInfo(model) + vi_typed = DynamicPPL.typed_varinfo(model) + vi_untyped = DynamicPPL.untyped_varinfo(model) + vi_vnv = DynamicPPL.untyped_vector_varinfo(model) + vi_vnv_typed = DynamicPPL.typed_vector_varinfo(model) model_name = model == model_uv ? "univariate" : "multivariate" @testset "$(model_name), $(short_varinfo_name(vi))" for vi in [ @@ -311,7 +311,7 @@ end return x ~ filldist(MvNormal([1, 100], I), 2) end - vi = TypedVarInfo(demo()) + vi = VarInfo(demo()) vals_prev = vi.metadata.x.vals ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])] DynamicPPL.setval!(vi, vi.metadata.x.vals, ks) @@ -332,7 +332,7 @@ end `m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. """ function test_setval!(model, chain; sample_idx=1, chain_idx=1) - var_info = TypedVarInfo(model) + var_info = VarInfo(model) θ_old = var_info[:] DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) θ_new = var_info[:] @@ -398,7 +398,7 @@ end @test meta.vals ≈ v atol = 1e-10 # Check that linking and invlinking preserves the values - vi = TypedVarInfo(vi) + vi = DynamicPPL.typed_varinfo(vi) meta = vi.metadata v_s = copy(meta.s.vals) v_m = copy(meta.m.vals) @@ -452,8 +452,9 @@ end # Need to run once since we can't specify that we want to _sample_ # in the unconstrained space for `VarInfo` without having `vn` # present in the `varinfo`. - ## `UntypedVarInfo` - vi = UntypedVarInfo(model) + + ## `untyped_varinfo` + vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) @@ -461,8 +462,8 @@ end x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) - ## `TypedVarInfo` - vi = TypedVarInfo(model) + ## `typed_varinfo` + vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) @@ -813,8 +814,8 @@ end model_left = demo_merge_different_y() model_right = demo_merge_different_z() - varinfo_left = TypedVarInfo(model_left) - varinfo_right = TypedVarInfo(model_right) + varinfo_left = VarInfo(model_left) + varinfo_right = VarInfo(model_right) varinfo_right = DynamicPPL.settrans!!(varinfo_right, true, @varname(x)) varinfo_merged = merge(varinfo_left, varinfo_right) @@ -869,7 +870,7 @@ end return x end model1 = demo_dot(1) - varinfo1 = DynamicPPL.link!!(DynamicPPL.UntypedVarInfo(model1), model1) + varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1) # Sampling from `model2` should hit the `istrans(vi) == true` branches # because all the existing variables are linked. model2 = demo_dot(2) @@ -971,7 +972,7 @@ end @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] @test DynamicPPL.get_num_produce(vi) == 3 - vi = empty!!(DynamicPPL.TypedVarInfo(vi)) + vi = empty!!(DynamicPPL.typed_varinfo(vi)) # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 DynamicPPL.increment_num_produce!(vi) @@ -1012,7 +1013,7 @@ end @testset "issue #842" begin model = DynamicPPL.TestUtils.DEMO_MODELS[1] - varinfo = TypedVarInfo(model) + varinfo = VarInfo(model) n = length(varinfo[:]) # `Bool`. From 52a017adff31869cf4a2f1fe4fb6ba87e4fd79c8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 8 Apr 2025 17:06:17 +0100 Subject: [PATCH 10/12] Rewrite VarInfo docstring --- src/varinfo.jl | 59 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 2791386f1..7be24bd17 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -69,25 +69,34 @@ end ########### """ -``` -struct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo - metadata::Tmeta - logp::Base.RefValue{Tlogp} - num_produce::Base.RefValue{Int} -end -``` + struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo + metadata::Tmeta + logp::Base.RefValue{Tlogp} + num_produce::Base.RefValue{Int} + end + +A light wrapper over some kind of metadata. -A light wrapper over one or more instances of `Metadata`. Let `vi` be an instance of -`VarInfo`. If `vi isa VarInfo{<:Metadata}`, then only one `Metadata` instance is used -for all the sybmols. `VarInfo{<:Metadata}` is aliased `UntypedVarInfo`. If -`vi isa VarInfo{<:NamedTuple}`, then `vi.metadata` is a `NamedTuple` that maps each -symbol used on the LHS of `~` in the model to its `Metadata` instance. The latter allows -for the type specialization of `vi` after the first sampling iteration when all the -symbols have been observed. `VarInfo{<:NamedTuple}` is aliased `NTVarInfo`. +The type of the metadata can be one of a number of options. It may either be a +`Metadata` or a `VarNamedVector`, _or_, it may be a `NamedTuple` which maps +symbols to `Metadata` or `VarNamedVector` instances. Here, a _symbol_ refers +to a Julia variable and may consist of one or more `VarName`s which appear on +the left-hand side of tilde statements. For example, `x[1]` and `x[2]` both +have the same symbol `x`. -Note: It is the user's responsibility to ensure that each "symbol" is visited at least -once whenever the model is called, regardless of any stochastic branching. Each symbol -refers to a Julia variable and can be a hierarchical array of many random variables, e.g. `x[1] ~ ...` and `x[2] ~ ...` both have the same symbol `x`. +Several type aliases are provided for these forms of VarInfos: +- `VarInfo{<:Metadata}` is `UntypedVarInfo` +- `VarInfo{<:VarNamedVector}` is `UntypedVectorVarInfo` +- `VarInfo{<:NamedTuple}` is `NTVarInfo` + +The NamedTuple form, i.e. `NTVarInfo`, is useful for maintaining type stability +of model evaluation. However, the element type of NamedTuples are not contained +in its type itself: thus, there is no way to use the type system to determine +whether the elements of the NamedTuple are `Metadata` or `VarNamedVector`. + +Note that for NTVarInfo, it is the user's responsibility to ensure that each +symbol is visited at least once during model evaluation, regardless of any +stochastic branching. """ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo metadata::Tmeta @@ -95,6 +104,22 @@ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo num_produce::Base.RefValue{Int} end VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) +""" + VarInfo([rng, ]model[, sampler, context]) + +Generate a `VarInfo` object for the given `model`, by evaluating it once using +the given `rng`, `sampler`, and `context`. + +!!! warning + + This function currently returns a `VarInfo` with its metadata field set to + a `NamedTuple` of `Metadata`. This is an implementation detail. In general, + this function may return any kind of object that satisfies the + `AbstractVarInfo` interface. If you require precise control over the type + of `VarInfo` returned, use the internal functions `untyped_varinfo`, + `typed_varinfo`, `untyped_vector_varinfo`, or `typed_vector_varinfo` + instead. +""" function VarInfo( rng::Random.AbstractRNG, model::Model, From 46a9d42a4d8098eede89647ffcda57ea3b964ee4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 8 Apr 2025 19:12:57 +0100 Subject: [PATCH 11/12] Fix methods --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 7be24bd17..d77355313 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -136,7 +136,7 @@ function VarInfo( # No rng return VarInfo(Random.default_rng(), model, sampler, context) end -function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) +function VarInfo(model::Model, sampler::AbstractSampler, context::AbstractContext) # No rng return VarInfo(Random.default_rng(), model, sampler, context) end From dbddb2f2aa37e6d1102f88c693f0c5169af2769c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 8 Apr 2025 20:55:48 +0100 Subject: [PATCH 12/12] Fix methods (really) --- src/varinfo.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index d77355313..360857ef7 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -136,10 +136,6 @@ function VarInfo( # No rng return VarInfo(Random.default_rng(), model, sampler, context) end -function VarInfo(model::Model, sampler::AbstractSampler, context::AbstractContext) - # No rng - return VarInfo(Random.default_rng(), model, sampler, context) -end function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) # No sampler return VarInfo(rng, model, SampleFromPrior(), context)