Skip to content

Commit 82e9e2e

Browse files
committed
Clean up matchingvalue etc
1 parent a21b21d commit 82e9e2e

File tree

5 files changed

+122
-48
lines changed

5 files changed

+122
-48
lines changed

HISTORY.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# DynamicPPL Changelog
22

3+
## 0.39.10
4+
5+
Rename the internal functions `matchingvalue` and `get_matching_type` to `convert_model_argument` and `promote_model_type_argument` respectively.
6+
The behaviour of `promote_model_type_argument` has also been slightly changed in some edge cases: for example, `promote_model_type_argument(ForwardDiff.Dual{Nothing,Float64,0}, Vector{Real})` now returns `Vector{ForwardDiff.Dual{Nothing,Real,0}}` instead of `Vector{ForwardDiff.Dual{Nothing,Float64,0}}`.
7+
In other words, abstract types in the type argument are now properly respected.
8+
9+
This should have almost no impact on end users (unless you were passing `::Type{T}=Vector{Real}` into the model, with an abstract eltype).
10+
311
## 0.39.9
412

513
The internals of `LogDensityFunction` have been changed slightly so that you do not need to specify `function_annotation` when performing AD with Enzyme.jl.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.39.9"
3+
version = "0.39.10"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/compiler.jl

Lines changed: 66 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,6 @@ function add_return_to_last_statment(body::Expr)
629629
return Expr(body.head, new_args...)
630630
end
631631

632-
const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}}
633632
hasmissing(::Type) = false
634633
hasmissing(::Type{>:Missing}) = true
635634
hasmissing(::Type{<:AbstractArray{TA}}) where {TA} = hasmissing(TA)
@@ -754,54 +753,76 @@ function warn_empty(body)
754753
return nothing
755754
end
756755

757-
# TODO(mhauru) matchingvalue has methods that can accept both types and values. Why?
758-
# TODO(mhauru) This function needs a more comprehensive docstring.
759756
"""
760-
matchingvalue(param_eltype, value)
761-
762-
Convert the `value` to the correct type, given the element type of the parameters
763-
being used to evaluate the model.
764-
"""
765-
function matchingvalue(param_eltype, value)
766-
T = typeof(value)
767-
if hasmissing(T)
768-
_value = convert(get_matching_type(param_eltype, T), value)
769-
# TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we
770-
# are happy to return `value` as-is?
771-
if _value === value
772-
return deepcopy(_value)
757+
convert_model_argument(param_eltype, model_argument)
758+
759+
Convert `model_argument` to the correct type, given the element type of the parameters being
760+
used to evaluate the model. This function potentially also deep-copies `model_argument` if it
761+
contains `missing` values.
762+
"""
763+
function convert_model_argument(param_eltype, model_argument)
764+
T = typeof(model_argument)
765+
# If the argument contains missing data, then we potentially need to deepcopy it. This
766+
# is because the argument may be e.g. a vector of missings, and evaluating a
767+
# tilde-statement like x[1] ~ Normal() would set x[1] = some_not_missing_value, thus
768+
# mutating x. If you then run the model again with the same argument, x[1] would no
769+
# longer be missing.
770+
return if hasmissing(T)
771+
# It is possible that we could skip the deepcopy, if the argument has to be promoted
772+
# anyway. For example, if we are running with ForwardDiff and the argument is a
773+
# Vector{Union{Missing, Float64}}, then we will convert it to a
774+
# Vector{Union{Missing, ForwardDiff.Dual{...}}} anyway, which will avoid mutating
775+
# the original argument. We can check for this by first converting and then only
776+
# deepcopying if the converted value aliases the original.
777+
converted_argument = convert(
778+
promote_model_type_argument(param_eltype, T), model_argument
779+
)
780+
if converted_argument === model_argument
781+
deepcopy(model_argument)
773782
else
774-
return _value
783+
converted_argument
775784
end
776785
else
777-
return value
786+
model_argument
778787
end
779788
end
780-
781-
function matchingvalue(param_eltype, value::FloatOrArrayType)
782-
return get_matching_type(param_eltype, value)
783-
end
784-
function matchingvalue(param_eltype, ::TypeWrap{T}) where {T}
785-
return TypeWrap{get_matching_type(param_eltype, T)}()
786-
end
787-
788-
# TODO(mhauru) This function needs a more comprehensive docstring. What is it for?
789-
"""
790-
get_matching_type(param_eltype, ::TypeWrap{T}) where {T}
791-
792-
Get the specialized version of type `T`, given an element type of the parameters
793-
being used to evaluate the model.
794-
"""
795-
get_matching_type(_, ::Type{T}) where {T} = T
796-
function get_matching_type(param_eltype, ::Type{<:Union{Missing,AbstractFloat}})
797-
return Union{Missing,float_type_with_fallback(param_eltype)}
798-
end
799-
function get_matching_type(param_eltype, ::Type{<:AbstractFloat})
800-
return float_type_with_fallback(param_eltype)
801-
end
802-
function get_matching_type(param_eltype, ::Type{<:Array{T,N}}) where {T,N}
803-
return Array{get_matching_type(param_eltype, T),N}
804-
end
805-
function get_matching_type(param_eltype, ::Type{<:Array{T}}) where {T}
806-
return Array{get_matching_type(param_eltype, T)}
789+
# These methods handle arguments that are types rather than values.
790+
function convert_model_argument(param_eltype, t::Type{<:Union{Real,AbstractArray}})
791+
return promote_model_type_argument(param_eltype, t)
792+
end
793+
function convert_model_argument(param_eltype, ::TypeWrap{T}) where {T}
794+
return TypeWrap{promote_model_type_argument(param_eltype, T)}()
795+
end
796+
# If the parameter element type is `Any`, then we don't need to do any conversion (but we
797+
# might need to deepcopy).
798+
function convert_model_argument(::Type{Any}, model_argument::T) where {T}
799+
return hasmissing(T) ? deepcopy(model_argument) : model_argument
800+
end
801+
# Extra methods to avoid method ambiguity.
802+
convert_model_argument(::Type{Any}, t::Type{<:Union{Real,AbstractArray}}) = t
803+
convert_model_argument(::Type{Any}, t::TypeWrap{T}) where {T} = t
804+
805+
"""
806+
promote_model_type_argument(param_eltype, ::Type{T}) where {T}
807+
promote_model_type_argument(param_eltype, ::TypeWrap{T}) where {T}
808+
809+
For arguments to a model that are types rather than values, promote the type `T` to
810+
match the element type of the parameters being used to evaluate the model.
811+
"""
812+
promote_model_type_argument(_, ::Type{T}) where {T} = T
813+
function promote_model_type_argument(param_eltype, ::Type{T}) where {T<:Real}
814+
# TODO(penelopeysm): This actually might still be over-aggressive. For example, if
815+
# `param_eltype` is `Float32` and `T` is `Vector{Int}`, then (after going through the
816+
# Array method) we will promote to `Vector{Float64}`, which seems unnecessary. However,
817+
# there's no way to actually check if `T` is the type of something that will later be
818+
# assigned to, so this is 'safe'.
819+
return Base.promote_type(param_eltype, T)
820+
end
821+
# NOTE(penelopeysm): This doesn't work with other types of AbstractArray. To get around
822+
# that, one could in principle use ArrayInterface.promote_eltype. However, it doesn't seem
823+
# like there is (1) demand for that; and (2) sufficiently strong adoption of ArrayInterface
824+
# to make that worth adding as a dependency.
825+
function promote_model_type_argument(param_eltype, ::Type{Array{T,N}}) where {T,N}
826+
promoted_eltype = promote_model_type_argument(param_eltype, eltype(T))
827+
return Array{promoted_eltype,N}
807828
end

src/model.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,12 +1010,14 @@ Return the arguments and keyword arguments to be passed to the evaluator of the
10101010
unwrap_args = [
10111011
if is_splat_symbol(var)
10121012
:(
1013-
$matchingvalue(
1013+
$convert_model_argument(
10141014
$get_param_eltype(varinfo, model.context), model.args.$var
10151015
)...
10161016
)
10171017
else
1018-
:($matchingvalue($get_param_eltype(varinfo, model.context), model.args.$var))
1018+
:($convert_model_argument(
1019+
$get_param_eltype(varinfo, model.context), model.args.$var
1020+
))
10191021
end for var in argnames
10201022
]
10211023
return quote

test/compiler.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,4 +836,47 @@ module Issue537 end
836836
@test vi isa VarInfo
837837
@test vi[@varname(m)] isa Real
838838
end
839+
840+
@testset "convert_model_argument" begin
841+
tdual = ForwardDiff.Dual{Nothing,Float64,0}
842+
# no-op
843+
@test DynamicPPL.convert_model_argument(Float64, 1.0) == 1.0
844+
@testset "shouldn't promote types of value arguments" begin
845+
# i.e. this shouldn't become a dual.
846+
@test DynamicPPL.convert_model_argument(tdual, 1.0) == 1.0
847+
end
848+
@testset "arrays" begin
849+
# convert_model_argument should make sure to not deepcopy arrays if not needed
850+
x = [1.0]
851+
@test DynamicPPL.convert_model_argument(Float64, x) === x
852+
# but if there's a missing in the array, it should
853+
y = [1.0, missing]
854+
y_converted = DynamicPPL.convert_model_argument(Float64, y)
855+
@test y_converted !== y
856+
@test isequal(y_converted, y)
857+
end
858+
@testset "type arguments" begin
859+
# These tests with types / TypeWrap as the second argument also test
860+
# `promote_model_type_argument`.
861+
function test_type_conversion(
862+
::Type{input}, ::Type{target}
863+
) where {input,target}
864+
converted_type = DynamicPPL.convert_model_argument(tdual, input)
865+
@test converted_type == target
866+
typewrap = DynamicPPL.TypeWrap{input}()
867+
converted_typewrap = DynamicPPL.convert_model_argument(tdual, typewrap)
868+
@test converted_typewrap == DynamicPPL.TypeWrap{target}()
869+
end
870+
test_type_conversion(Float64, tdual)
871+
test_type_conversion(Real, ForwardDiff.Dual{Nothing,Real,0})
872+
test_type_conversion(Vector{Float64}, Vector{tdual})
873+
test_type_conversion(Vector{Real}, Vector{ForwardDiff.Dual{Nothing,Real,0}})
874+
test_type_conversion(Matrix{Float64}, Matrix{tdual})
875+
test_type_conversion(Matrix{Real}, Matrix{ForwardDiff.Dual{Nothing,Real,0}})
876+
test_type_conversion(Vector{Vector{Float64}}, Vector{Vector{tdual}})
877+
test_type_conversion(
878+
Vector{Vector{Real}}, Vector{Vector{ForwardDiff.Dual{Nothing,Real,0}}}
879+
)
880+
end
881+
end
839882
end

0 commit comments

Comments
 (0)