@@ -629,7 +629,6 @@ function add_return_to_last_statment(body::Expr)
629629 return Expr(body. head, new_args... )
630630end
631631
632- const FloatOrArrayType = Type{<: Union{AbstractFloat,AbstractArray} }
633632hasmissing(:: Type ) = false
634633hasmissing(:: Type{>:Missing} ) = true
635634hasmissing(:: Type{<:AbstractArray{TA}} ) where {TA} = hasmissing(TA)
@@ -754,54 +753,68 @@ function warn_empty(body)
754753 return nothing
755754end
756755
757- # TODO (mhauru) matchingvalue has methods that can accept both types and values. Why?
758- # TODO (mhauru) This function needs a more comprehensive docstring.
759756"""
760- matchingvalue(param_eltype, value)
761-
762- Convert the `value` to the correct type, given the element type of the parameters
763- being used to evaluate the model.
764- """
765- function matchingvalue(param_eltype, value)
766- T = typeof(value)
767- if hasmissing(T)
768- _value = convert(get_matching_type(param_eltype, T), value)
769- # TODO (mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we
770- # are happy to return `value` as-is?
771- if _value === value
772- return deepcopy(_value)
757+ convert_model_argument(param_eltype, model_argument)
758+
759+ Convert `model_argument` to the correct type, given the element type of the parameters being
760+ used to evaluate the model. This function potentially also deep-copies `model_argument` if it
761+ contains `missing` values.
762+ """
763+ function convert_model_argument(param_eltype, model_argument)
764+ T = typeof(model_argument)
765+ # If the argument contains missing data, then we potentially need to deepcopy it. This
766+ # is because the argument may be e.g. a vector of missings, and evaluating a
767+ # tilde-statement like x[1] ~ Normal() would set x[1] = some_not_missing_value, thus
768+ # mutating x. If you then run the model again with the same argument, x[1] would no
769+ # longer be missing.
770+ return if hasmissing(T)
771+ # It is possible that we could skip the deepcopy, if the argument has to be promoted
772+ # anyway. For example, if we are running with ForwardDiff and the argument is a
773+ # Vector{Union{Missing, Float64}}, then we will convert it to a
774+ # Vector{Union{Missing, ForwardDiff.Dual{...}}} anyway, which will avoid mutating
775+ # the original argument. We can check for this by first converting and then only
776+ # deepcopying if the converted value aliases the original.
777+ converted_argument = convert(
778+ promote_model_type_argument(param_eltype, T), model_argument
779+ )
780+ if converted_argument === model_argument
781+ deepcopy(model_argument)
773782 else
774- return _value
783+ converted_argument
775784 end
776785 else
777- return value
786+ model_argument
778787 end
779788end
780-
781- function matchingvalue (param_eltype, value:: FloatOrArrayType )
782- return get_matching_type (param_eltype, value)
789+ # These methods handle arguments that are types rather than values.
790+ function convert_model_argument (param_eltype, value:: Type{<:Union{Real,AbstractArray}} )
791+ return promote_model_type_argument (param_eltype, value)
783792end
784- function matchingvalue (param_eltype, :: TypeWrap{T} ) where {T}
785- return TypeWrap{get_matching_type (param_eltype, T)}()
793+ function convert_model_argument (param_eltype, :: TypeWrap{T} ) where {T}
794+ return TypeWrap{promote_model_type_argument (param_eltype, T)}()
786795end
787796
788- # TODO (mhauru) This function needs a more comprehensive docstring. What is it for?
789797"""
790- get_matching_type(param_eltype, ::TypeWrap{T}) where {T}
798+ promote_model_type_argument(param_eltype, ::Type{T}) where {T}
799+ promote_model_type_argument(param_eltype, ::TypeWrap{T}) where {T}
791800
792- Get the specialized version of type `T`, given an element type of the parameters
793- being used to evaluate the model.
801+ For arguments to a model that are types rather than values, promote the type `T` to
802+ match the element type of the parameters being used to evaluate the model.
794803"""
795- get_matching_type(_, :: Type{T} ) where {T} = T
796- function get_matching_type(param_eltype, :: Type{<:Union{Missing,AbstractFloat}} )
797- return Union{Missing,float_type_with_fallback(param_eltype)}
798- end
799- function get_matching_type(param_eltype, :: Type{<:AbstractFloat} )
800- return float_type_with_fallback(param_eltype)
801- end
802- function get_matching_type(param_eltype, :: Type{<:Array{T,N}} ) where {T,N}
803- return Array{get_matching_type(param_eltype, T),N}
804+ promote_model_type_argument(_, :: Type{T} ) where {T} = T
805+ function promote_model_type_argument(param_eltype, :: Type{T} ) where {T<: Real }
806+ # TODO (penelopeysm): This actually might still be over-aggressive. For example, if
807+ # `param_eltype` is `Float32` and `T` is `Vector{Int}`, then (after going through the
808+ # Array method) we will promote to `Vector{Float64}`, which seems unnecessary. However,
809+ # there's no way to actually check if `T` is the type of something that will later be
810+ # assigned to, so this is 'safe'.
811+ return Base. promote_type(param_eltype, T)
804812end
805- function get_matching_type(param_eltype, :: Type{<:Array{T}} ) where {T}
806- return Array{get_matching_type(param_eltype, T)}
813+ # NOTE(penelopeysm): This doesn't work with other types of AbstractArray. To get around
814+ # that, one could in principle use ArrayInterface.promote_eltype. However, it doesn't seem
815+ # like there is (1) demand for that; and (2) sufficiently strong adoption of ArrayInterface
816+ # to make that worth adding as a dependency.
817+ function promote_model_type_argument(param_eltype, :: Type{Array{T,N}} ) where {T,N}
818+ promoted_eltype = promote_model_type_argument(param_eltype, eltype(T))
819+ return Array{promoted_eltype,N}
807820end
0 commit comments