-
Notifications
You must be signed in to change notification settings - Fork 36
Remove eltype
, matchingvalue
, get_matching_type
#1015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
96cf6ef
52dad64
f9d3431
069bc40
4a29a2a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -583,12 +583,6 @@ function add_return_to_last_statment(body::Expr) | |
return Expr(body.head, new_args...) | ||
end | ||
|
||
const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} | ||
hasmissing(::Type) = false | ||
hasmissing(::Type{>:Missing}) = true | ||
hasmissing(::Type{<:AbstractArray{TA}}) where {TA} = hasmissing(TA) | ||
hasmissing(::Type{Union{}}) = false # issue #368 | ||
|
||
""" | ||
TypeWrap{T} | ||
|
||
|
@@ -707,53 +701,3 @@ function warn_empty(body) | |
end | ||
return nothing | ||
end | ||
|
||
# TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? | ||
# TODO(mhauru) This function needs a more comprehensive docstring. | ||
""" | ||
matchingvalue(vi, value) | ||
|
||
Convert the `value` to the correct type for the `vi` object. | ||
""" | ||
function matchingvalue(vi, value) | ||
T = typeof(value) | ||
if hasmissing(T) | ||
_value = convert(get_matching_type(vi, T), value) | ||
# TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we | ||
# are happy to return `value` as-is? | ||
Comment on lines
-722
to
-723
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change was made here: The motivation is here: TuringLang/Turing.jl#1464 (comment) This has to do with some subtle mutation behaviour. For example @model function f(x)
x[1] ~ Normal()
end If If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So apart from the deepcopy to avoid aliasing, the other place where @model function f(y, ::Type{T}=Float64) where {T}
x = Vector{T}(undef, length(y))
for i in eachindex(y)
x[i] ~ Normal()
y[i] ~ Normal(x[i])
end
end
model = f([1.0]) If you just evaluate this normally with floats, it's all good. Nothing special needs to happen. If you evaluate this with ReverseDiff, then things need to change. Specifically:
It actually gets a bit more complicated. When you define the model, the ForwardDiff actually works just fine on this PR. I don't know why, but I also remember there was a talk I gave where we were surprised that actually ForwardDiff NUTS worked fine without special So this whole thing pretty much only exists to make ReverseDiff happy.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, for most models, ForwardDiff and ReverseDiff still work because of this special nice behaviour: julia> x = Float64[1.0, 2.0]
2-element Vector{Float64}:
1.0
2.0
julia> x[1] = ForwardDiff.Dual(3.0) # x[1] ~ dist doesn't do this
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{Nothing, Float64, 0})
The type `Float64` exists, but no method is defined for this combination of argument types when trying to construct it.
julia> x = Accessors.set(x, (@optic _[1]), ForwardDiff.Dual(3.0)) # x[1] ~ dist actually does this!
2-element Vector{ForwardDiff.Dual{Nothing, Float64, 0}}:
Dual{Nothing}(3.0)
Dual{Nothing}(2.0) There is only one erroring test in CI, which happens because the model explicitly includes the assignment BUT there are correctness issues with ReverseDiff (not errors), and I have no clue where those stem from. And really interestingly, it's only a problem for one of the demo models, not any of the others, even though many of them use the |
||
if _value === value | ||
return deepcopy(_value) | ||
else | ||
return _value | ||
end | ||
else | ||
return value | ||
end | ||
end | ||
|
||
function matchingvalue(vi, value::FloatOrArrayType) | ||
return get_matching_type(vi, value) | ||
end | ||
function matchingvalue(vi, ::TypeWrap{T}) where {T} | ||
return TypeWrap{get_matching_type(vi, T)}() | ||
end | ||
|
||
# TODO(mhauru) This function needs a more comprehensive docstring. What is it for? | ||
""" | ||
get_matching_type(vi, ::TypeWrap{T}) where {T} | ||
|
||
Get the specialized version of type `T` for `vi`. | ||
""" | ||
get_matching_type(_, ::Type{T}) where {T} = T | ||
function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) | ||
return Union{Missing,float_type_with_fallback(eltype(vi))} | ||
end | ||
function get_matching_type(vi, ::Type{<:AbstractFloat}) | ||
return float_type_with_fallback(eltype(vi)) | ||
end | ||
function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N} | ||
return Array{get_matching_type(vi, T),N} | ||
end | ||
function get_matching_type(vi, ::Type{<:Array{T}}) where {T} | ||
return Array{get_matching_type(vi, T)} | ||
end |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because
matchingvalue
gets called on all the model function's arguments, and types can be arguments to the model as well, e.g.