Skip to content

Commit 7b17e0e

Browse files
authored
allow specifying type argument in GLMM predict on original data (#856)
* allow specifying `type` argument in GLMM predict on original data * NEWS * slightly more efficient
1 parent 6408a34 commit 7b17e0e

File tree

4 files changed

+10
-3
lines changed

4 files changed

+10
-3
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ MixedModels v5.0.0 Release Notes
99
- Internal code around optimization in profiling has been restructuring so that fitting done during calls to `profile` respect the `backend` and `optimizer` settings. [#853]
1010
- The `prfit!` convenience function has been removed. [#853]
1111
- The `dataset` and `datasets` functions have been removed. They are now housed in `MixedModelsDatasets`.[#854]
12+
- One argument `predict(::GeneralizedLinearMixedModel)`, i.e. prediction on the original data, now supports the `type` keyword argument. [#856]
1213

1314
MixedModels v4.38.0 Release Notes
1415
==============================

src/predict.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ function StatsAPI.predict(
8686
return type == :linpred ? y : broadcast!(Base.Fix1(linkinv, Link(m)), y, y)
8787
end
8888

89+
function StatsAPI.predict(m::GeneralizedLinearMixedModel; type=:response)
90+
type in (:linpred, :response) || throw(ArgumentError("Invalid value for type: $(type)"))
91+
return type == :response ? fitted(m) : m.resp.eta
92+
end
93+
8994
# β is separated out here because m.β != m.LMM.β depending on how β is estimated for GLMM
9095
# also β should already be pivoted but NOT truncated in the rank deficient case
9196
function _predict(m::MixedModel{T}, newdata, β; new_re_levels) where {T}

test/fit.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,10 @@ end
7777
@testset "generalized" begin
7878
gm1 = fit(MixedModel, @formula(use ~ 1 + urban + livch + age + abs2(age) + (1 | dist)),
7979
MixedModels.dataset(:contra), Bernoulli(); progress=false)
80-
@test deviance(gm1) 2372.7286 atol = 1.0e-3
81-
8280
gm2 = glmm(@formula(use ~ 1 + urban + livch + age + abs2(age) + (1 | dist)),
8381
MixedModels.dataset(:contra), Bernoulli(); progress=false)
84-
@test deviance(gm2) 2372.7286 atol = 1.0e-3
82+
83+
@test deviance(gm1) deviance(gm2)
8584
end
8685

8786
@testset "Normal-IdentityLink" begin

test/predict.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ end
197197
# we can skip a lot of testing if the broad strokes work because
198198
# internally this is punted off to the same machinery as LMM
199199
@test predict(gm0) fitted(gm0)
200+
@test predict(gm0; type=:response) fitted(gm0)
201+
@test predict(gm0; type=:linpred) GLM.linkfun.(Ref(Link(gm0)), fitted(gm0))
200202
# XXX these tolerances aren't great but are required for fast=false fits
201203
@test predict(gm0, contra; type=:linpred) gm0.resp.eta rtol = 0.1
202204
@test predict(gm0, contra; type=:response) gm0.resp.mu rtol = 0.01

0 commit comments

Comments
 (0)