@@ -629,7 +629,6 @@ function add_return_to_last_statment(body::Expr)
629629 return Expr (body. head, new_args... )
630630end
631631
632- const FloatOrArrayType = Type{<: Union{AbstractFloat,AbstractArray} }
633632hasmissing (:: Type ) = false
634633hasmissing (:: Type{>:Missing} ) = true
635634hasmissing (:: Type{<:AbstractArray{TA}} ) where {TA} = hasmissing (TA)
@@ -754,54 +753,73 @@ function warn_empty(body)
754753 return nothing
755754end
756755
757- # TODO (mhauru) matchingvalue has methods that can accept both types and values. Why?
758- # TODO (mhauru) This function needs a more comprehensive docstring.
759756"""
760- matchingvalue(param_eltype, value)
761-
762- Convert the `value` to the correct type, given the element type of the parameters
763- being used to evaluate the model.
764- """
765- function matchingvalue (param_eltype, value)
766- T = typeof (value)
767- if hasmissing (T)
768- _value = convert (get_matching_type (param_eltype, T), value)
769- # TODO (mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we
770- # are happy to return `value` as-is?
771- if _value === value
772- return deepcopy (_value)
757+ convert_model_argument(param_eltype, model_argument)
758+
759+ Convert `model_argument` to the correct type, given the element type of the parameters being
760+ used to evaluate the model. This function potentially also deep-copies `model_argument` if it
761+ contains `missing` values.
762+ """
763+ function convert_model_argument (param_eltype, model_argument)
764+ T = typeof (model_argument)
765+ # If the argument contains missing data, then we potentially need to deepcopy it. This
766+ # is because the argument may be e.g. a vector of missings, and evaluating a
767+ # tilde-statement like x[1] ~ Normal() would set x[1] = some_not_missing_value, thus
768+ # mutating x. If you then run the model again with the same argument, x[1] would no
769+ # longer be missing.
770+ return if hasmissing (T)
771+ # It is possible that we could skip the deepcopy, if the argument has to be promoted
772+ # anyway. For example, if we are running with ForwardDiff and the argument is a
773+ # Vector{Union{Missing, Float64}}, then we will convert it to a
774+ # Vector{Union{Missing, ForwardDiff.Dual{...}}} anyway, which will avoid mutating
775+ # the original argument. We can check for this by first converting and then only
776+ # deepcopying if the converted value aliases the original.
777+ converted_argument = convert (
778+ promote_model_type_argument (param_eltype, T), model_argument
779+ )
780+ if converted_argument === model_argument
781+ deepcopy (model_argument)
773782 else
774- return _value
783+ converted_argument
775784 end
776785 else
777- return value
786+ model_argument
778787 end
779788end
780-
781- function matchingvalue (param_eltype, value:: FloatOrArrayType )
782- return get_matching_type (param_eltype, value)
783- end
784- function matchingvalue (param_eltype, :: TypeWrap{T} ) where {T}
785- return TypeWrap {get_matching_type(param_eltype, T)} ()
786- end
787-
788- # TODO (mhauru) This function needs a more comprehensive docstring. What is it for?
789- """
790- get_matching_type(param_eltype, ::TypeWrap{T}) where {T}
791-
792- Get the specialized version of type `T`, given an element type of the parameters
793- being used to evaluate the model.
794- """
795- get_matching_type (_, :: Type{T} ) where {T} = T
796- function get_matching_type (param_eltype, :: Type{<:Union{Missing,AbstractFloat}} )
797- return Union{Missing,float_type_with_fallback (param_eltype)}
798- end
799- function get_matching_type (param_eltype, :: Type{<:AbstractFloat} )
800- return float_type_with_fallback (param_eltype)
801- end
802- function get_matching_type (param_eltype, :: Type{<:Array{T,N}} ) where {T,N}
803- return Array{get_matching_type (param_eltype, T),N}
804- end
805- function get_matching_type (param_eltype, :: Type{<:Array{T}} ) where {T}
806- return Array{get_matching_type (param_eltype, T)}
789+ # These methods handle arguments that are types rather than values.
790+ function convert_model_argument (param_eltype, t:: Type{<:Union{Real,AbstractArray}} )
791+ return promote_model_type_argument (param_eltype, t)
792+ end
793+ function convert_model_argument (param_eltype, :: TypeWrap{T} ) where {T}
794+ return TypeWrap {promote_model_type_argument(param_eltype, T)} ()
795+ end
796+ # If the parameter element type is `Any`, then we don't need to do any conversion.
797+ # Extra methods to avoid method ambiguity.
798+ convert_model_argument (:: Type{Any} , model_argument) = model_argument
799+ convert_model_argument (:: Type{Any} , t:: Type{<:Union{Real,AbstractArray}} ) = t
800+ convert_model_argument (:: Type{Any} , t:: TypeWrap{T} ) where {T} = t
801+
802+ """
803+ promote_model_type_argument(param_eltype, ::Type{T}) where {T}
804+ promote_model_type_argument(param_eltype, ::TypeWrap{T}) where {T}
805+
806+ For arguments to a model that are types rather than values, promote the type `T` to
807+ match the element type of the parameters being used to evaluate the model.
808+ """
809+ promote_model_type_argument (_, :: Type{T} ) where {T} = T
810+ function promote_model_type_argument (param_eltype, :: Type{T} ) where {T<: Real }
811+ # TODO (penelopeysm): This actually might still be over-aggressive. For example, if
812+ # `param_eltype` is `Float32` and `T` is `Vector{Int}`, then (after going through the
813+ # Array method) we will promote to `Vector{Float64}`, which seems unnecessary. However,
814+ # there's no way to actually check if `T` is the type of something that will later be
815+ # assigned to, so this is 'safe'.
816+ return Base. promote_type (param_eltype, T)
817+ end
818+ # NOTE(penelopeysm): This doesn't work with other types of AbstractArray. To get around
819+ # that, one could in principle use ArrayInterface.promote_eltype. However, it doesn't seem
820+ # like there is (1) demand for that; and (2) sufficiently strong adoption of ArrayInterface
821+ # to make that worth adding as a dependency.
822+ function promote_model_type_argument (param_eltype, :: Type{Array{T,N}} ) where {T,N}
823+ promoted_eltype = promote_model_type_argument (param_eltype, eltype (T))
824+ return Array{promoted_eltype,N}
807825end
0 commit comments