Skip to content

Commit f101d32

Browse files
authored
Make coeftable(::MatrixTerm) always return Vector{String} (#334)
1 parent 143a2d2 commit f101d32

File tree

3 files changed

+34
-26
lines changed

3 files changed

+34
-26
lines changed

src/terms.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -573,13 +573,13 @@ Return the name(s) of column(s) generated by a term. Return value is either a
573573
See also [`termnames`](@ref).
574574
"""
575575
StatsAPI.coefnames(t::FormulaTerm) = (coefnames(t.lhs), coefnames(t.rhs))
576-
StatsAPI.coefnames(::InterceptTerm{H}) where {H} = H ? "(Intercept)" : []
576+
StatsAPI.coefnames(::InterceptTerm{H}) where {H} = H ? ["(Intercept)"] : String[]
577577
StatsAPI.coefnames(t::ContinuousTerm) = string(t.sym)
578578
StatsAPI.coefnames(t::CategoricalTerm) =
579579
["$(t.sym): $name" for name in t.contrasts.coefnames]
580580
StatsAPI.coefnames(t::FunctionTerm) = string(t.exorig)
581581
StatsAPI.coefnames(ts::TupleTerm) = reduce(vcat, coefnames.(ts))
582-
StatsAPI.coefnames(t::MatrixTerm) = mapreduce(coefnames, vcat, t.terms)
582+
StatsAPI.coefnames(t::MatrixTerm) = mapreduce(coefnames, vcat, t.terms; init = String[])
583583
StatsAPI.coefnames(t::InteractionTerm) =
584584
kron_insideout((args...) -> join(args, " & "), vectorize.(coefnames.(t.terms))...)
585585

test/modelmatrix.jl

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testset "Model matrix" begin
2-
2+
33
using StatsBase: StatisticalModel
44

55
using SparseArrays, DataFrames, Tables
@@ -14,7 +14,7 @@
1414
d.x1p = categorical(d.x1)
1515

1616
d_orig = deepcopy(d)
17-
17+
1818
x1 = [5.:8;]
1919
x2 = [9.:12;]
2020
x3 = [13.:16;]
@@ -161,8 +161,8 @@
161161
z = repeat([:e, :f], inner = 4))
162162
cs = Dict([Symbol(name) => EffectsCoding() for name in names(d)])
163163
d.n = 1.:8
164-
165-
164+
165+
166166
## No intercept
167167
mf = ModelFrame(@formula(n ~ 0 + x), d, contrasts=cs)
168168
mm = ModelMatrix(mf)
@@ -182,8 +182,8 @@
182182
mm = ModelMatrix(mf)
183183
@test all(mm.m .== ifelse.(d.x .== :a, -1, 1))
184184
@test coefnames(mf) == ["x: b"]
185-
186-
185+
186+
187187
## No first-order term for interaction
188188
mf = ModelFrame(@formula(n ~ 1 + x + x&y), d, contrasts=cs)
189189
mm = ModelMatrix(mf)
@@ -197,7 +197,7 @@
197197
1 0 1]
198198
@test mm.m == ModelMatrix{sparsetype}(mf).m
199199
@test coefnames(mf) == ["(Intercept)", "x: b", "x: a & y: d", "x: b & y: d"]
200-
200+
201201
## When both terms of interaction are non-redundant:
202202
mf = ModelFrame(@formula(n ~ 0 + x&y), d, contrasts=cs)
203203
mm = ModelMatrix(mf)
@@ -218,7 +218,7 @@
218218
mm = ModelMatrix(mf)
219219
@test mm.m == Matrix(1.0I, 8, 8)
220220
@test mm.m == ModelMatrix{sparsetype}(mf).m
221-
221+
222222
# two two-way interactions, with no lower-order term. both are promoted in
223223
# first (both x and y), but only the old term (x) in the second (because
224224
# dropping x gives z which isn't found elsewhere, but dropping z gives x
@@ -237,7 +237,7 @@
237237
@test coefnames(mf) == ["x: a & y: c", "x: b & y: c",
238238
"x: a & y: d", "x: b & y: d",
239239
"x: a & z: f", "x: b & z: f"]
240-
240+
241241
# ...and adding a three-way interaction, only the shared term (x) is promoted.
242242
# this is because dropping x gives y&z which isn't present, but dropping y or z
243243
# gives x&z or x&z respectively, which are both present.
@@ -256,7 +256,7 @@
256256
"x: a & y: d", "x: b & y: d",
257257
"x: a & z: f", "x: b & z: f",
258258
"x: a & y: d & z: f", "x: b & y: d & z: f"]
259-
259+
260260
# two two-way interactions, with common lower-order term. the common term x is
261261
# promoted in both (along with lower-order term), because in every case, when
262262
# x is dropped, the remaining terms (1, y, and z) aren't present elsewhere.
@@ -274,8 +274,8 @@
274274
@test coefnames(mf) == ["x: a", "x: b",
275275
"x: a & y: d", "x: b & y: d",
276276
"x: a & z: f", "x: b & z: f"]
277-
278-
277+
278+
279279
## FAILS: When both terms are non-redundant and intercept is PRESENT
280280
## (not fully redundant). Ideally, would drop last column. Might make sense
281281
## to warn about this, and suggest recoding x and y into a single variable.
@@ -286,7 +286,7 @@
286286
1 0 0 0]
287287
@test_broken coefnames(mf) == ["x: a & y: c", "x: b & y: c",
288288
"x: a & y: d", "x: b & y: d"]
289-
289+
290290
## note that R also does not detect this automatically. it's left to glm et al.
291291
## to detect numerically when the model matrix is rank deficient, which is hard
292292
## to do correctly.
@@ -343,7 +343,7 @@
343343
x = repeat([:a, :b], outer = 4),
344344
y = repeat([:c, :d], inner = 2, outer = 2),
345345
z = repeat([:e, :f], inner = 4))
346-
346+
347347
f = apply_schema(@formula(r ~ 1 + w*x*y*z), schema(d))
348348
modelmatrix(f, d)
349349
@test reduce(vcat, last.(modelcols.(Ref(f), Tables.rowtable(d)))') == modelmatrix(f,d)
@@ -355,7 +355,7 @@
355355
x = repeat([:a, :b], outer = 4),
356356
y = repeat([:c, :d], inner = 2, outer = 2),
357357
z = repeat([:e, :f], inner = 4))
358-
358+
359359
f = @formula(r ~ 1 + w*x*y*z)
360360

361361
mm1 = modelmatrix(f, d)
@@ -375,19 +375,19 @@
375375
C=repeat(['L','H'], inner=4))
376376

377377
contrasts = Dict(:A=>HelmertCoding(), :B=>HelmertCoding(), :C=>HelmertCoding())
378-
379-
378+
379+
380380

381381
mf = ModelFrame(@formula(Y ~ 1 + A*B*C), tbl)
382382
mf_helm = ModelFrame(@formula(Y ~ 1 + A*B*C), tbl, contrasts = contrasts)
383383

384384
@test size(modelmatrix(mf)) == size(modelmatrix(mf_helm))
385-
385+
386386
mf_helm2 = setcontrasts!(ModelFrame(@formula(Y ~ 1 + A*B*C), tbl), contrasts)
387387

388388
@test size(modelmatrix(mf)) == size(modelmatrix(mf_helm2))
389389
@test modelmatrix(mf_helm) == modelmatrix(mf_helm2)
390-
390+
391391
end
392392
end
393393

@@ -402,5 +402,15 @@
402402
f = apply_schema(@formula(0 ~ a&b&c), schema(t))
403403
@test vec(modelcols(f.rhs, t)) == modelcols.(Ref(f.rhs), Tables.rowtable(t))
404404
end
405-
405+
406+
@testset "#112. coefnames should return same type for all rhs: $(f)" for f in [
407+
@formula(y ~ 1),
408+
@formula(y ~ x1 + 0),
409+
@formula(y ~ x1),
410+
@formula(y ~ x1 + x2),
411+
]
412+
df = (y = [1.0, 1.0], x1 = [1, 2], x2 = ["A", "B"])
413+
_f = apply_schema(f, schema(f, df))
414+
@test coefnames(_f.rhs) isa Vector{String}
415+
end
406416
end

test/temporal_terms.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ using DataStructures
2626
resp, pred = modelcols(f, df)
2727

2828
@test isequal(pred[:, 1], [missing; 1.0:9])
29-
@test coefnames(f)[2] == "x_lag1"
29+
@test coefnames(f)[2] == ["x_lag1"]
3030
end
3131

3232
@testset "Row Table" begin
@@ -53,7 +53,7 @@ using DataStructures
5353
resp, pred = modelcols(neg_f, df);
5454

5555
@test isequal(pred[:, 1], [3.0:10; missing; missing])
56-
@test coefnames(neg_f)[2] == "x_lag-2"
56+
@test coefnames(neg_f)[2] == ["x_lag-2"]
5757
end
5858

5959
@testset "Categorical Term use" begin
@@ -184,7 +184,5 @@ using DataStructures
184184
@test coefnames(t1) == coefnames(t2) == coefnames(t3)
185185
end
186186
end
187-
188-
189187
end
190188
end

0 commit comments

Comments
 (0)