Skip to content

Commit d8c882c

Browse files
authored
Add a formula(model) accessor, loglikelihood(model, :), and update tests (#226)
* add formula accessor method * bump patch * bump dataframes compat to 1 * fix tests * trying to fix version problems on 1.0 * how about this one * add tests for loglikelihood method
1 parent 4af81b9 commit d8c882c

File tree

8 files changed

+28
-7
lines changed

8 files changed

+28
-7
lines changed

Project.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StatsModels"
22
uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
3-
version = "0.6.22"
3+
version = "0.6.23"
44

55
[deps]
66
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
@@ -14,11 +14,11 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1414
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1515

1616
[compat]
17-
CategoricalArrays = "0.8"
17+
CategoricalArrays = "0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10"
1818
DataAPI = "1.1"
19-
DataFrames = "0.21"
20-
DataStructures = "0.17.0, 0.18"
21-
ShiftedArrays = "1.0.0"
19+
DataFrames = "1"
20+
DataStructures = "0.17, 0.18"
21+
ShiftedArrays = "1"
2222
StatsBase = "0.33.5"
2323
StatsFuns = "0.9"
2424
Tables = "0.2, 1"

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ fit
6464
response
6565
modelmatrix
6666
lrtest
67+
formula
6768
```
6869

6970
### Traits

src/StatsModels.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ export
3535
coefnames,
3636
dropterm,
3737
setcontrasts!,
38+
formula,
3839

3940
AbstractTerm,
4041
ConstantTerm,

src/statsmodel.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,16 @@ for (modeltype, dfmodeltype) in ((:StatisticalModel, TableStatisticalModel),
9696
end
9797
end
9898

99+
"""
100+
formula(model)
101+
102+
Retrieve formula from a fitted or specified model
103+
"""
104+
function formula end
105+
106+
formula(m::TableStatisticalModel) = m.mf.f
107+
formula(m::TableRegressionModel) = m.mf.f
108+
99109
@doc """
100110
fit(Mod::Type{<:StatisticalModel}, f::FormulaTerm, data, args...;
101111
contrasts::Dict{Symbol}, kwargs...)
@@ -132,6 +142,7 @@ StatsBase.r2(mm::TableRegressionModel) = r2(mm.model)
132142
StatsBase.adjr2(mm::TableRegressionModel) = adjr2(mm.model)
133143
StatsBase.r2(mm::TableRegressionModel, variant::Symbol) = r2(mm.model, variant)
134144
StatsBase.adjr2(mm::TableRegressionModel, variant::Symbol) = adjr2(mm.model, variant)
145+
StatsBase.loglikelihood(mm::TableModels, c::Colon) = loglikelihood(mm.model, c)
135146

136147
function _return_predictions(T, yp::AbstractVector, nonmissings, len)
137148
out = Vector{Union{eltype(yp),Missing}}(missing, len)

test/contrasts.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@
227227
f_sdiff = apply_schema(f, schema(d2, Dict(:x => sdiff_hyp)))
228228
f_effects = apply_schema(f, schema(d2, Dict(:x => effects_hyp)))
229229

230-
y_means = by(d2, :x, :y => mean).y_mean
230+
y_means = combine(groupby(d2, :x), :y => mean).y_mean
231231

232232
y, X_sdiff = modelcols(f_sdiff, d2)
233233
@test X_sdiff \ y [mean(y_means); diff(y_means)]

test/modelmatrix.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@
159159
d = DataFrame(x = repeat([:a, :b], outer = 4),
160160
y = repeat([:c, :d], inner = 2, outer = 2),
161161
z = repeat([:e, :f], inner = 4))
162-
categorical!(d)
163162
cs = Dict([Symbol(name) => EffectsCoding() for name in names(d)])
164163
d.n = 1.:8
165164

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using SparseArrays
44

55
using StatsModels
66
using DataFrames
7+
using CategoricalArrays
78
using StatsBase
89

910
using StatsModels: ContrastsMatrix

test/statsmodel.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ StatsBase.deviance(mod::DummyMod) = sum((response(mod) .- predict(mod)).^2)
4747
# Incorrect but simple definition
4848
StatsModels.isnested(mod1::DummyMod, mod2::DummyMod; atol::Real=0.0) =
4949
dof(mod1) <= dof(mod2)
50+
StatsBase.loglikelihood(mod::DummyMod) = -sum((response(mod) .- predict(mod)).^2)
51+
StatsBase.loglikelihood(mod::DummyMod, ::Colon) = -(response(mod) .- predict(mod)).^2
5052

5153
# A dummy RegressionModel type that does not support intercept
5254
struct DummyModNoIntercept <: RegressionModel
@@ -122,6 +124,12 @@ Base.show(io::IO, m::DummyModTwo) = println(io, m.msg)
122124
## coefnames delegated to model frame by default
123125
@test coefnames(m) == coefnames(ModelFrame(f, d)) == ["(Intercept)", "x1", "x2", "x1 & x2"]
124126

127+
@test formula(m) == m.mf.f
128+
129+
## loglikelihood methods from StatsBase
130+
@test length(loglikelihood(m, :)) == nrow(d)
131+
@test sum(loglikelihood(m, :)) == loglikelihood(m) == -deviance(m)
132+
125133
## test prediction method
126134
## vanilla
127135
@test predict(m) == [ ones(size(d,1)) Array(d.x1) Array(d.x2) Array(d.x1).*Array(d.x2) ] * collect(1:4)

0 commit comments

Comments
 (0)