Skip to content

Commit 5fb6f65

Browse files
authored
correct coeftable for GLMMs, add tests (#308)
* correct coeftable for GLMMs, add tests * Move methods for abstract MixedModel struct to separate file.
1 parent 872bba1 commit 5fb6f65

File tree

5 files changed

+116
-73
lines changed

5 files changed

+116
-73
lines changed

src/MixedModels.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ include("randomeffectsterm.jl")
144144
include("linearmixedmodel.jl")
145145
include("gausshermite.jl")
146146
include("generalizedlinearmixedmodel.jl")
147+
include("mixedmodel.jl")
147148
include("likelihoodratiotest.jl")
148149
include("linalg/statschol.jl")
149150
include("linalg/cholUnblocked.jl")

src/generalizedlinearmixedmodel.jl

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,19 @@ struct GeneralizedLinearMixedModel{T<:AbstractFloat} <: MixedModel{T}
5353
mult::Vector{T}
5454
end
5555

56-
StatsBase.coefnames(m::GeneralizedLinearMixedModel) = coefnames(m.LMM)
57-
58-
StatsBase.coeftable(m::GeneralizedLinearMixedModel) = coeftable(m.LMM)
59-
56+
function StatsBase.coeftable(m::GeneralizedLinearMixedModel)
57+
co = fixef(m)
58+
se = stderror(m)
59+
z = co ./ se
60+
pvalue = ccdf.(Chisq(1), abs2.(z))
61+
CoefTable(
62+
hcat(co, se, z, pvalue),
63+
["Estimate", "Std.Error", "z value", "P(>|z|)"],
64+
coefnames(m),
65+
4, # pvalcol
66+
3, # teststatcol
67+
)
68+
end
6069

6170
"""
6271
deviance(m::GeneralizedLinearMixedModel{T}, nAGQ=1)::T where {T}
@@ -73,7 +82,7 @@ function StatsBase.deviance(m::GeneralizedLinearMixedModel{T}, nAGQ = 1) where {
7382
u = vec(first(m.u))
7483
u₀ = vec(first(m.u₀))
7584
copyto!(u₀, u)
76-
ra = RaggedArray(m.resp.devresid, first(m.LMM.reterms).refs)
85+
ra = RaggedArray(m.resp.devresid, first(m.LMM.allterms).refs)
7786
devc0 = sum!(map!(abs2, m.devc0, u), ra) # the deviance components at z = 0
7887
sd = map!(inv, m.sd, m.LMM.L[Block(1, 1)].diag)
7988
mult = fill!(m.mult, 0)
@@ -112,16 +121,15 @@ function deviance!(m::GeneralizedLinearMixedModel, nAGQ = 1)
112121
deviance(m, nAGQ)
113122
end
114123

115-
function GLM.dispersion(m::GeneralizedLinearMixedModel, sqr::Bool = false)
124+
function GLM.dispersion(m::GeneralizedLinearMixedModel{T}, sqr::Bool = false) where {T}
116125
# adapted from GLM.dispersion(::AbstractGLM, ::Bool)
117126
# TODO: PR for a GLM.dispersion(resp::GLM.GlmResp, dof_residual::Int, sqr::Bool)
118127
r = m.resp
119128
if dispersion_parameter(r.d)
120-
wrkwt, wrkresid = r.wrkwt, r.wrkresid
121-
s = sum(i -> wrkwt[i] * abs2(wrkresid[i]), eachindex(wrkwt, wrkresid)) / dof_residual(m)
129+
s = sum(wt * abs2(re) for (wt, re) in zip(r.wrkwt, r.wrkresid)) / dof_residual(m)
122130
sqr ? s : sqrt(s)
123131
else
124-
one(eltype(r.mu))
132+
one(T)
125133
end
126134
end
127135

@@ -391,7 +399,11 @@ function Base.getproperty(m::GeneralizedLinearMixedModel, s::Symbol)
391399
m.β
392400
elseif s (, :sigma)
393401
sdest(m)
394-
elseif s (:A, :L, , :lowerbd, :corr, :vcov, :PCA, :rePCA, :optsum, :X, :reterms, :feterms, :formula, :σs, :σρs)
402+
elseif s == :σs
403+
σs(m)
404+
elseif s == :σρs
405+
σρs(m)
406+
elseif s (:A, :L, , :lowerbd, :corr, :PCA, :rePCA, :optsum, :X, :reterms, :feterms, :formula)
395407
getproperty(m.LMM, s)
396408
elseif s == :y
397409
m.resp.y
@@ -401,18 +413,17 @@ function Base.getproperty(m::GeneralizedLinearMixedModel, s::Symbol)
401413
end
402414

403415
function StatsBase.loglikelihood(m::GeneralizedLinearMixedModel{T}) where {T}
404-
accum = zero(T)
416+
r = m.resp
405417
D = Distribution(m.resp)
406-
if D <: Binomial
407-
for (μ, y, n) in zip(m.resp.mu, m.resp.y, m.wt)
408-
accum += logpdf(D(round(Int, n), μ), round(Int, y * n))
418+
accum = (
419+
if D <: Binomial
420+
sum(logpdf(D(round(Int, n), μ), round(Int, y * n))
421+
for (μ, y, n) in zip(r.mu, r.y, m.wt))
422+
else
423+
sum(logpdf(D(μ), y) for (μ, y) in zip(r.mu, r.y))
409424
end
410-
else
411-
for (μ, y) in zip(m.resp.mu, m.resp.y)
412-
accum += logpdf(D(μ), y)
413-
end
414-
end
415-
accum - (mapreduce(u -> sum(abs2, u), +, m.u) + logdet(m)) / 2
425+
)
426+
accum - (sum(sum(abs2, u) for u in m.u) + logdet(m)) / 2
416427
end
417428

418429
StatsBase.nobs(m::GeneralizedLinearMixedModel) = length(m.η)
@@ -589,16 +600,14 @@ varest(m::GeneralizedLinearMixedModel{T}) where {T} = one(T)
589600

590601
# delegate GLMM method to LMM field
591602
for f in (
592-
:describeblocks,
593603
:feL,
604+
:fetrm,
594605
:(LinearAlgebra.logdet),
595606
:lowerbd,
596607
:PCA,
597608
:rePCA,
609+
:(StatsBase.coefnames),
598610
:(StatsModels.modelmatrix),
599-
:(StatsBase.vcov),
600-
:σs,
601-
:σρs,
602611
)
603612
@eval begin
604613
$f(m::GeneralizedLinearMixedModel) = $f(m.LMM)

src/linearmixedmodel.jl

Lines changed: 7 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,6 @@ function StatsBase.coeftable(m::LinearMixedModel)
187187
)
188188
end
189189

190-
"""
191-
cond(m::MixedModel)
192-
193-
Return a vector of condition numbers of the λ matrices for the random-effects terms
194-
"""
195-
LinearAlgebra.cond(m::MixedModel) = cond.(m.λ)
196-
197190
"""
198191
condVar(m::LinearMixedModel)
199192
@@ -251,8 +244,7 @@ function createAL(allterms::Vector{Union{ReMat{T},FeMat{T}}}) where {T}
251244
A, L
252245
end
253246

254-
255-
StatsBase.deviance(m::MixedModel) = objective(m)
247+
StatsBase.deviance(m::LinearMixedModel) = objective(m)
256248

257249
GLM.dispersion(m::LinearMixedModel, sqr::Bool = false) = sqr ? varest(m) : sdest(m)
258250

@@ -270,14 +262,14 @@ function StatsBase.dof_residual(m::LinearMixedModel)::Int
270262
end
271263

272264
"""
273-
feind(m::MixedModel)
265+
feind(m::LinearMixedModel)
274266
275267
An internal utility to return the index in `m.allterms` of the fixed-effects term.
276268
"""
277-
feind(m::MixedModel) = findfirst(Base.Fix2(isa, FeMat), m.allterms)
269+
feind(m::LinearMixedModel) = findfirst(Base.Fix2(isa, FeMat), m.allterms)
278270

279271
"""
280-
feL(m::MixedModel)
272+
feL(m::LinearMixedModel)
281273
282274
Return the lower Cholesky factor for the fixed-effects parameters, as an `LowerTriangular`
283275
`p × p` matrix.
@@ -292,7 +284,7 @@ end
292284
293285
Return the fixed-effects term from `m.allterms`
294286
"""
295-
fetrm(m) = m.allterms[feind(m)]
287+
fetrm(m::LinearMixedModel) = m.allterms[feind(m)]
296288

297289
"""
298290
fit!(m::LinearMixedModel[; verbose::Bool=false, REML::Bool=false])
@@ -765,16 +757,6 @@ function Base.show(io::IO, m::LinearMixedModel)
765757
show(io, coeftable(m))
766758
end
767759

768-
function σs(m::LinearMixedModel)
769-
σ = sdest(m)
770-
NamedTuple{fnames(m)}(((σs(t, σ) for t in m.reterms)...,))
771-
end
772-
773-
function σρs(m::LinearMixedModel)
774-
σ = sdest(m)
775-
NamedTuple{fnames(m)}(((σρs(t, σ) for t in m.reterms)...,))
776-
end
777-
778760
"""
779761
size(m::LinearMixedModel)
780762
@@ -815,6 +797,8 @@ end
815797
std(m::MixedModel)
816798
817799
Return the estimated standard deviations of the random effects as a `Vector{Vector{T}}`.
800+
801+
FIXME: This uses an old convention of isfinite(sdest(m)). Probably drop in favor of m.σs
818802
"""
819803
function Statistics.std(m::LinearMixedModel)
820804
rl = rowlengths.(m.reterms)
@@ -919,32 +903,6 @@ Returns the estimate of σ², the variance of the conditional distribution of Y
919903
"""
920904
varest(m::LinearMixedModel) = pwrss(m) / ssqdenom(m)
921905

922-
"""
923-
vcov(m::LinearMixedModel)
924-
925-
Returns the variance-covariance matrix of the fixed effects.
926-
If `corr=true`, then correlation of fixed effects is returned instead.
927-
"""
928-
function StatsBase.vcov(m::LinearMixedModel{T}; corr=false) where {T}
929-
Xtrm = fetrm(m)
930-
iperm = invperm(Xtrm.piv)
931-
p = length(iperm)
932-
r = Xtrm.rank
933-
Linv = inv(feL(m))
934-
permvcov = varest(m) * (Linv'Linv)
935-
if p == Xtrm.rank
936-
vv = permvcov[iperm, iperm]
937-
else
938-
covmat = fill(zero(T) / zero(T), (p, p))
939-
for j = 1:r, i = 1:r
940-
covmat[i, j] = permvcov[i, j]
941-
end
942-
vv = covmat[iperm, iperm]
943-
end
944-
945-
corr ? StatsBase.cov2cor!(vv, stderror(m)) : vv
946-
end
947-
948906
"""
949907
zerocorr!(m::LinearMixedModel[, trmnms::Vector{Symbol}])
950908

src/mixedmodel.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
2+
"""
3+
cond(m::MixedModel)
4+
5+
Return a vector of condition numbers of the λ matrices for the random-effects terms
6+
"""
7+
LinearAlgebra.cond(m::MixedModel) = cond.(m.λ)
8+
9+
function σs(m::MixedModel)
10+
σ = dispersion(m)
11+
NamedTuple{fnames(m)}(((σs(t, σ) for t in m.reterms)...,))
12+
end
13+
14+
function σρs(m::MixedModel)
15+
σ = dispersion(m)
16+
NamedTuple{fnames(m)}(((σρs(t, σ) for t in m.reterms)...,))
17+
end
18+
19+
"""
20+
vcov(m::LinearMixedModel)
21+
22+
Returns the variance-covariance matrix of the fixed effects.
23+
If `corr=true`, then correlation of fixed effects is returned instead.
24+
"""
25+
function StatsBase.vcov(m::MixedModel; corr=false)
26+
Xtrm = fetrm(m)
27+
iperm = invperm(Xtrm.piv)
28+
p = length(iperm)
29+
r = Xtrm.rank
30+
Linv = inv(feL(m))
31+
T = eltype(Linv)
32+
permvcov = dispersion(m, true) * (Linv'Linv)
33+
if p == Xtrm.rank
34+
vv = permvcov[iperm, iperm]
35+
else
36+
covmat = fill(zero(T) / zero(T), (p, p))
37+
for j = 1:r, i = 1:r
38+
covmat[i, j] = permvcov[i, j]
39+
end
40+
vv = covmat[iperm, iperm]
41+
end
42+
43+
corr ? StatsBase.cov2cor!(vv, stderror(m)) : vv
44+
end

test/pirls.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,34 @@ end
7777
#@test isapprox(sum(x -> sum(abs2, x), gm4.u), 196.8695297987013, atol=0.1)
7878
#@test isapprox(sum(gm4.resp.devresid), 220.92685781326136, atol=0.1)
7979
end
80+
81+
@testset "goldstein" begin # from a 2020-04-22 msg by Ben Goldstein to R-SIG-Mixed-Models
82+
goldstein =
83+
categorical!(
84+
DataFrame(
85+
group = repeat(1:10, outer=10),
86+
y = [
87+
83, 3, 8, 78, 901, 21, 4, 1, 1, 39,
88+
82, 3, 2, 82, 874, 18, 5, 1, 3, 50,
89+
87, 7, 3, 67, 914, 18, 0, 1, 1, 38,
90+
86, 13, 5, 65, 913, 13, 2, 0, 0, 48,
91+
90, 5, 5, 71, 886, 19, 3, 0, 2, 32,
92+
96, 1, 1, 87, 860, 21, 3, 0, 1, 54,
93+
83, 2, 4, 70, 874, 19, 5, 0, 4, 36,
94+
100, 11, 3, 71, 950, 21, 6, 0, 1, 40,
95+
89, 5, 5, 73, 859, 29, 3, 0, 2, 38,
96+
78, 13, 6, 100, 852, 24, 5, 0, 1, 39
97+
],
98+
),
99+
:group,
100+
)
101+
gform = @formula(y ~ 1 + (1|group))
102+
m1 = fit(MixedModel, gform, goldstein, Poisson())
103+
@test deviance(m1) 193.5587302384811 rtol=1.e-5
104+
@test only(m1.β) 4.192196439077657 atol=1.e-5
105+
@test only(m1.θ) 1.838245201739852 atol=1.e-5
106+
m11 = fit(MixedModel, gform, goldstein, Poisson(), nAGQ=11)
107+
@test deviance(m11) 193.51028088736842 rtol=1.e-5
108+
@test only(m11.β) 4.192196439077657 atol=1.e-5
109+
@test only(m11.θ) 1.838245201739852 atol=1.e-5
110+
end

0 commit comments

Comments
 (0)