Skip to content

Commit 5574963

Browse files
use DataAPI.levels instead of levels field on contrasts, terms as hints (#181)
* use AbstractTerm hints directly into schema * introduce levels(::AbstractContrasts) instead of property access * fix ambiguity error and also ambiguity in levels/levels() * add placeholder Grouping "contrasts" * Apply suggestions from code review Co-authored-by: Milan Bouchet-Valat <[email protected]> * Revert "add placeholder Grouping "contrasts"" This reverts commit d01aff9. * tests for hints (and fix hint bug I missed...) * tests for base= and levels= constructors and levels()/baselevel() Co-authored-by: Milan Bouchet-Valat <[email protected]>
1 parent a7d1534 commit 5574963

File tree

5 files changed

+96
-8
lines changed

5 files changed

+96
-8
lines changed

src/StatsModels.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using StatsBase
55
using ShiftedArrays
66
using ShiftedArrays: lag, lead
77
using DataStructures
8+
using DataAPI
89
using DataAPI: levels
910
using Printf: @sprintf
1011
using Distributions: Chisq, ccdf

src/contrasts.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,12 @@ function ContrastsMatrix(contrasts::C, levels::AbstractVector{T}) where {C<:Abst
159159
# if levels are defined on contrasts, use those, validating that they line up.
160160
# what does that mean? either:
161161
#
162-
# 1. contrasts.levels == levels (best case)
162+
# 1. DataAPI.levels(contrasts) == levels (best case)
163163
# 2. data levels missing from contrast: would generate empty/undefined rows.
164164
# better to filter data frame first
165165
# 3. contrast levels missing from data: would have empty columns, generate a
166166
# rank-deficient model matrix.
167-
c_levels = something(contrasts.levels, levels)
167+
c_levels = something(DataAPI.levels(contrasts), levels)
168168
if eltype(c_levels) != eltype(levels)
169169
throw(ArgumentError("mismatching levels types: got $(eltype(levels)), expected " *
170170
"$(eltype(c_levels)) based on contrasts levels."))
@@ -242,11 +242,13 @@ for contrastType in [:DummyCoding, :EffectsCoding, :HelmertCoding, :SeqDiffCodin
242242
## constructor with optional keyword arguments, defaulting to nothing
243243
$contrastType(; base=nothing, levels::Union{AbstractVector,Nothing}=nothing) = $contrastType(base, levels)
244244
baselevel(c::$contrastType) = c.base
245+
DataAPI.levels(c::$contrastType) = c.levels
245246
end
246247
end
247248

248-
# fallback method for other types that might not have base field
249+
# fallback method for other types that might not have base or level fields
249250
baselevel(c::AbstractContrasts) = nothing
251+
DataAPI.levels(c::AbstractContrasts) = nothing
250252

251253
"""
252254
FullDummyCoding()
@@ -586,6 +588,8 @@ end
586588
termnames(C::HypothesisCoding, levels::AbstractVector, baseind::Int) =
587589
something(C.labels, levels[1:length(levels) .!= baseind])
588590

591+
DataAPI.levels(c::HypothesisCoding) = c.levels
592+
589593
"""
590594
StatsModels.ContrastsCoding(mat::AbstractMatrix[, levels]])
591595
StatsModels.ContrastsCoding(mat::AbstractMatrix[; levels=nothing])
@@ -628,6 +632,8 @@ function contrasts_matrix(C::ContrastsCoding, baseind, n)
628632
C.mat
629633
end
630634

635+
DataAPI.levels(c::ContrastsCoding) = c.levels
636+
631637
## hypothesis matrix
632638
"""
633639
needs_intercept(mat::AbstractMatrix)
@@ -721,4 +727,3 @@ function pretty_mat(mat::AbstractMatrix; tol::Real=10*eps(eltype(mat)))
721727
return fracs
722728
end
723729
end
724-

src/schema.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,11 @@ concrete_term(t::Term, dt::ColumnTable, hints::Dict{Symbol}) =
172172
concrete_term(t, getproperty(dt, t.sym), get(hints, t.sym, nothing))
173173
concrete_term(t::Term, d) = concrete_term(t, d, nothing)
174174

175+
# if the "hint" is already an AbstractTerm, use that
176+
# need this specified to avoid ambiguity
177+
concrete_term(t::Term, d::ColumnTable, hint::AbstractTerm) = hint
178+
concrete_term(t::Term, x, hint::AbstractTerm) = hint
179+
175180
# second possible fix for #97
176181
concrete_term(t, d, hint) = t
177182

test/contrasts.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,4 +283,53 @@
283283
cmat3p[1] += 1e-3
284284
@test needs_intercept(cmat3p) == true
285285
end
286+
287+
@testset "levels and baselevel" begin
288+
using DataAPI: levels
289+
using StatsModels: baselevel, FullDummyCoding, ContrastsCoding
290+
291+
levs = [:a, :b, :c, :d]
292+
base = [:a]
293+
for C in [DummyCoding, EffectsCoding, SeqDiffCoding, HelmertCoding]
294+
c = C()
295+
@test levels(c) == nothing
296+
@test baselevel(c) == nothing
297+
298+
c = C(levels=levs)
299+
@test levels(c) == levs
300+
@test baselevel(c) == nothing
301+
302+
c = C(base=base)
303+
@test levels(c) == nothing
304+
@test baselevel(c) == base
305+
306+
c = C(levels=levs, base=base)
307+
@test levels(c) == levs
308+
@test baselevel(c) == base
309+
end
310+
311+
c = FullDummyCoding()
312+
@test baselevel(c) == nothing
313+
@test levels(c) == nothing
314+
315+
@test_throws MethodError FullDummyCoding(levels=levs)
316+
@test_throws MethodError FullDummyCoding(base=base)
317+
318+
c = HypothesisCoding(rand(3,4))
319+
@test baselevel(c) == levels(c) == nothing
320+
c = HypothesisCoding(rand(3,4), levels=levs)
321+
@test baselevel(c) == nothing
322+
@test levels(c) == levs
323+
# no notion of base level for HypothesisCoding
324+
@test_throws MethodError HypothesisCoding(rand(3,4), base=base)
325+
326+
c = ContrastsCoding(rand(4,3))
327+
@test baselevel(c) == levels(c) == nothing
328+
c = ContrastsCoding(rand(4,3), levels=levs)
329+
@test baselevel(c) == nothing
330+
@test levels(c) == levs
331+
# no notion of base level for ContrastsCoding
332+
@test_throws MethodError ContrastsCoding(rand(4,3), base=base)
333+
334+
end
286335
end

test/schema.jl

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,37 @@
22

33
using StatsModels: schema, apply_schema, FullRank
44

5-
f = @formula(y ~ 1 + a + b + c + b&c)
6-
df = (y = rand(9), a = 1:9, b = rand(9), c = repeat(["d","e","f"], 3))
7-
f = apply_schema(f, schema(f, df))
8-
@test f == apply_schema(f, schema(f, df))
5+
@testset "no-op apply_schema" begin
6+
f = @formula(y ~ 1 + a + b + c + b&c)
7+
df = (y = rand(9), a = 1:9, b = rand(9), c = repeat(["d","e","f"], 3))
8+
f = apply_schema(f, schema(f, df))
9+
@test f == apply_schema(f, schema(f, df))
10+
end
911

12+
@testset "hints" begin
13+
f = @formula(y ~ 1 + a)
14+
d = (y = rand(10), a = repeat([1,2], outer=2))
15+
16+
sch = schema(f, d)
17+
@test sch[term(:a)] isa ContinuousTerm
18+
19+
sch1 = schema(f, d, Dict(:a => CategoricalTerm))
20+
@test sch1[term(:a)] isa CategoricalTerm{DummyCoding}
21+
f1 = apply_schema(f, sch1)
22+
@test f1.rhs.terms[end] == sch1[term(:a)]
23+
24+
sch2 = schema(f, d, Dict(:a => DummyCoding()))
25+
@test sch2[term(:a)] isa CategoricalTerm{DummyCoding}
26+
f2 = apply_schema(f, sch2)
27+
@test f2.rhs.terms[end] == sch2[term(:a)]
28+
29+
hint = deepcopy(sch2[term(:a)])
30+
sch3 = schema(f, d, Dict(:a => hint))
31+
# if an <:AbstractTerm is supplied as hint, it's included as is
32+
@test sch3[term(:a)] === hint !== sch2[term(:a)]
33+
f3 = apply_schema(f, sch3)
34+
@test f3.rhs.terms[end] === hint
35+
36+
end
37+
1038
end

0 commit comments

Comments
 (0)