From 8a91ab79deefb717b14d1a1194b17eb7280cb182 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 30 Jul 2025 15:58:58 +0100 Subject: [PATCH 1/2] Change default Metadata value type to Any --- src/varinfo.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 07b1869c9..43940685e 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -37,7 +37,7 @@ struct Metadata{ TIdcs<:Dict{<:VarName,Int}, TDists<:AbstractVector{<:Distribution}, TVN<:AbstractVector{<:VarName}, - TVal<:AbstractVector{<:Real}, + TVal<:AbstractVector{<:Any}, } # Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists` idcs::TIdcs # Dict{<:VarName,Int} @@ -51,7 +51,7 @@ struct Metadata{ # Vector of values of all the univariate, multivariate and matrix variables # The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]` - vals::TVal # AbstractVector{<:Real} + vals::TVal # AbstractVector{<:Any} # Vector of distributions correpsonding to `vns` dists::TDists # AbstractVector{<:Distribution} @@ -409,7 +409,7 @@ unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) Construct an empty type unstable instance of `Metadata`. """ function Metadata() - vals = Vector{Real}() + vals = Vector{Any}() flags = Dict{String,BitVector}() flags["del"] = BitVector() flags["trans"] = BitVector() @@ -576,8 +576,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) T_right = eltype(metadata_right.vals) T = promote_type(T_left, T_right) # TODO: Is this necessary? - if !(T <: Real) - T = Real + if !(T <: Any) + T = Any end # Determine `eltype` of `dists`. @@ -767,7 +767,7 @@ function getindex_internal(vi::NTVarInfo, ::Colon) return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata)) end function getindex_internal(vi::VarInfo{NamedTuple{(),Tuple{}}}, ::Colon) - return float(Real)[] + return Any[] end function getindex_internal(md::Metadata, ::Colon) return mapreduce( From 2a0fbf2955620f264ff905046c3420b776830cb3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 30 Jul 2025 16:47:42 +0100 Subject: [PATCH 2/2] let's try with `DynamicPPLAny` --- src/varinfo.jl | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 43940685e..cea7cd0c2 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -6,6 +6,10 @@ # VarInfo metadata # #################### +abstract type DynamicPPLAny end +Base.convert(::Type{DynamicPPLAny}, x) = x +Base.zero(::Type{DynamicPPLAny}) = zero(LogProbType) + """ The `Metadata` struct stores some metadata about the parameters of the model. This helps query certain information about a variable, such as its distribution, which samplers @@ -37,7 +41,7 @@ struct Metadata{ TIdcs<:Dict{<:VarName,Int}, TDists<:AbstractVector{<:Distribution}, TVN<:AbstractVector{<:VarName}, - TVal<:AbstractVector{<:Any}, + TVal<:AbstractVector{<:DynamicPPLAny}, } # Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists` idcs::TIdcs # Dict{<:VarName,Int} @@ -51,7 +55,7 @@ struct Metadata{ # Vector of values of all the univariate, multivariate and matrix variables # The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]` - vals::TVal # AbstractVector{<:Any} + vals::TVal # AbstractVector{<:DynamicPPLAny} # Vector of distributions correpsonding to `vns` dists::TDists # AbstractVector{<:Distribution} @@ -409,7 +413,7 @@ unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) Construct an empty type unstable instance of `Metadata`. """ function Metadata() - vals = Vector{Any}() + vals = Vector{DynamicPPLAny}() flags = Dict{String,BitVector}() flags["del"] = BitVector() flags["trans"] = BitVector() @@ -576,8 +580,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) T_right = eltype(metadata_right.vals) T = promote_type(T_left, T_right) # TODO: Is this necessary? - if !(T <: Any) - T = Any + if !(T <: DynamicPPLAny) + T = DynamicPPLAny end # Determine `eltype` of `dists`. @@ -766,8 +770,8 @@ getindex_internal(vi::VarInfo, ::Colon) = getindex_internal(vi.metadata, Colon() function getindex_internal(vi::NTVarInfo, ::Colon) return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata)) end -function getindex_internal(vi::VarInfo{NamedTuple{(),Tuple{}}}, ::Colon) - return Any[] +function getindex_internal(::VarInfo{NamedTuple{(),Tuple{}}}, ::Colon) + return DynamicPPLAny[] end function getindex_internal(md::Metadata, ::Colon) return mapreduce(