Skip to content

Commit d0d6a0a

Browse files
improve performance of hessians with static arrays (#751)
* improve performance of hessians with static arrays * ForwardDiffStaticArraysExt.jl: import Partials * Update HessianTest.jl * Update HessianTest.jl * Simplify code and fix test --------- Co-authored-by: David Widmann <[email protected]>
1 parent 42e0aa6 commit d0d6a0a

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

ext/ForwardDiffStaticArraysExt.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module ForwardDiffStaticArraysExt
33
using ForwardDiff, StaticArrays
44
using ForwardDiff.LinearAlgebra
55
using ForwardDiff.DiffResults
6-
using ForwardDiff: Dual, partials, GradientConfig, JacobianConfig, HessianConfig, Tag, Chunk,
6+
using ForwardDiff: Dual, partials, npartials, Partials, GradientConfig, JacobianConfig, HessianConfig, Tag, Chunk,
77
gradient, hessian, jacobian, gradient!, hessian!, jacobian!,
88
extract_gradient!, extract_jacobian!, extract_value!,
99
vector_mode_gradient, vector_mode_gradient!,
@@ -71,8 +71,9 @@ end
7171
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig) where {F} = jacobian!(result, f, x)
7272
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where {F} = jacobian!(result, f, x)
7373

74-
@generated function extract_jacobian(::Type{T}, ydual::StaticArray, x::S) where {T,S<:StaticArray}
75-
M, N = length(ydual), length(x)
74+
@generated function extract_jacobian(::Type{T}, ydual::Union{StaticArray,Partials}, x::S) where {T,S<:StaticArray}
75+
M = ydual <: Partials ? npartials(ydual) : length(ydual)
76+
N = length(x)
7677
result = Expr(:tuple, [:(partials(T, ydual[$i], $j)) for i in 1:M, j in 1:N]...)
7778
return quote
7879
$(Expr(:meta, :inline))

test/HessianTest.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,4 +163,15 @@ end
163163
@test ForwardDiff.hessian(x->dot(x,H,x), zeros(3)) [2 6 10; 6 10 14; 10 14 18]
164164
end
165165

166+
#https://github.com/JuliaDiff/ForwardDiff.jl/issues/720
167+
@testset "allocation-free hessian with StaticArrays" begin
168+
function hessian_allocs()
169+
g = r -> (r[1]^2 - 3) * (r[2]^2 - 2)
170+
x = SVector(0.5, 2.8)
171+
hres = DiffResults.HessianResult(x)
172+
return @allocated(ForwardDiff.hessian!(hres, g, x))
173+
end
174+
@test iszero(hessian_allocs())
175+
end
176+
166177
end # module

0 commit comments

Comments
 (0)