Skip to content

Commit a7d1534

Browse files
authored
Add lrtest (#162)
Heavily Inspired by `ftest` in GLM. Also introduce an `isnested` function which can be overloaded by modeling packages to protect users from comparing non-nested models: a warning is printed if no method has been defined by a model type; if it's defined, an error is thrown for non-nested models.
1 parent 0c2361a commit a7d1534

File tree

4 files changed

+249
-1
lines changed

4 files changed

+249
-1
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ version = "0.6.11"
55
[deps]
66
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
77
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
8+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
911
ShiftedArrays = "1277b4bf-5013-50f5-be3d-901d8477a67a"
1012
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1113
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
@@ -16,6 +18,7 @@ CategoricalArrays = "0.7"
1618
DataAPI = "1.1"
1719
DataFrames = "0.20, 0.21"
1820
DataStructures = "0.17.0"
21+
Distributions = "0.16, 0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23"
1922
ShiftedArrays = "1.0.0"
2023
StatsBase = "0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33"
2124
Tables = "0.2, 1"

src/StatsModels.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ using ShiftedArrays
66
using ShiftedArrays: lag, lead
77
using DataStructures
88
using DataAPI: levels
9+
using Printf: @sprintf
10+
using Distributions: Chisq, ccdf
911

1012
using SparseArrays
1113
using LinearAlgebra
@@ -56,7 +58,9 @@ export
5658
width,
5759
modelcols,
5860
modelmatrix,
59-
response
61+
response,
62+
63+
lrtest
6064

6165
include("traits.jl")
6266
include("contrasts.jl")
@@ -66,5 +70,6 @@ include("temporal_terms.jl")
6670
include("formula.jl")
6771
include("modelframe.jl")
6872
include("statsmodel.jl")
73+
include("lrtest.jl")
6974

7075
end # module StatsModels

src/lrtest.jl

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
struct LRTestResult{N}
2+
nobs::Int
3+
deviance::NTuple{N, Float64}
4+
dof::NTuple{N, Int}
5+
pval::NTuple{N, Float64}
6+
end
7+
8+
_diff(t::NTuple{N}) where {N} = ntuple(i->t[i+1]-t[i], N-1)
9+
10+
"""
11+
isnested(m1::StatisticalModel, m2::StatisticalModel; atol::Real=0.0)
12+
13+
Indicate whether model `m1` is nested in model `m2`, i.e. whether
14+
`m1` can be obtained by constraining some parameters in `m2`.
15+
Both models must have been fitted on the same data.
16+
"""
17+
function isnested end
18+
19+
"""
20+
lrtest(mods::StatisticalModel...; atol::Real=0.0)
21+
22+
For each sequential pair of statistical models in `mods...`, perform a likelihood ratio
23+
test to determine if the first one fits significantly better than the next.
24+
25+
A table is returned containing degrees of freedom (DOF),
26+
difference in DOF from the preceding model, deviance, difference in deviance
27+
from the preceding model, and likelihood ratio and p-value for the comparison
28+
between the two models.
29+
30+
Optional keyword argument `atol` controls the numerical tolerance when testing whether
31+
the models are nested.
32+
33+
# Examples
34+
35+
Suppose we want to compare the effects of two or more treatments on some result.
36+
Our null hypothesis is that `Result ~ 1` fits the data as well as
37+
`Result ~ 1 + Treatment`.
38+
39+
```jldoctest
40+
julia> using DataFrames, GLM
41+
42+
julia> dat = DataFrame(Result=[1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1],
43+
Treatment=[1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2],
44+
Other=categorical([1, 1, 2, 1, 2, 1, 3, 1, 1, 2, 2, 1]));
45+
46+
julia> nullmodel = glm(@formula(Result ~ 1), dat, Binomial(), LogitLink());
47+
48+
julia> model = glm(@formula(Result ~ 1 + Treatment), dat, Binomial(), LogitLink());
49+
50+
julia> bigmodel = glm(@formula(Result ~ 1 + Treatment + Other), dat, Binomial(), LogitLink());
51+
52+
julia> lrtest(nullmodel, model, bigmodel)
53+
Likelihood-ratio test: 3 models fitted on 12 observations
54+
──────────────────────────────────────────────
55+
DOF ΔDOF Deviance ΔDeviance p(>Chisq)
56+
──────────────────────────────────────────────
57+
[1] 1 16.3006
58+
[2] 2 1 15.9559 -0.3447 0.5571
59+
[3] 4 2 14.0571 -1.8988 0.3870
60+
──────────────────────────────────────────────
61+
62+
julia> lrtest(bigmodel, model, nullmodel)
63+
Likelihood-ratio test: 3 models fitted on 12 observations
64+
──────────────────────────────────────────────
65+
DOF ΔDOF Deviance ΔDeviance p(>Chisq)
66+
──────────────────────────────────────────────
67+
[1] 4 14.0571
68+
[2] 2 -2 15.9559 1.8988 0.3870
69+
[3] 1 -1 16.3006 0.3447 0.5571
70+
──────────────────────────────────────────────
71+
```
72+
"""
73+
function lrtest(mods::StatisticalModel...; atol::Real=0.0)
74+
if length(mods) < 2
75+
throw(ArgumentError("At least two models are needed to perform LR test"))
76+
end
77+
T = typeof(mods[1])
78+
df = dof.(mods)
79+
forward = df[1] <= df[2]
80+
if !all(m -> typeof(m) == T, mods)
81+
throw(ArgumentError("LR test is only valid for models of the same type"))
82+
end
83+
if !all(==(nobs(mods[1])), nobs.(mods))
84+
throw(ArgumentError("LR test is only valid for models fitted on the same data, " *
85+
"but number of observations differ"))
86+
end
87+
checknested = hasmethod(isnested, Tuple{T, T})
88+
if forward
89+
for i in 2:length(mods)
90+
if df[i-1] >= df[i] ||
91+
(checknested && !isnested(mods[i-1], mods[i], atol=atol))
92+
throw(ArgumentError("LR test is only valid for nested models"))
93+
end
94+
end
95+
else
96+
for i in 2:length(mods)
97+
if df[i] >= df[i-1] ||
98+
(checknested && !isnested(mods[i], mods[i-1], atol=atol))
99+
throw(ArgumentError("LR test is only valid for nested models"))
100+
end
101+
end
102+
end
103+
if !checknested
104+
@warn "Could not check whether models are nested as model type " *
105+
"$(nameof(T)) does not implement isnested: results may not be meaningful"
106+
end
107+
108+
dev = deviance.(mods)
109+
Δdev = _diff(dev)
110+
111+
Δdf = _diff(df)
112+
dfr = Int.(dof_residual.(mods))
113+
114+
if (forward && any(x -> x > 0, Δdev)) || (!forward && any(x -> x < 0, Δdev))
115+
throw(ArgumentError("Residual deviance must be strictly lower " *
116+
"in models with more degrees of freedom"))
117+
end
118+
119+
pval = (NaN, ccdf.(Chisq.(abs.(Δdf)), abs.(Δdev))...)
120+
return LRTestResult(Int(nobs(mods[1])), dev, df, pval)
121+
end
122+
123+
function Base.show(io::IO, lrr::LRTestResult{N}) where N
124+
Δdf = _diff(lrr.dof)
125+
Δdev = _diff(lrr.deviance)
126+
127+
nc = 6
128+
nr = N
129+
outrows = Matrix{String}(undef, nr+1, nc)
130+
131+
outrows[1, :] = ["", "DOF", "ΔDOF", "Deviance", "ΔDeviance", "p(>Chisq)"]
132+
133+
outrows[2, :] = ["[1]", @sprintf("%.0d", lrr.dof[1]), " ",
134+
@sprintf("%.4f", lrr.deviance[1]), " ", " "]
135+
136+
for i in 2:nr
137+
outrows[i+1, :] = ["[$i]", @sprintf("%.0d", lrr.dof[i]),
138+
@sprintf("%.0d", Δdf[i-1]),
139+
@sprintf("%.4f", lrr.deviance[i]), @sprintf("%.4f", Δdev[i-1]),
140+
string(StatsBase.PValue(lrr.pval[i])) ]
141+
end
142+
colwidths = length.(outrows)
143+
max_colwidths = [maximum(view(colwidths, :, i)) for i in 1:nc]
144+
totwidth = sum(max_colwidths) + 2*5
145+
146+
println(io, "Likelihood-ratio test: $N models fitted on $(lrr.nobs) observations")
147+
println(io, ''^totwidth)
148+
149+
for r in 1:nr+1
150+
for c in 1:nc
151+
cur_cell = outrows[r, c]
152+
cur_cell_len = length(cur_cell)
153+
154+
padding = " "^(max_colwidths[c]-cur_cell_len)
155+
if c > 1
156+
padding = " "*padding
157+
end
158+
159+
print(io, padding)
160+
print(io, cur_cell)
161+
end
162+
print(io, "\n")
163+
r == 1 && println(io, ''^totwidth)
164+
end
165+
print(io, ''^totwidth)
166+
end

test/statsmodel.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ function StatsBase.predict(mod::DummyMod, newX::Matrix;
4040
throw(ArgumentError("value not allowed for interval"))
4141
end
4242
end
43+
StatsBase.dof(mod::DummyMod) = length(mod.beta)
44+
StatsBase.dof_residual(mod::DummyMod) = length(mod.y) - length(mod.beta)
45+
StatsBase.nobs(mod::DummyMod) = length(mod.y)
46+
StatsBase.deviance(mod::DummyMod) = sum((response(mod) .- predict(mod)).^2)
47+
# Incorrect but simple definition
48+
StatsModels.isnested(mod1::DummyMod, mod2::DummyMod; atol::Real=0.0) =
49+
dof(mod1) <= dof(mod2)
4350

4451
# A dummy RegressionModel type that does not support intercept
4552
struct DummyModNoIntercept <: RegressionModel
@@ -83,6 +90,11 @@ function StatsBase.predict(mod::DummyModNoIntercept, newX::Matrix;
8390
throw(ArgumentError("value not allowed for interval"))
8491
end
8592
end
93+
StatsBase.dof(mod::DummyModNoIntercept) = length(mod.beta)
94+
StatsBase.dof_residual(mod::DummyModNoIntercept) = length(mod.y) - length(mod.beta)
95+
StatsBase.nobs(mod::DummyModNoIntercept) = length(mod.y)
96+
StatsBase.deviance(mod::DummyModNoIntercept) = sum((response(mod) .- predict(mod)).^2)
97+
# isnested not implemented to test fallback
8698

8799
## Another dummy model type to test fall-through show method
88100
struct DummyModTwo <: RegressionModel
@@ -212,3 +224,65 @@ Base.show(io::IO, m::DummyModTwo) = println(io, m.msg)
212224
show(io, m2)
213225

214226
end
227+
228+
@testset "lrtest" begin
229+
230+
y = collect(1:4)
231+
x1 = 2:5
232+
x2 = [1, 5, 3, 1]
233+
234+
m0 = DummyMod([1], ones(4, 1), y)
235+
m1 = DummyMod([1, 0.3], [ones(4, 1) x1], y)
236+
m2 = DummyMod([1, 0.25, 0.05, 0.04], [ones(4, 1) x1 x2 x1.*x2], y)
237+
238+
@test_throws ArgumentError lrtest(m0)
239+
@test_throws ArgumentError lrtest(m0, m0)
240+
@test_throws ArgumentError lrtest(m0, m2, m1)
241+
@test_throws ArgumentError lrtest(m1, m0, m2)
242+
@test_throws ArgumentError lrtest(m2, m0, m1)
243+
244+
m1b = DummyMod([1, 0.3], [ones(3, 1) x1[2:end]], y[2:end])
245+
@test_throws ArgumentError lrtest(m0, m1b)
246+
247+
lr1 = lrtest(m0, m1)
248+
@test isnan(lr1.pval[1])
249+
@test lr1.pval[2] 0.0010484433450981662
250+
@test sprint(show, lr1) == """
251+
Likelihood-ratio test: 2 models fitted on 4 observations
252+
──────────────────────────────────────────────
253+
DOF ΔDOF Deviance ΔDeviance p(>Chisq)
254+
──────────────────────────────────────────────
255+
[1] 1 14.0000
256+
[2] 2 1 3.2600 -10.7400 0.0010
257+
──────────────────────────────────────────────"""
258+
259+
m0 = DummyModNoIntercept(Float64[], ones(4, 0), y)
260+
m1 = DummyModNoIntercept([0.3], reshape(x1, :, 1), y)
261+
m2 = DummyModNoIntercept([0.25, 0.05, 0.04], [x1 x2 x1.*x2], y)
262+
263+
@test_throws ArgumentError lrtest(m0)
264+
@test_throws ArgumentError lrtest(m0, m0)
265+
@test_throws ArgumentError lrtest(m0, m2, m1)
266+
@test_throws ArgumentError lrtest(m1, m0, m2)
267+
@test_throws ArgumentError lrtest(m2, m0, m1)
268+
269+
m1b = DummyModNoIntercept([0.3], reshape(x1[2:end], :, 1), y[2:end])
270+
@test_throws ArgumentError lrtest(m0, m1b)
271+
272+
# Incorrect, but check that it doesn't throw an error
273+
lr2 = @test_logs((:warn, "Could not check whether models are nested " *
274+
"as model type DummyModNoIntercept does not implement isnested: " *
275+
"results may not be meaningful"),
276+
lrtest(m0, m1))
277+
@test isnan(lr2.pval[1])
278+
@test lr2.pval[2] 1.2147224767092312e-5
279+
@test sprint(show, lr2) == """
280+
Likelihood-ratio test: 2 models fitted on 4 observations
281+
──────────────────────────────────────────────
282+
DOF ΔDOF Deviance ΔDeviance p(>Chisq)
283+
──────────────────────────────────────────────
284+
[1] 0 30.0000
285+
[2] 1 1 10.8600 -19.1400 <1e-4
286+
──────────────────────────────────────────────"""
287+
288+
end

0 commit comments

Comments
 (0)