Skip to content

Commit 5cf78c3

Browse files
authored
Schema application to solitary term (#208)
* add broken cases to tests * reduce(+,...) instead of unconditionally sum * actually do some tests for singleton tuple term * patch version bump * add unary +operator for AbstractTerm * make codecov happy, tuple everything
1 parent 552e3ba commit 5cf78c3

File tree

5 files changed

+39
-29
lines changed

5 files changed

+39
-29
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StatsModels"
22
uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
3-
version = "0.6.18"
3+
version = "0.6.19"
44

55
[deps]
66
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"

src/schema.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ Base.haskey(schema::Schema, key) = haskey(schema.schema, key)
6060
Compute all the invariants necessary to fit a model with `terms`. A schema is a dict that
6161
maps `Term`s to their concrete instantiations (either `CategoricalTerm`s or
6262
`ContinuousTerm`s. "Hints" may optionally be supplied in the form of a `Dict` mapping term
63-
names (as `Symbol`s) to term or contrast types. If a hint is not provided for a variable,
63+
names (as `Symbol`s) to term or contrast types. If a hint is not provided for a variable,
6464
the appropriate term type will be guessed based on the data type from the data column: any
6565
numeric data is assumed to be continuous, and any non-numeric data is assumed to be
6666
categorical.
@@ -93,7 +93,7 @@ StatsModels.Schema with 1 entry:
9393
y => y
9494
```
9595
96-
Note that concrete `ContinuousTerm` and `CategoricalTerm` and un-typed `Term`s print the
96+
Note that concrete `ContinuousTerm` and `CategoricalTerm` and un-typed `Term`s print the
9797
same in a container, but when printed alone are different:
9898
9999
```jldoctest 1
@@ -203,9 +203,9 @@ end
203203
Return a new term that is the result of applying `schema` to term `t` with
204204
destination model (type) `Mod`. If `Mod` is omitted, `Nothing` will be used.
205205
206-
When `t` is a `ContinuousTerm` or `CategoricalTerm` already, the term will be returned
207-
unchanged _unless_ a matching term is found in the schema. This allows
208-
selective re-setting of a schema to change the contrast coding or levels of a
206+
When `t` is a `ContinuousTerm` or `CategoricalTerm` already, the term will be returned
207+
unchanged _unless_ a matching term is found in the schema. This allows
208+
selective re-setting of a schema to change the contrast coding or levels of a
209209
categorical term, or to change a continuous term to categorical or vice versa.
210210
211211
When defining behavior for custom term types, it's best to dispatch on
@@ -214,7 +214,7 @@ in _most_ cases, but cause method ambiguity in some.
214214
"""
215215
apply_schema(t, schema) = apply_schema(t, schema, Nothing)
216216
apply_schema(t, schema, Mod::Type) = t
217-
apply_schema(terms::TupleTerm, schema, Mod::Type) = sum(apply_schema.(terms, Ref(schema), Mod))
217+
apply_schema(terms::TupleTerm, schema, Mod::Type) = reduce(+, apply_schema.(terms, Ref(schema), Mod))
218218

219219
apply_schema(t::Term, schema::Schema, Mod::Type) = schema[t]
220220
apply_schema(ft::FormulaTerm, schema::Schema, Mod::Type) =
@@ -284,7 +284,7 @@ function apply_schema(t::FormulaTerm, schema::Schema, Mod::Type{<:StatisticalMod
284284
end
285285

286286
# strategy is: apply schema, then "repair" if necessary (promote to full rank
287-
# contrasts).
287+
# contrasts).
288288
#
289289
# to know whether to repair, need to know context a term appears in. main
290290
# effects occur in "own" context.

src/terms.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ width(::ConstantTerm) = 1
4242
FormulaTerm{L,R} <: AbstractTerm
4343
4444
Represents an entire formula, with a left- and right-hand side. These can be of
45-
any type (captured by the type parameters).
45+
any type (captured by the type parameters).
4646
4747
# Fields
4848
@@ -64,7 +64,7 @@ data table.
6464
6565
The `FunctionTerm` _also_ captures the arguments of the original call and parses
6666
them _as if_ they were part of a special DSL call, applying the rules to expand
67-
`*`, distribute `&` over `+`, and wrap symbols in `Term`s.
67+
`*`, distribute `&` over `+`, and wrap symbols in `Term`s.
6868
6969
By storing the original function as a type parameter _and_ pessimistically
7070
parsing the arguments as if they're part of a special DSL call, this allows
@@ -78,7 +78,7 @@ on `apply_schema(f::FunctionTerm{typeof(special_syntax)}, schema,
7878
* `forig::Forig`: the original function (e.g., `log`)
7979
* `fanon::Fanon`: the generated anonymous function (e.g., `(a, b) -> log(1+a+b)`)
8080
* `exorig::Expr`: the original expression passed to `@formula`
81-
* `args_parsed::Vector`: the arguments of the call passed to `@formula`, each
81+
* `args_parsed::Vector`: the arguments of the call passed to `@formula`, each
8282
parsed _as if_ the call was a "special" DSL call.
8383
8484
# Type parameters
@@ -113,7 +113,7 @@ julia> modelcols(f.rhs, (a=3, b=4))
113113
julia> modelcols(f.rhs, (a=[3, 4], b=[4, 5]))
114114
2-element Array{Float64,1}:
115115
2.0794415416798357
116-
2.302585092994046
116+
2.302585092994046
117117
```
118118
"""
119119
struct FunctionTerm{Forig,Fanon,Names} <: AbstractTerm
@@ -132,7 +132,7 @@ Base.:(==)(a::FunctionTerm, b::FunctionTerm) = a.forig == b.forig && a.exorig ==
132132
"""
133133
InteractionTerm{Ts} <: AbstractTerm
134134
135-
Represents an _interaction_ between two or more individual terms.
135+
Represents an _interaction_ between two or more individual terms.
136136
137137
Generated by combining multiple `AbstractTerm`s with `&` (which is what calls to
138138
`&` in a `@formula` lower to)
@@ -223,8 +223,8 @@ Represents a categorical term, with a name and [`ContrastsMatrix`](@ref)
223223
# Fields
224224
225225
* `sym::Symbol`: The name of the variable
226-
* `contrasts::ContrastsMatrix`: A contrasts matrix that captures the unique
227-
values this variable takes on and how they are mapped onto numerical
226+
* `contrasts::ContrastsMatrix`: A contrasts matrix that captures the unique
227+
values this variable takes on and how they are mapped onto numerical
228228
predictors.
229229
"""
230230
struct CategoricalTerm{C,T,N} <: AbstractTerm
@@ -242,9 +242,9 @@ CategoricalTerm(sym::Symbol, contrasts::ContrastsMatrix{C,T}) where {C,T} =
242242
243243
A collection of terms that should be combined to produce a single numeric matrix.
244244
245-
A matrix term is created by [`apply_schema`](@ref) from a tuple of terms using
245+
A matrix term is created by [`apply_schema`](@ref) from a tuple of terms using
246246
[`collect_matrix_terms`](@ref), which pulls out all the terms that are matrix
247-
terms as determined by the trait function [`is_matrix_term`](@ref), which is
247+
terms as determined by the trait function [`is_matrix_term`](@ref), which is
248248
true by default for all `AbstractTerm`s.
249249
"""
250250
struct MatrixTerm{Ts<:TupleTerm} <: AbstractTerm
@@ -311,7 +311,7 @@ is_matrix_term(::Type{<:AbstractTerm}) = true
311311

312312

313313
"""
314-
capture_call(f_orig::Function, f_anon::Function, argnames::NTuple{N,Symbol},
314+
capture_call(f_orig::Function, f_anon::Function, argnames::NTuple{N,Symbol},
315315
ex_orig::Expr, args_parsed::Vector{AbstractTerm})
316316
317317
When the [`@formula`](@ref) macro encounters a call to a function that's not
@@ -404,6 +404,7 @@ Base.:&(terms::AbstractTerm...) = InteractionTerm(terms)
404404
Base.:&(term::AbstractTerm) = term
405405
Base.:&(it::InteractionTerm, terms::AbstractTerm...) = InteractionTerm((it.terms..., terms...))
406406

407+
Base.:+(a::AbstractTerm) = a
407408
Base.:+(a::AbstractTerm, b::AbstractTerm) = a==b ? a : (a, b)
408409
Base.:+(as::TupleTerm, b::AbstractTerm) = b in as ? as : (as..., b)
409410
Base.:+(a::AbstractTerm, bs::TupleTerm) = a in bs ? bs : (a, bs...)
@@ -437,7 +438,7 @@ end
437438
"""
438439
modelcols(ts::NTuple{N, AbstractTerm}, data) where N
439440
440-
When a tuple of terms is provided, `modelcols` broadcasts over the individual
441+
When a tuple of terms is provided, `modelcols` broadcasts over the individual
441442
terms. To create a single matrix, wrap the tuple in a [`MatrixTerm`](@ref).
442443
443444
# Example
@@ -448,7 +449,7 @@ julia> using StableRNGs; rng = StableRNG(1);
448449
julia> d = (a = [1:9;], b = rand(rng, 9), c = repeat(["d","e","f"], 3));
449450
450451
julia> ts = apply_schema(term.((:a, :b, :c)), schema(d))
451-
a(continuous)
452+
a(continuous)
452453
b(continuous)
453454
c(DummyCoding:3→2)
454455
@@ -498,8 +499,8 @@ modelcols(t::CategoricalTerm, d::NamedTuple) = t.contrasts[d[t.sym], :]
498499
"""
499500
reshape_last_to_i(i::Int, a)
500501
501-
Reshape `a` so that its last dimension moves to dimension `i` (+1 if `a` is an
502-
`AbstractMatrix`).
502+
Reshape `a` so that its last dimension moves to dimension `i` (+1 if `a` is an
503+
`AbstractMatrix`).
503504
"""
504505
reshape_last_to_i(i, a) = a
505506
reshape_last_to_i(i, a::AbstractVector) = reshape(a, ones(Int, i-1)..., :)
@@ -557,7 +558,7 @@ Return the name(s) of column(s) generated by a term. Return value is either a
557558
StatsBase.coefnames(t::FormulaTerm) = (coefnames(t.lhs), coefnames(t.rhs))
558559
StatsBase.coefnames(::InterceptTerm{H}) where {H} = H ? "(Intercept)" : []
559560
StatsBase.coefnames(t::ContinuousTerm) = string(t.sym)
560-
StatsBase.coefnames(t::CategoricalTerm) =
561+
StatsBase.coefnames(t::CategoricalTerm) =
561562
["$(t.sym): $name" for name in t.contrasts.termnames]
562563
StatsBase.coefnames(t::FunctionTerm) = string(t.exorig)
563564
StatsBase.coefnames(ts::TupleTerm) = reduce(vcat, coefnames.(ts))
@@ -580,7 +581,7 @@ omitsintercept(t::TermOrTerms) =
580581

581582
hasresponse(t) = false
582583
hasresponse(t::FormulaTerm) =
583-
t.lhs !== nothing &&
584+
t.lhs !== nothing &&
584585
t.lhs !== ConstantTerm(0) &&
585586
t.lhs !== InterceptTerm{false}()
586587

test/schema.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
@test f == apply_schema(f, schema(f, df))
1010
end
1111

12+
@testset "lonely term in a tuple" begin
13+
d = (a = [1,1],)
14+
@test apply_schema(ConstantTerm(1), schema(d)) == apply_schema((ConstantTerm(1),), schema(d))
15+
@test apply_schema(Term(:a), schema(d)) == apply_schema((Term(:a),), schema(d))
16+
end
17+
1218
@testset "hints" begin
1319
f = @formula(y ~ 1 + a)
1420
d = (y = rand(10), a = repeat([1,2], outer=2))
@@ -20,7 +26,7 @@
2026
@test sch1[term(:a)] isa CategoricalTerm{DummyCoding}
2127
f1 = apply_schema(f, sch1)
2228
@test f1.rhs.terms[end] == sch1[term(:a)]
23-
29+
2430
sch2 = schema(f, d, Dict(:a => DummyCoding()))
2531
@test sch2[term(:a)] isa CategoricalTerm{DummyCoding}
2632
f2 = apply_schema(f, sch2)
@@ -39,7 +45,7 @@
3945
using StatsModels: has_schema
4046

4147
d = (y = rand(10), a = rand(10), b = repeat([:a, :b], 5))
42-
48+
4349
f = @formula(y ~ a*b)
4450
@test !has_schema(f)
4551
@test !has_schema(f.rhs)
@@ -63,5 +69,5 @@
6369
@test has_schema(sch[a] & sch[b])
6470

6571
end
66-
72+
6773
end

test/terms.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ StatsModels.apply_schema(mt::MultiTerm, sch::StatsModels.Schema, Mod::Type) =
5151
@test mimestring(t2full) == "aaa(StatsModels.FullDummyCoding:2→2)"
5252
@test string(t2full) == "aaa"
5353
end
54-
54+
5555
@testset "term operators" begin
5656
a = term(:a)
5757
b = term(:b)
@@ -84,6 +84,9 @@ StatsModels.apply_schema(mt::MultiTerm, sch::StatsModels.Schema, Mod::Type) =
8484
@test b+ab == ab
8585
@test ab+ab == ab
8686
@test ab+bc == abc
87+
@test sum((a,b,c)) == abc
88+
@test sum((a,)) == a
89+
@test +a == a
8790
end
8891

8992
@testset "uniqueness of FunctionTerms" begin
@@ -169,5 +172,5 @@ StatsModels.apply_schema(mt::MultiTerm, sch::StatsModels.Schema, Mod::Type) =
169172
end
170173

171174
end
172-
175+
173176
end

0 commit comments

Comments
 (0)