Skip to content

Commit 9b6b428

Browse files
Merge pull request #173 from longemen3000/sa_ext
move StaticArrays to extension
2 parents b41a859 + e1cf6bf commit 9b6b428

File tree

4 files changed

+38
-11
lines changed

4 files changed

+38
-11
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1313
[weakdeps]
1414
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
1515
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
16+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1617

1718
[extensions]
1819
FiniteDiffBandedMatricesExt = "BandedMatrices"
1920
FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices"
21+
FiniteDiffStaticArraysExt = "StaticArrays"
2022

2123
[compat]
2224
ArrayInterface = "7"
@@ -30,7 +32,8 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
3032
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
3133
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
3234
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
35+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3336
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3437

3538
[targets]
36-
test = ["Test", "BlockBandedMatrices", "BandedMatrices", "Pkg", "SafeTestsets"]
39+
test = ["Test", "BlockBandedMatrices", "BandedMatrices", "Pkg", "SafeTestsets", "StaticArrays"]

ext/FiniteDiffStaticArraysExt.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module FiniteDiffStaticArraysExt
2+
3+
if isdefined(Base, :get_extension)
4+
using FiniteDiff: FiniteDiff, ArrayInterface
5+
using StaticArrays
6+
else
7+
using ..FiniteDiff: FiniteDiff, ArrayInterface
8+
using ..StaticArrays
9+
end
10+
FiniteDiff._mat(x::StaticVector) = reshape(x, (axes(x, 1), SOneTo(1)))
11+
FiniteDiff.setindex(x::StaticArray, v, i::Int...) = StaticArrays.setindex(x, v, i...)
12+
FiniteDiff.__Symmetric(x::SMatrix) = Symmetric(SArray(H))
13+
14+
end #module

src/FiniteDiff.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
module FiniteDiff
22

3-
using LinearAlgebra, SparseArrays, StaticArrays, ArrayInterface, Requires
3+
using LinearAlgebra, SparseArrays, ArrayInterface, Requires
44

55
import Base: resize!
66

77
_vec(x) = vec(x)
88
_vec(x::Number) = x
99

1010
_mat(x::AbstractMatrix) = x
11-
_mat(x::StaticVector) = reshape(x, (axes(x, 1), SOneTo(1)))
1211
_mat(x::AbstractVector) = reshape(x, (axes(x, 1), Base.OneTo(1)))
1312

1413
# Setindex overloads without piracy
1514
setindex(x...) = Base.setindex(x...)
16-
setindex(x::StaticArray, v, i::Int...) = StaticArrays.setindex(x, v, i...)
1715

1816
function setindex(x::AbstractArray, v, i...)
1917
_x = Base.copymutable(x)
@@ -39,6 +37,8 @@ include("jacobians.jl")
3937
include("hessians.jl")
4038

4139
if !isdefined(Base,:get_extension)
40+
using StaticArrays
41+
include("../ext/FiniteDiffStaticArraysExt.jl")
4242
using Requires
4343
function __init__()
4444
@require BandedMatrices="aae01518-5342-5314-be14-df237901396f" begin

src/hessians.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,27 @@ struct HessianCache{T,fdtype,inplace}
55
xmm::T
66
end
77

8+
#used to dispatch on StaticArrays
9+
_hessian_inplace(::Type{T}) where T = Val(ArrayInterface.ismutable(T))
10+
_hessian_inplace(x) = _hessian_inplace(typeof(x))
11+
__Symmetric(x) = Symmetric(x)
12+
13+
function mutable_zeromatrix(x)
14+
A = ArrayInterface.zeromatrix(x)
15+
ArrayInterface.ismutable(A) ? A : Base.copymutable(A)
16+
end
17+
18+
819
function HessianCache(xpp,xpm,xmp,xmm,
920
fdtype=Val(:hcentral),
10-
inplace = x isa StaticArray ? Val(false) : Val(true))
21+
inplace = _hessian_inplace(x))
1122
fdtype isa Type && (fdtype = fdtype())
1223
inplace isa Type && (inplace = inplace())
1324
HessianCache{typeof(xpp),fdtype,inplace}(xpp,xpm,xmp,xmm)
1425
end
1526

1627
function HessianCache(x, fdtype=Val(:hcentral),
17-
inplace = x isa StaticArray ? Val(false) : Val(true))
28+
inplace = _hessian_inplace(x))
1829
cx = copy(x)
1930
fdtype isa Type && (fdtype = fdtype())
2031
inplace isa Type && (inplace = inplace())
@@ -23,7 +34,7 @@ end
2334

2435
function finite_difference_hessian(f, x,
2536
fdtype = Val(:hcentral),
26-
inplace = x isa StaticArray ? Val(false) : Val(true);
37+
inplace = _hessian_inplace(x);
2738
relstep = default_relstep(fdtype, eltype(x)),
2839
absstep = relstep)
2940

@@ -36,16 +47,15 @@ function finite_difference_hessian(
3647
cache::HessianCache{T,fdtype,inplace};
3748
relstep=default_relstep(fdtype, eltype(x)),
3849
absstep=relstep) where {T,fdtype,inplace}
39-
_H = false .* x .* x'
40-
_H isa SMatrix ? H = MArray(_H) : H = _H
50+
H = mutable_zeromatrix(x)
4151
finite_difference_hessian!(H, f, x, cache; relstep=relstep, absstep=absstep)
42-
Symmetric(_H isa SMatrix ? SArray(H) : H)
52+
__Symmetric(H)
4353
end
4454

4555
function finite_difference_hessian!(H,f,
4656
x,
4757
fdtype = Val(:hcentral),
48-
inplace = x isa StaticArray ? Val(false) : Val(true);
58+
inplace = _hessian_inplace(x);
4959
relstep=default_relstep(fdtype, eltype(x)),
5060
absstep=relstep)
5161

0 commit comments

Comments
 (0)