Skip to content

Remove eltype, matchingvalue, get_matching_type #1015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -611,29 +611,6 @@ julia> values_as(vi, Vector)
"""
function values_as end

"""
eltype(vi::AbstractVarInfo)

Return the `eltype` of the values returned by `vi[:]`.

!!! warning
This should generally not be called explicitly, as it's only used in
[`matchingvalue`](@ref) to determine the default type to use in place of
type-parameters passed to the model.

This method is considered legacy, and is likely to be deprecated in the future.
"""
function Base.eltype(vi::AbstractVarInfo)
T = Base.promote_op(getindex, typeof(vi), Colon)
if T === Union{}
# In this case `getindex(vi, :)` errors
# Let us throw a more descriptive error message
# Ref https://github.com/TuringLang/Turing.jl/issues/2151
return eltype(vi[:])
end
return eltype(T)
end

"""
has_varnamedvector(varinfo::VarInfo)

Expand Down
56 changes: 0 additions & 56 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -583,12 +583,6 @@ function add_return_to_last_statment(body::Expr)
return Expr(body.head, new_args...)
end

const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}}
hasmissing(::Type) = false
hasmissing(::Type{>:Missing}) = true
hasmissing(::Type{<:AbstractArray{TA}}) where {TA} = hasmissing(TA)
hasmissing(::Type{Union{}}) = false # issue #368

"""
TypeWrap{T}

Expand Down Expand Up @@ -707,53 +701,3 @@ function warn_empty(body)
end
return nothing
end

# TODO(mhauru) matchingvalue has methods that can accept both types and values. Why?
Copy link
Member Author

@penelopeysm penelopeysm Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because matchingvalue gets called on all the model function's arguments, and types can be arguments to the model as well, e.g.

@model function f(x, T) ... end
model = f(1.0, Float64)

# TODO(mhauru) This function needs a more comprehensive docstring.
"""
matchingvalue(vi, value)

Convert the `value` to the correct type for the `vi` object.
"""
function matchingvalue(vi, value)
T = typeof(value)
if hasmissing(T)
_value = convert(get_matching_type(vi, T), value)
# TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we
# are happy to return `value` as-is?
Comment on lines -722 to -723
Copy link
Member Author

@penelopeysm penelopeysm Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was made here:

#191

The motivation is here:

TuringLang/Turing.jl#1464 (comment)

This has to do with some subtle mutation behaviour. For example

@model function f(x)
    x[1] ~ Normal()
end

If model = f([1.0]), the tilde statement is an observe, and thus even if you reassign to x[1] it doesn't change the value of x. This is the !hasmissing branch, and since overwriting is a no-op, we don't need to deepcopy it.

If model = f([missing]) - the tilde statement is now an assume, and when you run the model it will sample a new value for x[1] and set that value in x. Then if you rerun the model x[1] is no longer missing. This is the case where deepcopy is triggered.

Copy link
Member Author

@penelopeysm penelopeysm Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So apart from the deepcopy to avoid aliasing, the other place where matchingvalue does something meaningful is

@model function f(y, ::Type{T}=Float64) where {T}
    x = Vector{T}(undef, length(y))
    for i in eachindex(y)
        x[i] ~ Normal()
        y[i] ~ Normal(x[i])
    end
end
model = f([1.0])

If you just evaluate this normally with floats, it's all good. Nothing special needs to happen.

If you evaluate this with ReverseDiff, then things need to change. Specifically:

  1. x needs to become a vector of TrackedReals rather than a vector of Floats.
  2. In order to accomplish this, the ARGUMENT to the model needs to change: even though T SEEMS to be specified as Float64, in fact, matchingvalue hijacks it to turn it into TrackedReal when calling model().
  3. How does matchingvalue know that it needs to become a TrackedReal? Simple - when you call logdensity_and_gradient it calls unflatten to set the parameters (which will be TrackedReals) in the varinfo. matchingvalue then looks inside the varinfo to see if the varinfo contains TrackedReals! Hence eltype(vi) 🙃

It actually gets a bit more complicated. When you define the model, the @model macro already hijacks it to turn T into TypeWrap{Float64}(), and then when you actually evaluate the model matchingvalue hijacks it even further to turn it into TypeWrap{TrackedReal}(). Not sure why TypeWrap is needed but apparently it's something to do with avoiding DataType.

ForwardDiff actually works just fine on this PR. I don't know why, but I also remember there was a talk I gave where we were surprised that actually ForwardDiff NUTS worked fine without special ::Type{T}=Float64 stuff, so that is consistent with this observation.

So this whole thing pretty much only exists to make ReverseDiff happy.

To get around this, I propose that we drop compatibility with ReverseDiff

Copy link
Member Author

@penelopeysm penelopeysm Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, for most models, ForwardDiff and ReverseDiff still work because of this special nice behaviour:

julia> x = Float64[1.0, 2.0]
2-element Vector{Float64}:
 1.0
 2.0

julia> x[1] = ForwardDiff.Dual(3.0) # x[1] ~ dist doesn't do this
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{Nothing, Float64, 0})
The type `Float64` exists, but no method is defined for this combination of argument types when trying to construct it.

julia> x = Accessors.set(x, (@optic _[1]), ForwardDiff.Dual(3.0)) # x[1] ~ dist actually does this!
2-element Vector{ForwardDiff.Dual{Nothing, Float64, 0}}:
 Dual{Nothing}(3.0)
 Dual{Nothing}(2.0)

There is only one erroring test in CI, which happens because the model explicitly includes the assignment x[i] = ... rather than a tilde-statement x[i] ~ .... Changing the assignment to use Accessors.set makes it work just fine.

BUT there are correctness issues with ReverseDiff (not errors), and I have no clue where those stem from. And really interestingly, it's only a problem for one of the demo models, not any of the others, even though many of them use the Type{T} syntax.

if _value === value
return deepcopy(_value)
else
return _value
end
else
return value
end
end

function matchingvalue(vi, value::FloatOrArrayType)
return get_matching_type(vi, value)
end
function matchingvalue(vi, ::TypeWrap{T}) where {T}
return TypeWrap{get_matching_type(vi, T)}()
end

# TODO(mhauru) This function needs a more comprehensive docstring. What is it for?
"""
get_matching_type(vi, ::TypeWrap{T}) where {T}

Get the specialized version of type `T` for `vi`.
"""
get_matching_type(_, ::Type{T}) where {T} = T
function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}})
return Union{Missing,float_type_with_fallback(eltype(vi))}
end
function get_matching_type(vi, ::Type{<:AbstractFloat})
return float_type_with_fallback(eltype(vi))
end
function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N}
return Array{get_matching_type(vi, T),N}
end
function get_matching_type(vi, ::Type{<:Array{T}}) where {T}
return Array{get_matching_type(vi, T)}
end
7 changes: 5 additions & 2 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,9 @@ end

is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#")

# TODO(penelopeysm) fix
maybe_deepcopy(x) = deepcopy(x)

"""
make_evaluate_args_and_kwargs(model, varinfo)

Expand All @@ -933,9 +936,9 @@ Return the arguments and keyword arguments to be passed to the evaluator of the
) where {_F,argnames}
unwrap_args = [
if is_splat_symbol(var)
:($matchingvalue(varinfo, model.args.$var)...)
:($(maybe_deepcopy)(model.args.$var)...)
else
:($matchingvalue(varinfo, model.args.$var))
:($(maybe_deepcopy)(model.args.$var))
end for var in argnames
]
return quote
Expand Down
3 changes: 0 additions & 3 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,6 @@ const SimpleOrThreadSafeSimple{T,V,C} = Union{
SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}}
}

# Necessary for `matchingvalue` to work properly.
Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V

# `subset`
function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
return SimpleVarInfo(
Expand Down
11 changes: 8 additions & 3 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,21 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest
t = 1:0.05:8
σ = 0.3
y = @. rand(sin(t) + Normal(0, σ))
@model function state_space(y, TT, ::Type{T}=Float64) where {T}
@model function state_space(y, TT)
# Priors
α ~ Normal(y[1], 0.001)
τ ~ Exponential(1)
η ~ filldist(Normal(0, 1), TT - 1)
σ ~ Exponential(1)
# create latent variable
x = Vector{T}(undef, TT)
# create latent variable -- Have to use typeof(α) here to ensure that
# AD works fine. Not sure if this is a generally good workaround.
x = Vector{typeof(α)}(undef, TT)
# As an alternative to the above, we could do this:
# using Accessors
# x = Accessors.set(x, (Accessors.@optic _[1]), α)
x[1] = α
for t in 2:TT
# and likewise for this line -- use Accessors
x[t] = x[t - 1] + η[t - 1] * τ
end
# measurement model
Expand Down
21 changes: 7 additions & 14 deletions test/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
module DynamicPPLCompilerTests

using DynamicPPL, Distributions, Random, Test
using LinearAlgebra: I

macro custom(expr)
(Meta.isexpr(expr, :call, 3) && expr.args[1] === :~) || error("incorrect macro usage")
quote
Expand Down Expand Up @@ -661,20 +666,6 @@ module Issue537 end
@test demo_ret_with_ret()() === Val(1)
end

@testset "issue #368: hasmissing dispatch" begin
@test !DynamicPPL.hasmissing(typeof(Union{}[]))

# (nested) arrays with `Missing` eltypes
@test DynamicPPL.hasmissing(Vector{Union{Missing,Float64}})
@test DynamicPPL.hasmissing(Matrix{Union{Missing,Real}})
@test DynamicPPL.hasmissing(Vector{Matrix{Union{Missing,Float32}}})

# no `Missing`
@test !DynamicPPL.hasmissing(Vector{Float64})
@test !DynamicPPL.hasmissing(Matrix{Real})
@test !DynamicPPL.hasmissing(Vector{Matrix{Float32}})
end

@testset "issue #393: anonymous argument with type parameter" begin
@model f_393(::Val{ispredict}=Val(false)) where {ispredict} = ispredict ? 0 : 1
@test f_393()() == 1
Expand Down Expand Up @@ -794,3 +785,5 @@ module Issue537 end
@test res == (a=1, b=1, c=2, d=2, t=DynamicPPL.TypeWrap{Int}())
end
end

end # module DynamicPPLCompilerTests
Loading