Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/terms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -573,13 +573,13 @@ Return the name(s) of column(s) generated by a term. Return value is either a
See also [`termnames`](@ref).
"""
StatsAPI.coefnames(t::FormulaTerm) = (coefnames(t.lhs), coefnames(t.rhs))
StatsAPI.coefnames(::InterceptTerm{H}) where {H} = H ? "(Intercept)" : []
StatsAPI.coefnames(::InterceptTerm{H}) where {H} = H ? ["(Intercept)"] : String[]
StatsAPI.coefnames(t::ContinuousTerm) = string(t.sym)
StatsAPI.coefnames(t::CategoricalTerm) =
["$(t.sym): $name" for name in t.contrasts.coefnames]
StatsAPI.coefnames(t::FunctionTerm) = string(t.exorig)
StatsAPI.coefnames(ts::TupleTerm) = reduce(vcat, coefnames.(ts))
StatsAPI.coefnames(t::MatrixTerm) = mapreduce(coefnames, vcat, t.terms)
StatsAPI.coefnames(t::MatrixTerm) = mapreduce(coefnames, vcat, t.terms; init = String[])
StatsAPI.coefnames(t::InteractionTerm) =
kron_insideout((args...) -> join(args, " & "), vectorize.(coefnames.(t.terms))...)

Expand Down
50 changes: 30 additions & 20 deletions test/modelmatrix.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testset "Model matrix" begin

using StatsBase: StatisticalModel

using SparseArrays, DataFrames, Tables
Expand All @@ -14,7 +14,7 @@
d.x1p = categorical(d.x1)

d_orig = deepcopy(d)

x1 = [5.:8;]
x2 = [9.:12;]
x3 = [13.:16;]
Expand Down Expand Up @@ -161,8 +161,8 @@
z = repeat([:e, :f], inner = 4))
cs = Dict([Symbol(name) => EffectsCoding() for name in names(d)])
d.n = 1.:8


## No intercept
mf = ModelFrame(@formula(n ~ 0 + x), d, contrasts=cs)
mm = ModelMatrix(mf)
Expand All @@ -182,8 +182,8 @@
mm = ModelMatrix(mf)
@test all(mm.m .== ifelse.(d.x .== :a, -1, 1))
@test coefnames(mf) == ["x: b"]


## No first-order term for interaction
mf = ModelFrame(@formula(n ~ 1 + x + x&y), d, contrasts=cs)
mm = ModelMatrix(mf)
Expand All @@ -197,7 +197,7 @@
1 0 1]
@test mm.m == ModelMatrix{sparsetype}(mf).m
@test coefnames(mf) == ["(Intercept)", "x: b", "x: a & y: d", "x: b & y: d"]

## When both terms of interaction are non-redundant:
mf = ModelFrame(@formula(n ~ 0 + x&y), d, contrasts=cs)
mm = ModelMatrix(mf)
Expand All @@ -218,7 +218,7 @@
mm = ModelMatrix(mf)
@test mm.m == Matrix(1.0I, 8, 8)
@test mm.m == ModelMatrix{sparsetype}(mf).m

# two two-way interactions, with no lower-order term. both are promoted in
# first (both x and y), but only the old term (x) in the second (because
# dropping x gives z which isn't found elsewhere, but dropping z gives x
Expand All @@ -237,7 +237,7 @@
@test coefnames(mf) == ["x: a & y: c", "x: b & y: c",
"x: a & y: d", "x: b & y: d",
"x: a & z: f", "x: b & z: f"]

# ...and adding a three-way interaction, only the shared term (x) is promoted.
# this is because dropping x gives y&z which isn't present, but dropping y or z
# gives x&z or x&z respectively, which are both present.
Expand All @@ -256,7 +256,7 @@
"x: a & y: d", "x: b & y: d",
"x: a & z: f", "x: b & z: f",
"x: a & y: d & z: f", "x: b & y: d & z: f"]

# two two-way interactions, with common lower-order term. the common term x is
# promoted in both (along with lower-order term), because in every case, when
# x is dropped, the remaining terms (1, y, and z) aren't present elsewhere.
Expand All @@ -274,8 +274,8 @@
@test coefnames(mf) == ["x: a", "x: b",
"x: a & y: d", "x: b & y: d",
"x: a & z: f", "x: b & z: f"]


## FAILS: When both terms are non-redundant and intercept is PRESENT
## (not fully redundant). Ideally, would drop last column. Might make sense
## to warn about this, and suggest recoding x and y into a single variable.
Expand All @@ -286,7 +286,7 @@
1 0 0 0]
@test_broken coefnames(mf) == ["x: a & y: c", "x: b & y: c",
"x: a & y: d", "x: b & y: d"]

## note that R also does not detect this automatically. it's left to glm et al.
## to detect numerically when the model matrix is rank deficient, which is hard
## to do correctly.
Expand Down Expand Up @@ -343,7 +343,7 @@
x = repeat([:a, :b], outer = 4),
y = repeat([:c, :d], inner = 2, outer = 2),
z = repeat([:e, :f], inner = 4))

f = apply_schema(@formula(r ~ 1 + w*x*y*z), schema(d))
modelmatrix(f, d)
@test reduce(vcat, last.(modelcols.(Ref(f), Tables.rowtable(d)))') == modelmatrix(f,d)
Expand All @@ -355,7 +355,7 @@
x = repeat([:a, :b], outer = 4),
y = repeat([:c, :d], inner = 2, outer = 2),
z = repeat([:e, :f], inner = 4))

f = @formula(r ~ 1 + w*x*y*z)

mm1 = modelmatrix(f, d)
Expand All @@ -375,19 +375,19 @@
C=repeat(['L','H'], inner=4))

contrasts = Dict(:A=>HelmertCoding(), :B=>HelmertCoding(), :C=>HelmertCoding())



mf = ModelFrame(@formula(Y ~ 1 + A*B*C), tbl)
mf_helm = ModelFrame(@formula(Y ~ 1 + A*B*C), tbl, contrasts = contrasts)

@test size(modelmatrix(mf)) == size(modelmatrix(mf_helm))

mf_helm2 = setcontrasts!(ModelFrame(@formula(Y ~ 1 + A*B*C), tbl), contrasts)

@test size(modelmatrix(mf)) == size(modelmatrix(mf_helm2))
@test modelmatrix(mf_helm) == modelmatrix(mf_helm2)

end
end

Expand All @@ -402,5 +402,15 @@
f = apply_schema(@formula(0 ~ a&b&c), schema(t))
@test vec(modelcols(f.rhs, t)) == modelcols.(Ref(f.rhs), Tables.rowtable(t))
end


@testset "#112. coefnames should return same type for all rhs: $(f)" for f in [
@formula(y ~ 1),
@formula(y ~ x1 + 0),
@formula(y ~ x1),
@formula(y ~ x1 + x2),
]
df = (y = [1.0, 1.0], x1 = [1, 2], x2 = ["A", "B"])
_f = apply_schema(f, schema(f, df))
@test coefnames(_f.rhs) isa Vector{String}
end
end
6 changes: 2 additions & 4 deletions test/temporal_terms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ using DataStructures
resp, pred = modelcols(f, df)

@test isequal(pred[:, 1], [missing; 1.0:9])
@test coefnames(f)[2] == "x_lag1"
@test coefnames(f)[2] == ["x_lag1"]
end

@testset "Row Table" begin
Expand All @@ -53,7 +53,7 @@ using DataStructures
resp, pred = modelcols(neg_f, df);

@test isequal(pred[:, 1], [3.0:10; missing; missing])
@test coefnames(neg_f)[2] == "x_lag-2"
@test coefnames(neg_f)[2] == ["x_lag-2"]
end

@testset "Categorical Term use" begin
Expand Down Expand Up @@ -184,7 +184,5 @@ using DataStructures
@test coefnames(t1) == coefnames(t2) == coefnames(t3)
end
end


end
end
Loading