diff --git a/HISTORY.md b/HISTORY.md index 8dc37e9db..775850973 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,13 @@ # DynamicPPL Changelog +## 0.39.10 + +Rename the internal functions `matchingvalue` and `get_matching_type` to `convert_model_argument` and `promote_model_type_argument` respectively. +The behaviour of `promote_model_type_argument` has also been slightly changed in some edge cases: for example, `promote_model_type_argument(ForwardDiff.Dual{Nothing,Float64,0}, Vector{Real})` now returns `Vector{ForwardDiff.Dual{Nothing,Real,0}}` instead of `Vector{ForwardDiff.Dual{Nothing,Float64,0}}`. +In other words, abstract types in the type argument are now properly respected. + +This should have almost no impact on end users (unless you were passing `::Type{T}=Vector{Real}` into the model, with an abstract eltype). + ## 0.39.9 The internals of `LogDensityFunction` have been changed slightly so that you do not need to specify `function_annotation` when performing AD with Enzyme.jl. diff --git a/Project.toml b/Project.toml index a163f29b8..988de202e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.39.9" +version = "0.39.10" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/compiler.jl b/src/compiler.jl index f1e92e369..84a9a4857 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -629,7 +629,6 @@ function add_return_to_last_statment(body::Expr) return Expr(body.head, new_args...) end -const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} hasmissing(::Type) = false hasmissing(::Type{>:Missing}) = true hasmissing(::Type{<:AbstractArray{TA}}) where {TA} = hasmissing(TA) @@ -754,54 +753,80 @@ function warn_empty(body) return nothing end -# TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? -# TODO(mhauru) This function needs a more comprehensive docstring. """ - matchingvalue(param_eltype, value) - -Convert the `value` to the correct type, given the element type of the parameters -being used to evaluate the model. -""" -function matchingvalue(param_eltype, value) - T = typeof(value) - if hasmissing(T) - _value = convert(get_matching_type(param_eltype, T), value) - # TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we - # are happy to return `value` as-is? - if _value === value - return deepcopy(_value) + convert_model_argument(param_eltype, model_argument) + +Convert `model_argument` to the correct type, given the element type of the parameters being +used to evaluate the model. This function potentially also deep-copies `model_argument` if it +contains `missing` values. +""" +function convert_model_argument(param_eltype, model_argument) + T = typeof(model_argument) + # If the argument contains missing data, then we potentially need to deepcopy it. This + # is because the argument may be e.g. a vector of missings, and evaluating a + # tilde-statement like x[1] ~ Normal() would set x[1] = some_not_missing_value, thus + # mutating x. If you then run the model again with the same argument, x[1] would no + # longer be missing. + return if hasmissing(T) + # It is possible that we could skip the deepcopy, if the argument has to be promoted + # anyway. For example, if we are running with ForwardDiff and the argument is a + # Vector{Union{Missing, Float64}}, then we will convert it to a + # Vector{Union{Missing, ForwardDiff.Dual{...}}} anyway, which will avoid mutating + # the original argument. We can check for this by first converting and then only + # deepcopying if the converted value aliases the original. + # Note that indiscriminately deepcopying can not only lead to reduced performance, + # but sometimes also incorrect behaviour with ReverseDiff.jl, because ReverseDiff + # expects to be able to track array mutations. See e.g. + # https://github.com/TuringLang/DynamicPPL.jl/pull/1015#issuecomment-3166011534 + converted_argument = convert( + promote_model_type_argument(param_eltype, T), model_argument + ) + if converted_argument === model_argument + deepcopy(model_argument) else - return _value + converted_argument end else - return value + model_argument end end - -function matchingvalue(param_eltype, value::FloatOrArrayType) - return get_matching_type(param_eltype, value) -end -function matchingvalue(param_eltype, ::TypeWrap{T}) where {T} - return TypeWrap{get_matching_type(param_eltype, T)}() -end - -# TODO(mhauru) This function needs a more comprehensive docstring. What is it for? -""" - get_matching_type(param_eltype, ::TypeWrap{T}) where {T} - -Get the specialized version of type `T`, given an element type of the parameters -being used to evaluate the model. -""" -get_matching_type(_, ::Type{T}) where {T} = T -function get_matching_type(param_eltype, ::Type{<:Union{Missing,AbstractFloat}}) - return Union{Missing,float_type_with_fallback(param_eltype)} -end -function get_matching_type(param_eltype, ::Type{<:AbstractFloat}) - return float_type_with_fallback(param_eltype) -end -function get_matching_type(param_eltype, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(param_eltype, T),N} -end -function get_matching_type(param_eltype, ::Type{<:Array{T}}) where {T} - return Array{get_matching_type(param_eltype, T)} +# These methods handle arguments that are types rather than values. +function convert_model_argument(param_eltype, t::Type{<:Union{Real,AbstractArray}}) + return promote_model_type_argument(param_eltype, t) +end +function convert_model_argument(param_eltype, ::TypeWrap{T}) where {T} + return TypeWrap{promote_model_type_argument(param_eltype, T)}() +end +# If the parameter element type is `Any`, then we don't need to do any conversion (but we +# might need to deepcopy). +function convert_model_argument(::Type{Any}, model_argument::T) where {T} + return hasmissing(T) ? deepcopy(model_argument) : model_argument +end +# Extra methods to avoid method ambiguity. +convert_model_argument(::Type{Any}, t::Type{<:Union{Real,AbstractArray}}) = t +convert_model_argument(::Type{Any}, t::TypeWrap{T}) where {T} = t + +""" + promote_model_type_argument(param_eltype, ::Type{T}) where {T} + promote_model_type_argument(param_eltype, ::TypeWrap{T}) where {T} + +For arguments to a model that are types rather than values, promote the type `T` to +match the element type of the parameters being used to evaluate the model. +""" +promote_model_type_argument(_, ::Type{T}) where {T} = T +function promote_model_type_argument(param_eltype, ::Type{T}) where {T<:Real} + # TODO(penelopeysm): This actually might still be over-aggressive. For example, if + # `param_eltype` is `Float32` and `T` is `Vector{Int}`, then (after going through the + # Array method) we will promote to `Vector{Float64}`, which seems unnecessary. However, + # there's no way to actually check if `T` is the type of something that will later be + # assigned to, so this is 'safe'. + return Base.promote_type(param_eltype, T) +end +# NOTE(penelopeysm): This doesn't work with other types of AbstractArray. To get around +# that, one could in principle use ArrayInterface.promote_eltype. However, it doesn't seem +# like there is (1) demand for that; and (2) sufficiently strong adoption of ArrayInterface +# to make that worth adding as a dependency. +function promote_model_type_argument(param_eltype, ::Type{Array{T,N}}) where {T,N} + promoted_eltype = promote_model_type_argument(param_eltype, T) + return Array{promoted_eltype,N} end diff --git a/src/model.jl b/src/model.jl index 8bfeaf6a1..b7f797944 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1010,12 +1010,14 @@ Return the arguments and keyword arguments to be passed to the evaluator of the unwrap_args = [ if is_splat_symbol(var) :( - $matchingvalue( + $convert_model_argument( $get_param_eltype(varinfo, model.context), model.args.$var )... ) else - :($matchingvalue($get_param_eltype(varinfo, model.context), model.args.$var)) + :($convert_model_argument( + $get_param_eltype(varinfo, model.context), model.args.$var + )) end for var in argnames ] return quote @@ -1095,30 +1097,85 @@ Base.rand(::Type{T}, model::Model) where {T} = rand(Random.default_rng(), T, mod Base.rand(model::Model) = rand(Random.default_rng(), NamedTuple, model) """ + logjoint(model::Model, params) logjoint(model::Model, varinfo::AbstractVarInfo) -Return the log joint probability of variables `varinfo` for the probabilistic `model`. +Return the log joint probability of variables `params` for the probabilistic `model`, or the +log joint of the data in `varinfo` if provided. -Note that this probability always refers to the parameters in unlinked space, i.e., -the return value of `logjoint` does not depend on whether `VarInfo` has been linked -or not. +Note that this probability always refers to the parameters in unlinked space, i.e., the +return value of `logjoint` does not depend on whether `VarInfo` has been linked or not. -See [`logprior`](@ref) and [`loglikelihood`](@ref). +See also [`logprior`](@ref) and [`loglikelihood`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + logjoint(demo([1.0]), (m = 100.0, )) +-9902.33787706641 + +julia> # Using a `OrderedDict`. + logjoint(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-9902.33787706641 + +julia> # Truth. + logpdf(Normal(100.0, 1.0), 1.0) + logpdf(Normal(), 100.0) +-9902.33787706641 +``` """ function logjoint(model::Model, varinfo::AbstractVarInfo) return getlogjoint(last(evaluate!!(model, varinfo))) end +function logjoint(model::Model, params) + vi = OnlyAccsVarInfo( + AccumulatorTuple(LogPriorAccumulator(), LogLikelihoodAccumulator()) + ) + ctx = InitFromParams(params, nothing) + return getlogjoint(last(init!!(model, vi, ctx))) +end """ + logprior(model::Model, params) logprior(model::Model, varinfo::AbstractVarInfo) -Return the log prior probability of variables `varinfo` for the probabilistic `model`. +Return the log prior probability of variables `params` for the probabilistic `model`, or the +log prior of the data in `varinfo` if provided. -Note that this probability always refers to the parameters in unlinked space, i.e., -the return value of `logprior` does not depend on whether `VarInfo` has been linked -or not. +Note that this probability always refers to the parameters in unlinked space, i.e., the +return value of `logprior` does not depend on whether `VarInfo` has been linked or not. See also [`logjoint`](@ref) and [`loglikelihood`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + logprior(demo([1.0]), (m = 100.0, )) +-5000.918938533205 + +julia> # Using a `OrderedDict`. + logprior(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-5000.918938533205 + +julia> # Truth. + logpdf(Normal(), 100.0) +-5000.918938533205 +``` """ function logprior(model::Model, varinfo::AbstractVarInfo) # Remove other accumulators from varinfo, since they are unnecessary. @@ -1130,13 +1187,42 @@ function logprior(model::Model, varinfo::AbstractVarInfo) varinfo = setaccs!!(deepcopy(varinfo), (logprioracc,)) return getlogprior(last(evaluate!!(model, varinfo))) end +function logprior(model::Model, params) + vi = OnlyAccsVarInfo(AccumulatorTuple(LogPriorAccumulator())) + ctx = InitFromParams(params, nothing) + return getlogprior(last(init!!(model, vi, ctx))) +end """ + loglikelihood(model::Model, params) loglikelihood(model::Model, varinfo::AbstractVarInfo) -Return the log likelihood of variables `varinfo` for the probabilistic `model`. +Return the log likelihood of variables `params` for the probabilistic `model`, or the log +likelihood of the data in `varinfo` if provided. See also [`logjoint`](@ref) and [`logprior`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + loglikelihood(demo([1.0]), (m = 100.0, )) +-4901.418938533205 + +julia> # Using a `OrderedDict`. + loglikelihood(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-4901.418938533205 + +julia> # Truth. + logpdf(Normal(100.0, 1.0), 1.0) +-4901.418938533205 """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) # Remove other accumulators from varinfo, since they are unnecessary. @@ -1148,6 +1234,11 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) varinfo = setaccs!!(deepcopy(varinfo), (loglikelihoodacc,)) return getloglikelihood(last(evaluate!!(model, varinfo))) end +function Distributions.loglikelihood(model::Model, params) + vi = OnlyAccsVarInfo(AccumulatorTuple(LogLikelihoodAccumulator())) + ctx = InitFromParams(params, nothing) + return getloglikelihood(last(init!!(model, vi, ctx))) +end # Implemented & documented in DynamicPPLMCMCChainsExt function predict end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 9d3fb1925..ec02f4c94 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -408,7 +408,8 @@ const SimpleOrThreadSafeSimple{T,V,C} = Union{ } # Necessary for `matchingvalue` to work properly. -Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V +Base.eltype(svi::SimpleVarInfo) = infer_nested_eltype(typeof(svi.values)) +Base.eltype(tsvi::ThreadSafeVarInfo{<:SimpleVarInfo}) = eltype(tsvi.varinfo) # `subset` function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) @@ -502,105 +503,6 @@ function values_as(vi::SimpleVarInfo, ::Type{T}) where {T} return values_as(vi.values, T) end -""" - logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) - -Return the log joint probability of variables `θ` for the probabilistic `model`. - -See [`logprior`](@ref) and [`loglikelihood`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - logjoint(demo([1.0]), (m = 100.0, )) --9902.33787706641 - -julia> # Using a `OrderedDict`. - logjoint(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --9902.33787706641 - -julia> # Truth. - logpdf(Normal(100.0, 1.0), 1.0) + logpdf(Normal(), 100.0) --9902.33787706641 -``` -""" -logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) = - logjoint(model, SimpleVarInfo(θ)) - -""" - logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) - -Return the log prior probability of variables `θ` for the probabilistic `model`. - -See also [`logjoint`](@ref) and [`loglikelihood`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - logprior(demo([1.0]), (m = 100.0, )) --5000.918938533205 - -julia> # Using a `OrderedDict`. - logprior(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --5000.918938533205 - -julia> # Truth. - logpdf(Normal(), 100.0) --5000.918938533205 -``` -""" -logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) = - logprior(model, SimpleVarInfo(θ)) - -""" - loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) - -Return the log likelihood of variables `θ` for the probabilistic `model`. - -See also [`logjoint`](@ref) and [`logprior`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - loglikelihood(demo([1.0]), (m = 100.0, )) --4901.418938533205 - -julia> # Using a `OrderedDict`. - loglikelihood(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --4901.418938533205 - -julia> # Truth. - logpdf(Normal(100.0, 1.0), 1.0) --4901.418938533205 -``` -""" -Distributions.loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) = - loglikelihood(model, SimpleVarInfo(θ)) - # Allow usage of `NamedBijector` too. function link!!( t::StaticTransformation{<:Bijectors.NamedTransform}, diff --git a/test/compiler.jl b/test/compiler.jl index c701bce29..62b6b9b2b 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -836,4 +836,47 @@ module Issue537 end @test vi isa VarInfo @test vi[@varname(m)] isa Real end + + @testset "convert_model_argument" begin + tdual = ForwardDiff.Dual{Nothing,Float64,0} + # no-op + @test DynamicPPL.convert_model_argument(Float64, 1.0) == 1.0 + @testset "shouldn't promote types of value arguments" begin + # i.e. this shouldn't become a dual. + @test DynamicPPL.convert_model_argument(tdual, 1.0) == 1.0 + end + @testset "arrays" begin + # convert_model_argument should make sure to not deepcopy arrays if not needed + x = [1.0] + @test DynamicPPL.convert_model_argument(Float64, x) === x + # but if there's a missing in the array, it should + y = [1.0, missing] + y_converted = DynamicPPL.convert_model_argument(Float64, y) + @test y_converted !== y + @test isequal(y_converted, y) + end + @testset "type arguments" begin + # These tests with types / TypeWrap as the second argument also test + # `promote_model_type_argument`. + function test_type_conversion( + ::Type{input}, ::Type{target} + ) where {input,target} + converted_type = DynamicPPL.convert_model_argument(tdual, input) + @test converted_type == target + typewrap = DynamicPPL.TypeWrap{input}() + converted_typewrap = DynamicPPL.convert_model_argument(tdual, typewrap) + @test converted_typewrap == DynamicPPL.TypeWrap{target}() + end + test_type_conversion(Float64, tdual) + test_type_conversion(Real, ForwardDiff.Dual{Nothing,Real,0}) + test_type_conversion(Vector{Float64}, Vector{tdual}) + test_type_conversion(Vector{Real}, Vector{ForwardDiff.Dual{Nothing,Real,0}}) + test_type_conversion(Matrix{Float64}, Matrix{tdual}) + test_type_conversion(Matrix{Real}, Matrix{ForwardDiff.Dual{Nothing,Real,0}}) + test_type_conversion(Vector{Vector{Float64}}, Vector{Vector{tdual}}) + test_type_conversion( + Vector{Vector{Real}}, Vector{Vector{ForwardDiff.Dual{Nothing,Real,0}}} + ) + end + end end