diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 7940f20e6..b8a410be6 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -611,29 +611,6 @@ julia> values_as(vi, Vector) """ function values_as end -""" - eltype(vi::AbstractVarInfo) - -Return the `eltype` of the values returned by `vi[:]`. - -!!! warning - This should generally not be called explicitly, as it's only used in - [`matchingvalue`](@ref) to determine the default type to use in place of - type-parameters passed to the model. - - This method is considered legacy, and is likely to be deprecated in the future. -""" -function Base.eltype(vi::AbstractVarInfo) - T = Base.promote_op(getindex, typeof(vi), Colon) - if T === Union{} - # In this case `getindex(vi, :)` errors - # Let us throw a more descriptive error message - # Ref https://github.com/TuringLang/Turing.jl/issues/2151 - return eltype(vi[:]) - end - return eltype(T) -end - """ has_varnamedvector(varinfo::VarInfo) diff --git a/src/compiler.jl b/src/compiler.jl index 6384eaa7c..f900d0ff7 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -583,12 +583,6 @@ function add_return_to_last_statment(body::Expr) return Expr(body.head, new_args...) end -const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} -hasmissing(::Type) = false -hasmissing(::Type{>:Missing}) = true -hasmissing(::Type{<:AbstractArray{TA}}) where {TA} = hasmissing(TA) -hasmissing(::Type{Union{}}) = false # issue #368 - """ TypeWrap{T} @@ -707,53 +701,3 @@ function warn_empty(body) end return nothing end - -# TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? -# TODO(mhauru) This function needs a more comprehensive docstring. -""" - matchingvalue(vi, value) - -Convert the `value` to the correct type for the `vi` object. -""" -function matchingvalue(vi, value) - T = typeof(value) - if hasmissing(T) - _value = convert(get_matching_type(vi, T), value) - # TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we - # are happy to return `value` as-is? - if _value === value - return deepcopy(_value) - else - return _value - end - else - return value - end -end - -function matchingvalue(vi, value::FloatOrArrayType) - return get_matching_type(vi, value) -end -function matchingvalue(vi, ::TypeWrap{T}) where {T} - return TypeWrap{get_matching_type(vi, T)}() -end - -# TODO(mhauru) This function needs a more comprehensive docstring. What is it for? -""" - get_matching_type(vi, ::TypeWrap{T}) where {T} - -Get the specialized version of type `T` for `vi`. -""" -get_matching_type(_, ::Type{T}) where {T} = T -function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) - return Union{Missing,float_type_with_fallback(eltype(vi))} -end -function get_matching_type(vi, ::Type{<:AbstractFloat}) - return float_type_with_fallback(eltype(vi)) -end -function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(vi, T),N} -end -function get_matching_type(vi, ::Type{<:Array{T}}) where {T} - return Array{get_matching_type(vi, T)} -end diff --git a/src/model.jl b/src/model.jl index ac9968cf2..f7cc65c96 100644 --- a/src/model.jl +++ b/src/model.jl @@ -923,6 +923,9 @@ end is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#") +# TODO(penelopeysm) fix +maybe_deepcopy(x) = deepcopy(x) + """ make_evaluate_args_and_kwargs(model, varinfo) @@ -933,9 +936,9 @@ Return the arguments and keyword arguments to be passed to the evaluator of the ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) - :($matchingvalue(varinfo, model.args.$var)...) + :($(maybe_deepcopy)(model.args.$var)...) else - :($matchingvalue(varinfo, model.args.$var)) + :($(maybe_deepcopy)(model.args.$var)) end for var in argnames ] return quote diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ad22bf52d..51df17733 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -412,9 +412,6 @@ const SimpleOrThreadSafeSimple{T,V,C} = Union{ SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}} } -# Necessary for `matchingvalue` to work properly. -Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V - # `subset` function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) return SimpleVarInfo( diff --git a/test/ad.jl b/test/ad.jl index 371e79b06..2de022aa3 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -82,16 +82,21 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest t = 1:0.05:8 σ = 0.3 y = @. rand(sin(t) + Normal(0, σ)) - @model function state_space(y, TT, ::Type{T}=Float64) where {T} + @model function state_space(y, TT) # Priors α ~ Normal(y[1], 0.001) τ ~ Exponential(1) η ~ filldist(Normal(0, 1), TT - 1) σ ~ Exponential(1) - # create latent variable - x = Vector{T}(undef, TT) + # create latent variable -- Have to use typeof(α) here to ensure that + # AD works fine. Not sure if this is a generally good workaround. + x = Vector{typeof(α)}(undef, TT) + # As an alternative to the above, we could do this: + # using Accessors + # x = Accessors.set(x, (Accessors.@optic _[1]), α) x[1] = α for t in 2:TT + # and likewise for this line -- use Accessors x[t] = x[t - 1] + η[t - 1] * τ end # measurement model diff --git a/test/compiler.jl b/test/compiler.jl index 97121715a..1da876359 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -1,3 +1,8 @@ +module DynamicPPLCompilerTests + +using DynamicPPL, Distributions, Random, Test +using LinearAlgebra: I + macro custom(expr) (Meta.isexpr(expr, :call, 3) && expr.args[1] === :~) || error("incorrect macro usage") quote @@ -661,20 +666,6 @@ module Issue537 end @test demo_ret_with_ret()() === Val(1) end - @testset "issue #368: hasmissing dispatch" begin - @test !DynamicPPL.hasmissing(typeof(Union{}[])) - - # (nested) arrays with `Missing` eltypes - @test DynamicPPL.hasmissing(Vector{Union{Missing,Float64}}) - @test DynamicPPL.hasmissing(Matrix{Union{Missing,Real}}) - @test DynamicPPL.hasmissing(Vector{Matrix{Union{Missing,Float32}}}) - - # no `Missing` - @test !DynamicPPL.hasmissing(Vector{Float64}) - @test !DynamicPPL.hasmissing(Matrix{Real}) - @test !DynamicPPL.hasmissing(Vector{Matrix{Float32}}) - end - @testset "issue #393: anonymous argument with type parameter" begin @model f_393(::Val{ispredict}=Val(false)) where {ispredict} = ispredict ? 0 : 1 @test f_393()() == 1 @@ -794,3 +785,5 @@ module Issue537 end @test res == (a=1, b=1, c=2, d=2, t=DynamicPPL.TypeWrap{Int}()) end end + +end # module DynamicPPLCompilerTests