Skip to content
Merged
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"

[extensions]
CategoricalArraysArrowExt = "Arrow"
CategoricalArraysJSONExt = "JSON"
CategoricalArraysRecipesBaseExt = "RecipesBase"
CategoricalArraysStatsBaseExt = "StatsBase"
CategoricalArraysSentinelArraysExt = "SentinelArrays"
CategoricalArraysStructTypesExt = "StructTypes"

Expand All @@ -37,6 +39,7 @@ RecipesBase = "1.1"
Requires = "1"
SentinelArrays = "1"
Statistics = "1"
StatsBase = "0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34"
StructTypes = "1"
julia = "1.6"

Expand All @@ -49,8 +52,9 @@ PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecipesPipeline = "01d81517-befc-4cb6-b9ec-a95719d0359c"
SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Arrow", "Dates", "JSON", "JSON3", "PooledArrays", "RecipesBase", "RecipesPipeline", "SentinelArrays", "StructTypes", "Test"]
test = ["Arrow", "Dates", "JSON", "JSON3", "PooledArrays", "RecipesBase", "RecipesPipeline", "SentinelArrays", "StatsBase", "StructTypes", "Test"]
13 changes: 13 additions & 0 deletions ext/CategoricalArraysStatsBaseExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module CategoricalArraysStatsBaseExt

if isdefined(Base, :get_extension)
import CategoricalArrays: _wquantile
using StatsBase
else
import ..CategoricalArrays: _wquantile
using ..StatsBase
end

_wquantile(x::AbstractArray, w::AbstractWeights, p::AbstractVector) = quantile(x, w, p)

end
1 change: 1 addition & 0 deletions src/CategoricalArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ module CategoricalArrays
@require JSON="682c06a0-de6a-54ab-a142-c8b1cf79cde6" include("../ext/CategoricalArraysJSONExt.jl")
@require RecipesBase="3cdcf5f2-1ef4-517c-9805-6587b60abb01" include("../ext/CategoricalArraysRecipesBaseExt.jl")
@require SentinelArrays="91c51154-3ec4-41a3-a24f-3f23e20d615c" include("../ext/CategoricalArraysSentinelArraysExt.jl")
@require StatsBase="2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" include("../ext/CategoricalArraysStatsBaseExt.jl")
@require StructTypes="856f2bd8-1eba-4b0a-8007-ebc267875bd4" include("../ext/CategoricalArraysStructTypesExt.jl")
end
end
Expand Down
44 changes: 36 additions & 8 deletions src/extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,17 @@ function find_breaks(v::AbstractVector, qs::AbstractVector)
return breaks
end

# AbstractWeights method is defined in StatsBase extension
# There is no in-place weighted quantile method in StatsBase
_wquantile(x::AbstractArray, w::AbstractVector, p::AbstractVector) =
throw(ArgumentError("`weights` must be an `AbstractWeights` vector from StatsBase.jl"))

"""
cut(x::AbstractArray, ngroups::Integer;
labels::Union{AbstractVector{<:AbstractString},Function},
sigdigits::Integer=3,
allowempty::Bool=false)
allowempty::Bool=false,
weights::Union{AbstractWeights, Nothing}=nothing)

Cut a numeric array into `ngroups` quantiles.

Expand Down Expand Up @@ -373,19 +379,41 @@ quantiles.
other than the last one are equal, generating empty intervals;
when `true`, duplicate breaks are allowed and the intervals they generate are kept as
unused levels (but duplicate labels are not allowed).
* `weights::Union{AbstractWeights, Nothing}=nothing`: observations weights to used when
computing quantiles (see `quantile` documentation in StatsBase).
"""
function cut(x::AbstractArray, ngroups::Integer;
labels::Union{AbstractVector{<:SupportedTypes},Function,Nothing}=nothing,
sigdigits::Integer=3,
allowempty::Bool=false)
allowempty::Bool=false,
weights::Union{AbstractVector, Nothing}=nothing)
ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)"))
sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
min_x, max_x = first(sorted_x), last(sorted_x)
if (min_x isa Number && isnan(min_x)) ||
(max_x isa Number && isnan(max_x))
throw(ArgumentError("NaN values are not allowed in input vector"))
if weights === nothing
sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
min_x, max_x = first(sorted_x), last(sorted_x)
if (min_x isa Number && isnan(min_x)) ||
(max_x isa Number && isnan(max_x))
throw(ArgumentError("NaN values are not allowed in input vector"))
end
qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true)
else
if eltype(x) >: Missing
nm_inds = findall(!ismissing, x)
nm_x = view(x, nm_inds)
# TODO: use a view once this is supported (JuliaStats/StatsBase.jl#723)
nm_weights = weights[nm_inds]
else
nm_x = x
nm_weights = weights
end
sorted_x = sort(nm_x)
min_x, max_x = first(sorted_x), last(sorted_x)
if (min_x isa Number && isnan(min_x)) ||
(max_x isa Number && isnan(max_x))
throw(ArgumentError("NaN values are not allowed in input vector"))
end
qs = _wquantile(nm_x, nm_weights, (1:(ngroups-1))/ngroups)
end
qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true)
breaks = [min_x; find_breaks(sorted_x, qs); max_x]
if !allowempty && !allunique(@view breaks[1:end-1])
throw(ArgumentError("cannot compute $ngroups quantiles due to " *
Expand Down
24 changes: 24 additions & 0 deletions test/15_extras.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module TestExtras
using Test
using CategoricalArrays
using StatsBase
using Missings

const ≅ = isequal

Expand Down Expand Up @@ -423,4 +425,26 @@ end

end

@testset "cut with weighted quantiles" begin
@test_throws ArgumentError cut(1:3, 3, weights=1:3)

x = collect(Float64, 1:100)
w = fweights(repeat(1:10, inner=10))
y = cut(x, 10, weights=w)
@test levelcode.(y) == levelcode.(cut(x, quantile(x, w, (0:10)./10)))
@test levels(y) == ["[1, 29)", "[29, 43)", "[43, 53)", "[53, 62)", "[62, 70)",
"[70, 77)", "[77, 83)", "[83, 89)", "[89, 95)", "[95, 100]"]

mx = allowmissing(x)
mx[2] = mx[10] = missing
nm_inds = .!ismissing.(mx)
y = cut(mx, 10, weights=w)
@test levelcode.(y) ≅ levelcode.(cut(mx, quantile(x[nm_inds], w[nm_inds], (0:10)./10)))
@test levels(y) == ["[1, 30)", "[30, 43)", "[43, 53)", "[53, 62)", "[62, 70)",
"[70, 77)", "[77, 83)", "[83, 89)", "[89, 95)", "[95, 100]"]

x[5] = NaN
@test_throws ArgumentError cut(x, 3, weights=w)
end

end
Loading