From eefe907ace0636ac7b0ca8f4eabbfa35a6b2ef64 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Fri, 1 Aug 2025 09:49:30 +0200 Subject: [PATCH] Support CategoricalArrays 1 Since `levels(::CategoricalArray)` now returns a `CategoricalArray`, we need to unwrap the result before storing it as an `Array` field. This also works on CategoricalArrays 0.10. --- src/contrasts.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/contrasts.jl b/src/contrasts.jl index ca9befdd..b069c7f3 100644 --- a/src/contrasts.jl +++ b/src/contrasts.jl @@ -157,7 +157,9 @@ ContrastsMatrix(contrasts_matrix::ContrastsMatrix, levels::AbstractVector) constructing a model matrix from a `ModelFrame` using different data. """ -function ContrastsMatrix(contrasts::C, levels::AbstractVector{T}) where {C<:AbstractContrasts, T} +function ContrastsMatrix(contrasts::C, levels::AbstractVector) where {C<:AbstractContrasts} + + u_levels = DataAPI.unwrap.(levels) # if levels are defined on contrasts, use those, validating that they line up. # what does that mean? either: @@ -167,9 +169,9 @@ function ContrastsMatrix(contrasts::C, levels::AbstractVector{T}) where {C<:Abst # better to filter data frame first # 3. contrast levels missing from data: would have empty columns, generate a # rank-deficient model matrix. - c_levels = something(DataAPI.levels(contrasts), levels) + c_levels = something(DataAPI.levels(contrasts), u_levels) - mismatched_levels = symdiff(c_levels, levels) + mismatched_levels = symdiff(c_levels, u_levels) if !isempty(mismatched_levels) throw(ArgumentError("contrasts levels not found in data or vice-versa: " * "$mismatched_levels." * @@ -179,7 +181,7 @@ function ContrastsMatrix(contrasts::C, levels::AbstractVector{T}) where {C<:Abst # do conversion AFTER checking for levels so users get a nice error message # when they've made a mistake with the level types - c_levels = convert(Vector{T}, c_levels) + c_levels = convert(Vector{eltype(u_levels)}, c_levels) n = length(c_levels) if n == 0 @@ -228,7 +230,7 @@ end function StatsAPI.coefnames(C::AbstractContrasts, levels::AbstractVector, baseind::Integer) not_base = [1:(baseind-1); (baseind+1):length(levels)] - levels[not_base] + DataAPI.unwrap.(levels[not_base]) end Base.getindex(contrasts::ContrastsMatrix, rowinds, colinds) = @@ -594,7 +596,7 @@ function contrasts_matrix(C::HypothesisCoding, baseind, n) end StatsAPI.coefnames(C::HypothesisCoding, levels::AbstractVector, baseind::Int) = - something(C.labels, levels[1:length(levels) .!= baseind]) + something(C.labels, DataAPI.unwrap.(levels[1:length(levels) .!= baseind])) DataAPI.levels(c::HypothesisCoding) = c.levels