Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
- Additional methods for pre-allocated result arrays and `*Config` instances have been added to the ForwardDiff extension. [#871].

MixedModels v5.2.0 Release Notes
==============================
- The use of the `wts` keyword argument has been deprecated in favor of the keyword argument `weights`, in line with the deprecation in GLM.jl v1.9.1. The usage (and subsequent interpretation) remains otherwise unchanged. [#873]
Expand Down Expand Up @@ -714,4 +716,5 @@ Package dependencies
[#864]: https://github.com/JuliaStats/MixedModels.jl/issues/864
[#865]: https://github.com/JuliaStats/MixedModels.jl/issues/865
[#867]: https://github.com/JuliaStats/MixedModels.jl/issues/867
[#871]: https://github.com/JuliaStats/MixedModels.jl/issues/871
[#873]: https://github.com/JuliaStats/MixedModels.jl/issues/873
56 changes: 50 additions & 6 deletions ext/MixedModelsForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ using LinearAlgebra: LinearAlgebra,
using SparseArrays: SparseArrays, nzrange

# Stuff we're defining in this file
using ForwardDiff: ForwardDiff
using ForwardDiff: ForwardDiff,
Chunk,
GradientConfig,
HessianConfig
using MixedModels: fd_cholUnblocked!,
fd_deviance,
fd_logdet,
Expand Down Expand Up @@ -59,6 +62,16 @@ const FORWARDDIFF = """
should be included is currently still being decided.
"""

#####
##### Gradients
#####

function ForwardDiff.GradientConfig(
model::LinearMixedModel{T}, x::AbstractVector{T}=model.θ, chunk::Chunk=Chunk(x)
) where {T}
return GradientConfig(fd_deviance(model), x, chunk)
end

"""
ForwardDiff.gradient(model::LinearMixedModel)

Expand All @@ -68,9 +81,29 @@ values.
$(FORWARDDIFF)
"""
function ForwardDiff.gradient(
model::LinearMixedModel{T}, θ::Vector{T}=model.θ
model::LinearMixedModel{T}, θ::Vector{T}=model.θ,
cfg::GradientConfig=GradientConfig(model, θ),
check::Val{CHK}=Val(true),
) where {T,CHK}
return ForwardDiff.gradient!(similar(model.θ), model, θ, cfg, check)
end

function ForwardDiff.gradient!(result::AbstractArray,
model::LinearMixedModel{T}, θ::Vector{T}=model.θ,
cfg::GradientConfig=GradientConfig(model, θ),
check::Val{CHK}=Val(true),
) where {T,CHK}
return ForwardDiff.gradient!(result, fd_deviance(model), θ, cfg, check)
end

#####
##### Hessians
#####

function ForwardDiff.HessianConfig(
model::LinearMixedModel{T}, x::AbstractVector{T}=model.θ, chunk::Chunk=Chunk(x)
) where {T}
return ForwardDiff.gradient(fd_deviance(model), θ)
return HessianConfig(fd_deviance(model), x, chunk)
end

"""
Expand All @@ -82,9 +115,20 @@ values.
$(FORWARDDIFF)
"""
function ForwardDiff.hessian(
model::LinearMixedModel{T}, θ::Vector{T}=model.θ
) where {T}
return ForwardDiff.hessian(fd_deviance(model), θ)
model::LinearMixedModel{T}, θ::Vector{T}=model.θ,
cfg::HessianConfig=HessianConfig(model, θ),
check::Val{CHK}=Val(true),
) where {T,CHK}
n = length(θ)
return ForwardDiff.hessian!(Matrix{T}(undef, n, n), model, θ, cfg, check)
end

function ForwardDiff.hessian!(result::AbstractArray,
model::LinearMixedModel{T}, θ::Vector{T}=model.θ,
cfg::HessianConfig=HessianConfig(model, θ),
check::Val{CHK}=Val(true),
) where {T,CHK}
return ForwardDiff.hessian!(result, fd_deviance(model), θ, cfg, check)
end

#####
Expand Down
4 changes: 4 additions & 0 deletions gradients/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*.html
*\~
*.swp

Loading
Loading