diff --git a/src/varinfo.jl b/src/varinfo.jl index 07b1869c9..cea7cd0c2 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -6,6 +6,10 @@ # VarInfo metadata # #################### +abstract type DynamicPPLAny end +Base.convert(::Type{DynamicPPLAny}, x) = x +Base.zero(::Type{DynamicPPLAny}) = zero(LogProbType) + """ The `Metadata` struct stores some metadata about the parameters of the model. This helps query certain information about a variable, such as its distribution, which samplers @@ -37,7 +41,7 @@ struct Metadata{ TIdcs<:Dict{<:VarName,Int}, TDists<:AbstractVector{<:Distribution}, TVN<:AbstractVector{<:VarName}, - TVal<:AbstractVector{<:Real}, + TVal<:AbstractVector{<:DynamicPPLAny}, } # Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists` idcs::TIdcs # Dict{<:VarName,Int} @@ -51,7 +55,7 @@ struct Metadata{ # Vector of values of all the univariate, multivariate and matrix variables # The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]` - vals::TVal # AbstractVector{<:Real} + vals::TVal # AbstractVector{<:DynamicPPLAny} # Vector of distributions correpsonding to `vns` dists::TDists # AbstractVector{<:Distribution} @@ -409,7 +413,7 @@ unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) Construct an empty type unstable instance of `Metadata`. """ function Metadata() - vals = Vector{Real}() + vals = Vector{DynamicPPLAny}() flags = Dict{String,BitVector}() flags["del"] = BitVector() flags["trans"] = BitVector() @@ -576,8 +580,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) T_right = eltype(metadata_right.vals) T = promote_type(T_left, T_right) # TODO: Is this necessary? - if !(T <: Real) - T = Real + if !(T <: DynamicPPLAny) + T = DynamicPPLAny end # Determine `eltype` of `dists`. @@ -766,8 +770,8 @@ getindex_internal(vi::VarInfo, ::Colon) = getindex_internal(vi.metadata, Colon() function getindex_internal(vi::NTVarInfo, ::Colon) return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata)) end -function getindex_internal(vi::VarInfo{NamedTuple{(),Tuple{}}}, ::Colon) - return float(Real)[] +function getindex_internal(::VarInfo{NamedTuple{(),Tuple{}}}, ::Colon) + return DynamicPPLAny[] end function getindex_internal(md::Metadata, ::Colon) return mapreduce(