diff --git a/src/univariate/discrete/discretenonparametric.jl b/src/univariate/discrete/discretenonparametric.jl index 8e1eefab6..cc339fc07 100644 --- a/src/univariate/discrete/discretenonparametric.jl +++ b/src/univariate/discrete/discretenonparametric.jl @@ -23,22 +23,47 @@ struct DiscreteNonParametric{T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractV function DiscreteNonParametric{T,P,Ts,Ps}(xs::Ts, ps::Ps; check_args::Bool=true) where { T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractVector{P}} - check_args || return new{T,P,Ts,Ps}(xs, ps) + let xs = xs, ps = ps + @check_args( + DiscreteNonParametric, + (length(xs) == length(ps), "length of support and probability vector must be equal"), + (ps, isprobvec(ps), "vector is not a probability vector"), + (xs, issorted_allunique(xs), "support must be sorted and contain only unique elements"), + ) + end + new{T,P,Ts,Ps}(xs, ps) + end +end + +function DiscreteNonParametric(xs::AbstractVector{T}, ps::AbstractVector{P}; check_args::Bool=true) where {T<:Real,P<:Real} + # These checks are performed before sorting the support since we do not want to throw a `BoundsError` when the lengths do not match + let xs = xs, ps = ps @check_args( DiscreteNonParametric, (length(xs) == length(ps), "length of support and probability vector must be equal"), (ps, isprobvec(ps), "vector is not a probability vector"), - (xs, allunique(xs), "support must contain only unique elements"), ) + end + # We always sort the support unless it can be deduced from the type of the support that it is sorted. + # Sorting can be skipped for all inputs by using the inner constructor. + if xs isa AbstractUnitRange + sortedxs = xs + sortedps = ps + else sort_order = sortperm(xs) - new{T,P,Ts,Ps}(xs[sort_order], ps[sort_order]) + sortedxs = xs[sort_order] + sortedps = ps[sort_order] + # It is more efficient to perform this check once the array is sorted + let sortedxs = sortedxs + @check_args( + DiscreteNonParametric, + (sortedxs, issorted_allunique(sortedxs), "support must contain only unique elements"), + ) + end end + return DiscreteNonParametric{T,P,typeof(sortedxs),typeof(sortedps)}(sortedxs, sortedps; check_args=false) end -DiscreteNonParametric(vs::AbstractVector{T}, ps::AbstractVector{P}; check_args::Bool=true) where { - T<:Real,P<:Real} = - DiscreteNonParametric{T,P,typeof(vs),typeof(ps)}(vs, ps; check_args=check_args) - Base.eltype(::Type{<:DiscreteNonParametric{T}}) where T = T # Conversion diff --git a/src/utils.jl b/src/utils.jl index cabe51dcd..038136cd9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -97,6 +97,23 @@ isunitvec(v::AbstractVector) = (norm(v) - 1.0) < 1.0e-12 isprobvec(p::AbstractVector{<:Real}) = all(x -> x ≥ zero(x), p) && isapprox(sum(p), one(eltype(p))) +issorted_allunique(xs::AbstractUnitRange{<:Real}) = true +function issorted_allunique(xs::AbstractVector{<:Real}) + xi_state = iterate(xs) + if xi_state === nothing + return true + end + xi, state = xi_state + while (xj_state = iterate(xs, state)) !== nothing + xj, state = xj_state + if xj <= xi + return false + end + xi = xj + end + return true +end + sqrt!!(x::AbstractVector{<:Real}) = map(sqrt, x) function sqrt!!(x::Vector{<:Real}) for i in eachindex(x) diff --git a/test/univariate/discrete/categorical.jl b/test/univariate/discrete/categorical.jl index 6d87d4dc8..96425a204 100644 --- a/test/univariate/discrete/categorical.jl +++ b/test/univariate/discrete/categorical.jl @@ -137,4 +137,21 @@ end @test count(==(1e8), priorities[iat]) >= 13 end +@testset "AbstractVector" begin + # issue #1084 + P = abs.(randn(5,4,2)) + p = view(P,:,1,1) + p ./= sum(p) + d = @inferred(Categorical(p)) + @test d isa Categorical{Float64, typeof(p)} + @test d.p === p + + # #1832 + x = rand(3,5) + x ./= sum(x; dims=1) + c = Categorical.(eachcol(x)) + @test c isa Vector{<:Categorical} + @test all(ci.p isa SubArray for ci in c) +end + end diff --git a/test/univariate/discrete/discretenonparametric.jl b/test/univariate/discrete/discretenonparametric.jl index 68354a064..128207d74 100644 --- a/test/univariate/discrete/discretenonparametric.jl +++ b/test/univariate/discrete/discretenonparametric.jl @@ -14,7 +14,8 @@ rng = MersenneTwister(123) d = DiscreteNonParametric([40., 80., 120., -60.], [.4, .3, .1, .2]) -@test !(d ≈ DiscreteNonParametric([40., 80, 120, -60], [.4, .3, .1, .2], check_args=false)) +# In the outer constructor, the support is always sorted, regardless of whether `check_args = false` or `check_args = true` +@test d ≈ DiscreteNonParametric([40., 80., 120., -60.], [.4, .3, .1, .2], check_args=false) @test d ≈ DiscreteNonParametric([-60., 40., 80, 120], [.2, .4, .3, .1], check_args=false) # Invalid probability @@ -23,6 +24,25 @@ d = DiscreteNonParametric([40., 80., 120., -60.], [.4, .3, .1, .2]) # Invalid probability, but no arg check DiscreteNonParametric([40., 80, 120, -60], [.5, .3, .1, .2], check_args=false) +# Invalid support +@test_throws DomainError DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([40., 80, 120, -60], [.4, .3, .1, .2]) +@test_throws DomainError DiscreteNonParametric([-60., 40., 40., 120], [.2, .4, .3, .1]) +@test_throws DomainError DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([-60., 40., 40., 120], [.2, .4, .3, .1]) + +# Invalid support but no arg check +DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([40., 80, 120, -60], [.4, .3, .1, .2], check_args=false) +DiscreteNonParametric([-60., 40., 40., 120], [.2, .4, .3, .1], check_args=false) +DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([-60., 40., 40., 120], [.2, .4, .3, .1], check_args=false) + +# Mismatch between support and probabilities +@test_throws ArgumentError DiscreteNonParametric([-60., 40., 40., 120], [.2, .4, .3]) +@test_throws ArgumentError DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([-60., 40., 40., 120], [.2, .4, .3]) + +# Mismatch between support and probabilities but no arg check +@test_throws BoundsError DiscreteNonParametric([-60., 40., 40., 120], [.2, .4, .3], check_args=false) # sorting errors +DiscreteNonParametric(1:4, [.2, .4, .3], check_args=false) # no sorting, hence no `BoundsError` +DiscreteNonParametric{Float64,Float64,Vector{Float64},Vector{Float64}}([-60., 40., 40., 120], [.2, .4, .3], check_args=false) + test_range(d) vs = Distributions.get_evalsamples(d, 0.00001) test_evaluation(d, vs, true) @@ -213,4 +233,25 @@ end # Different types @test DiscreteNonParametric(1:2, [0.5, 0.5]) == DiscreteNonParametric([1, 2], [0.5f0, 0.5f0]) @test DiscreteNonParametric(1:2, [0.5, 0.5]) ≈ DiscreteNonParametric([1, 2], [0.5f0, 0.5f0]) -end \ No newline at end of file +end + +@testset "AbstractVector (issue #1084)" begin + P = abs.(randn(5,4,2)) + p = view(P,:,1,1) + p ./= sum(p) + + d = @inferred(DiscreteNonParametric(Base.OneTo(5), p)) + @test d isa DiscreteNonParametric + @test d.p === p + d = @inferred(DiscreteNonParametric(1:5, p)) + @test d isa DiscreteNonParametric + @test d.p === p + d = @inferred(DiscreteNonParametric(1:1:5, p)) + @test d isa DiscreteNonParametric + @test d.p !== p + @test d.p == p + d = @inferred(DiscreteNonParametric([1, 2, 3, 4, 5], p)) + @test d isa DiscreteNonParametric + @test d.p !== p + @test d.p == p +end