Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/contrasts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ ContrastsMatrix(contrasts_matrix::ContrastsMatrix, levels::AbstractVector)
constructing a model matrix from a `ModelFrame` using different data.

"""
function ContrastsMatrix(contrasts::C, levels::AbstractVector{T}) where {C<:AbstractContrasts, T}
function ContrastsMatrix(contrasts::C, levels::AbstractVector) where {C<:AbstractContrasts}

u_levels = DataAPI.unwrap.(levels)

# if levels are defined on contrasts, use those, validating that they line up.
# what does that mean? either:
Expand All @@ -167,9 +169,9 @@ function ContrastsMatrix(contrasts::C, levels::AbstractVector{T}) where {C<:Abst
# better to filter data frame first
# 3. contrast levels missing from data: would have empty columns, generate a
# rank-deficient model matrix.
c_levels = something(DataAPI.levels(contrasts), levels)
c_levels = something(DataAPI.levels(contrasts), u_levels)

mismatched_levels = symdiff(c_levels, levels)
mismatched_levels = symdiff(c_levels, u_levels)
if !isempty(mismatched_levels)
throw(ArgumentError("contrasts levels not found in data or vice-versa: " *
"$mismatched_levels." *
Expand All @@ -179,7 +181,7 @@ function ContrastsMatrix(contrasts::C, levels::AbstractVector{T}) where {C<:Abst

# do conversion AFTER checking for levels so users get a nice error message
# when they've made a mistake with the level types
c_levels = convert(Vector{T}, c_levels)
c_levels = convert(Vector{eltype(u_levels)}, c_levels)

n = length(c_levels)
if n == 0
Expand Down Expand Up @@ -228,7 +230,7 @@ end

function StatsAPI.coefnames(C::AbstractContrasts, levels::AbstractVector, baseind::Integer)
not_base = [1:(baseind-1); (baseind+1):length(levels)]
levels[not_base]
DataAPI.unwrap.(levels[not_base])
end

Base.getindex(contrasts::ContrastsMatrix, rowinds, colinds) =
Expand Down Expand Up @@ -594,7 +596,7 @@ function contrasts_matrix(C::HypothesisCoding, baseind, n)
end

StatsAPI.coefnames(C::HypothesisCoding, levels::AbstractVector, baseind::Int) =
something(C.labels, levels[1:length(levels) .!= baseind])
something(C.labels, DataAPI.unwrap.(levels[1:length(levels) .!= baseind]))

DataAPI.levels(c::HypothesisCoding) = c.levels

Expand Down
Loading