Skip to content

Commit eda9f10

Browse files
committed
Implement and test hessian computation
1 parent 7522e4d commit eda9f10

File tree

7 files changed

+196
-27
lines changed

7 files changed

+196
-27
lines changed

src/Interpolations.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,16 +187,16 @@ Base.to_index(::AbstractInterpolation, x::Number) = x
187187
# itp(to_indices(itp, x))
188188
# end
189189
function gradient(itp::AbstractInterpolation, x::Vararg{UnexpandedIndexTypes})
190-
gradient(itp, to_indices(itp, x))
190+
gradient(itp, to_indices(itp, x)...)
191191
end
192192
function gradient!(dest, itp::AbstractInterpolation, x::Vararg{UnexpandedIndexTypes})
193-
gradient!(dest, itp, to_indices(itp, x))
193+
gradient!(dest, itp, to_indices(itp, x)...)
194194
end
195195
function hessian(itp::AbstractInterpolation, x::Vararg{UnexpandedIndexTypes})
196-
hessian(itp, to_indices(itp, x))
196+
hessian(itp, to_indices(itp, x)...)
197197
end
198198
function hessian!(dest, itp::AbstractInterpolation, x::Vararg{UnexpandedIndexTypes})
199-
hessian!(dest, itp, to_indices(itp, x))
199+
hessian!(dest, itp, to_indices(itp, x)...)
200200
end
201201

202202
# @inline function (itp::AbstractInterpolation)(x::Vararg{ExpandedIndexTypes})

src/b-splines/indexing.jl

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@ end
1717
@boundscheck checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x)
1818
expand_gradient!(dest, itp, x)
1919
end
20+
2021
@inline function hessian(itp::BSplineInterpolation{T,N}, x::Vararg{Number,N}) where {T,N}
2122
@boundscheck checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x)
2223
expand_hessian(itp, x)
2324
end
25+
@inline function hessian!(dest, itp::BSplineInterpolation{T,N}, x::Vararg{Number,N}) where {T,N}
26+
@boundscheck checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x)
27+
expand_hessian(itp, x)
28+
end
2429

2530
checkbounds(::Type{Bool}, itp::AbstractInterpolation, x::Vararg{Number,N}) where N =
2631
checklubounds(lbounds(itp), ubounds(itp), x)
@@ -77,7 +82,7 @@ Calculate the interpolated hessian of `itp` at `x`.
7782
function expand_hessian(itp::AbstractInterpolation, x::Tuple)
7883
coefs = coefficients(itp)
7984
degree = interpdegree(itp)
80-
ixs, rxs = expand_indices_resid(degree, axes(coefs), x)
85+
ixs, rxs = expand_indices_resid(degree, bounds(itp), x)
8186
cxs = expand_weights(value_weights, degree, rxs)
8287
gxs = expand_weights(gradient_weights, degree, rxs)
8388
hxs = expand_weights(hessian_weights, degree, rxs)
@@ -194,14 +199,76 @@ function expand!(dest, coefs, (vweights, gweights)::Tuple{HasNoInterp{N},HasNoIn
194199
i = 0
195200
for d = 1:N
196201
w = substitute(vweights, d, gweights)
197-
w isa Weights || continue # must have a NoInterp in it
202+
w isa Weights || continue # if this isn't true, it must have a NoInterp in it
198203
dest[i+=1] = expand(coefs, w, ixs)
199204
end
200205
dest
201206
end
202207

203-
function expand(coefs, (vweights, gweights, hweights)::NTuple{3,Weights{N}}, ixs::Indexes{N}) where N
204-
error("not yet implemented")
208+
# Expansion of the hessian
209+
# To handle the immutability of SMatrix we build static methods that visit just the entries we need,
210+
# which due to symmetry is just the upper triangular part
211+
ntuple_sym(f, ::Val{0}) = ()
212+
ntuple_sym(f, ::Val{1}) = (f(1,1),)
213+
ntuple_sym(f, ::Val{2}) = (f(1,1), f(1,2), f(2,2))
214+
ntuple_sym(f, ::Val{3}) = (f(1,1), f(1,2), f(2,2), f(1,3), f(2,3), f(3,3))
215+
ntuple_sym(f, ::Val{4}) = (f(1,1), f(1,2), f(2,2), f(1,3), f(2,3), f(3,3), f(1,4), f(2,4), f(3,4), f(4,4))
216+
@inline function ntuple_sym(f, ::Val{N}) where N
217+
(ntuple_sym(f, Val(N-1))..., ntuple(i->f(i,N), Val(N))...)
218+
end
219+
220+
sym2dense(t::Tuple{}) = t
221+
sym2dense(t::NTuple{1,T}) where T = t
222+
sym2dense(t::NTuple{3,T}) where T = (t[1], t[2], t[2], t[3])
223+
sym2dense(t::NTuple{6,T}) where T = (t[1], t[2], t[4], t[2], t[3], t[5], t[4], t[5], t[6])
224+
sym2dense(t::NTuple{10,T}) where T = (t[1], t[2], t[4], t[7], t[2], t[3], t[5], t[8], t[4], t[5], t[6], t[9], t[7], t[8], t[9], t[10])
225+
function sym2dense(t::NTuple{L,T}) where {L,T}
226+
# Warning: non-inferrable unless we make this @generated.
227+
# Above 4 dims one might anyway prefer an Array, and use hessian!
228+
N = ceil(Int, sqrt(2*L))
229+
@assert (N*(N+1))÷2 == L
230+
a = Vector{T}(undef, N*N)
231+
idx = 0
232+
for j = 1:N, i=1:N
233+
iu, ju = ifelse(i>=j, (j, i), (i, j)) # index in the upper triangular
234+
k = (ju*(ju+1))÷2 + iu
235+
a[idx+=1] = t[k]
236+
end
237+
tuple(a...)
238+
end
239+
240+
squarematrix(t::NTuple{1,T}) where T = SMatrix{1,1,T}(t)
241+
squarematrix(t::NTuple{4,T}) where T = SMatrix{2,2,T}(t)
242+
squarematrix(t::NTuple{9,T}) where T = SMatrix{3,3,T}(t)
243+
squarematrix(t::NTuple{16,T}) where T = SMatrix{4,4,T}(t)
244+
function squarematrix(t::NTuple{L,T}) where {L,T}
245+
# Warning: non-inferrable unless we make this @generated.
246+
# Above 4 dims one might anyway prefer an Array, and use hessian!
247+
N = floor(Int, sqrt(L))
248+
@assert N*N == L
249+
SMatrix{N,N,T}(t)
250+
end
251+
252+
cumweights(w) = _cumweights(0, w...)
253+
_cumweights(c, w1, w...) = (c+1, _cumweights(c+1, w...)...)
254+
_cumweights(c, ::NoInterp, w...) = (c, _cumweights(c, w...)...)
255+
_cumweights(c) = ()
256+
257+
function expand(coefs, (vweights, gweights, hweights)::NTuple{3,HasNoInterp{N}}, ixs::Indexes{N}) where N
258+
coefs = ntuple_sym((i,j)->expand(coefs, substitute(vweights, i, j, gweights, hweights), ixs), Val(N))
259+
squarematrix(sym2dense(skip_nointerp(coefs...)))
260+
end
261+
262+
function expand!(dest, coefs, (vweights, gweights, hweights)::NTuple{3,HasNoInterp{N}}, ixs::Indexes{N}) where N
263+
# The Hessian is nominally N × N, but if there are K NoInterp dims then it's N-K × N-K
264+
indlookup = cumweights(hweights) # for d in 1:N, indlookup[d] returns the appropriate index in 1:N-K
265+
for d2 = 1:N, d1 = 1:d2
266+
w = substitute(vweights, d1, d2, gweights, hweights)
267+
w isa Weights || continue # if this isn't true, it must have a NoInterp in it
268+
i, j = indlookup[d1], indlookup[d2]
269+
dest[i, j] = dest[j, i] = expand(coefs, w, ixs)
270+
end
271+
dest
205272
end
206273

207274
function expand_indices_resid(degree, bounds, x)
@@ -255,7 +322,7 @@ roundbounds(x, bounds::Tuple{Integer,Integer}) = round(x)
255322
roundbounds(x, (l, u)) = ifelse(x == l, ceil(l), ifelse(x == u, floor(u), round(x)))
256323

257324
floorbounds(x, bounds::Tuple{Integer,Integer}) = floor(x)
258-
function floorbounds(x, (l, u))
325+
function floorbounds(x, (l, u)::Tuple{Real,Real})
259326
ceill = ceil(l)
260327
ifelse(l <= x <= ceill, ceill, floor(x))
261328
end

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,19 @@ fast_trunc(::Type{Int}, x::Rational) = x.num ÷ x.den
1414
iextract(f::Flag, d) = f
1515
iextract(t::Tuple, d) = t[d]
1616

17+
# Substitution for gradient components
1718
function substitute(default::NTuple{N,Any}, d::Integer, subst::NTuple{N,Any}) where N
1819
ntuple(i->ifelse(i==d, subst[i], default[i]), Val(N))
1920
end
2021
function substitute(default::NTuple{N,Any}, d::Integer, val) where N
2122
ntuple(i->ifelse(i==d, val, default[i]), Val(N))
2223
end
2324

25+
# Substitution for hessian components
26+
function substitute(default::NTuple{N,Any}, d1::Integer, d2::Integer, subst1::NTuple{N,Any}, subst2::NTuple{N,Any}) where N
27+
ntuple(i->ifelse(i==d1==d2, subst2[i], ifelse(i==d1, subst1[i], ifelse(i==d2, subst1[i], default[i]))), Val(N))
28+
end
29+
2430
@inline skip_nointerp(x, rest...) = (x, skip_nointerp(rest...)...)
2531
@inline skip_nointerp(::NoInterp, rest...) = skip_nointerp(rest...)
2632
skip_nointerp() = ()

test/InterpolationTestUtils.jl

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
module InterpolationTestUtils
22

3-
using Test, Interpolations
3+
using Test, Interpolations, ForwardDiff, StaticArrays
44
using Interpolations: degree, itpflag, bounds, lbounds, ubounds
55
using Interpolations: substitute
66

7-
export check_axes, check_inbounds_values, check_oob, can_eval_near_boundaries
7+
export check_axes, check_inbounds_values, check_oob, can_eval_near_boundaries,
8+
check_gradient, check_hessian
89
export MyPair
910

1011
const failstore = Ref{Any}(nothing) # stash the inputs to failing tests here
@@ -105,7 +106,34 @@ function can_eval_near_boundaries(itp::AbstractInterpolation)
105106
end
106107
end
107108

108-
# Used for multi-valued tests
109+
# Generate a grid of points [1.0, 1.3333, 1.6667, 2.0, 2.3333, ...] along each coordinate
110+
thirds(axs) = Iterators.product(_thirds(axs...)...)
111+
112+
_thirds(a, axs...) =
113+
(sort(Float64[a; (first(a):last(a)-1) .+ 1/3; (first(a)+1:last(a)) .- 1/3]), _thirds(axs...)...)
114+
_thirds() = ()
115+
116+
function check_gradient(itp::AbstractInterpolation, gtmp)
117+
val(x) = itp(Tuple(x)...)
118+
g!(gstore, x) = ForwardDiff.gradient!(gstore, val, x)
119+
gtmp2 = similar(gtmp)
120+
for i in thirds(axes(itp))
121+
@test Interpolations.gradient(itp, i...) g!(gtmp, SVector(i))
122+
@test Interpolations.gradient!(gtmp2, itp, i...) gtmp
123+
end
124+
end
125+
126+
function check_hessian(itp::AbstractInterpolation, htmp)
127+
val(x) = itp(Tuple(x)...)
128+
h!(hstore, x) = ForwardDiff.hessian!(hstore, val, x)
129+
htmp2 = similar(htmp)
130+
for i in thirds(axes(itp))
131+
@test Interpolations.hessian(itp, i...) h!(htmp, SVector(i))
132+
@test Interpolations.hessian!(htmp2, itp, i...) htmp
133+
end
134+
end
135+
136+
## A type used for multi-valued tests
109137
import Base: +, -, *, /,
110138

111139
struct MyPair{T}

test/gradient.jl

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,40 @@ using Test, Interpolations, DualNumbers, LinearAlgebra
33
@testset "Gradients" begin
44
nx = 10
55
f1(x) = sin((x-3)*2pi/(nx-1) - 1)
6-
g1(x) = 2pi/(nx-1) * cos((x-3)*2pi/(nx-1) - 1)
6+
g1gt(x) = 2pi/(nx-1) * cos((x-3)*2pi/(nx-1) - 1)
7+
A1 = Float64[f1(x) for x in 1:nx]
8+
g1 = Array{Float64}(undef, 1)
9+
A2 = rand(Float64, nx, nx) * 100
10+
g2 = Array{Float64}(undef, 2)
11+
12+
for (A, g) in ((A1, g1), (A2, g2))
13+
# Gradient of Constant should always be 0
14+
itp = interpolate(A, BSpline(Constant()), OnGrid())
15+
for x in InterpolationTestUtils.thirds(axes(A))
16+
@test all(iszero, Interpolations.gradient(itp, x...))
17+
@test all(iszero, Interpolations.gradient!(g, itp, x...))
18+
end
719

8-
# Gradient of Constant should always be 0
9-
itp1 = interpolate(Float64[f1(x) for x in 1:nx],
10-
BSpline(Constant()), OnGrid())
20+
for GT in (OnGrid, OnCell)
21+
itp = interpolate(A, BSpline(Linear()), GT())
22+
check_gradient(itp, g)
23+
I = first(eachindex(itp))
24+
@test Interpolations.gradient(itp, I) == Interpolations.gradient(itp, Tuple(I)...)
25+
end
1126

12-
g = Array{Float64}(undef, 1)
27+
for BC in (Flat,Line,Free,Periodic,Reflect,Natural), GT in (OnGrid, OnCell)
28+
itp = interpolate(A, BSpline(Quadratic(BC())), GT())
29+
check_gradient(itp, g)
30+
I = first(eachindex(itp))
31+
@test Interpolations.gradient(itp, I) == Interpolations.gradient(itp, Tuple(I)...)
32+
end
1333

14-
for x in 1:nx
15-
@test Interpolations.gradient(itp1, x)[1] == 0
16-
@test Interpolations.gradient!(g, itp1, x)[1] == 0
17-
@test g[1] == 0
34+
for BC in (Line, Flat, Free, Periodic), GT in (OnGrid, OnCell)
35+
itp = interpolate(A, BSpline(Cubic(BC())), GT())
36+
check_gradient(itp, g)
37+
I = first(eachindex(itp))
38+
@test Interpolations.gradient(itp, I) == Interpolations.gradient(itp, Tuple(I)...)
39+
end
1840
end
1941

2042
# Since Linear is OnGrid in the domain, check the gradients between grid points
@@ -24,9 +46,9 @@ using Test, Interpolations, DualNumbers, LinearAlgebra
2446
# Gridded(Linear()))
2547
for itp in (itp1, )#itp2)
2648
for x in 2.5:nx-1.5
27-
@test (g1(x),(Interpolations.gradient(itp,x))[1],atol=abs(0.1 * g1(x)))
28-
@test (g1(x),(Interpolations.gradient!(g,itp,x))[1],atol=abs(0.1 * g1(x)))
29-
@test (g1(x),g[1],atol=abs(0.1 * g1(x)))
49+
@test (g1gt(x),(Interpolations.gradient(itp,x))[1],atol=abs(0.1 * g1gt(x)))
50+
@test (g1gt(x),(Interpolations.gradient!(g1,itp,x))[1],atol=abs(0.1 * g1gt(x)))
51+
@test (g1gt(x),g1[1],atol=abs(0.1 * g1gt(x)))
3052
end
3153

3254
for i = 1:10
@@ -52,9 +74,9 @@ using Test, Interpolations, DualNumbers, LinearAlgebra
5274
itp1 = interpolate(Float64[f1(x) for x in 1:nx-1],
5375
BSpline(Quadratic(Periodic())), OnCell())
5476
for x in 2:nx-1
55-
@test (g1(x),(Interpolations.gradient(itp1,x))[1],atol=abs(0.05 * g1(x)))
56-
@test (g1(x),(Interpolations.gradient!(g,itp1,x))[1],atol=abs(0.05 * g1(x)))
57-
@test (g1(x),g[1],atol=abs(0.1 * g1(x)))
77+
@test (g1gt(x),(Interpolations.gradient(itp1,x))[1],atol=abs(0.05 * g1gt(x)))
78+
@test (g1gt(x),(Interpolations.gradient!(g1,itp1,x))[1],atol=abs(0.05 * g1gt(x)))
79+
@test (g1gt(x),g1[1],atol=abs(0.1 * g1gt(x)))
5880
end
5981

6082
for i = 1:10

test/hessian.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using Test, Interpolations, LinearAlgebra
2+
3+
@testset "Hessians" begin
4+
nx = 5
5+
k = 2pi/(nx-1)
6+
f1(x) = sin(k*(x-3) - 1)
7+
A1 = Float64[f1(x) for x in 1:nx]
8+
h1 = Array{Float64}(undef, 1, 1)
9+
A2 = rand(Float64, nx, nx) * 100
10+
h2 = Array{Float64}(undef, 2, 2)
11+
12+
for (A, h) in ((A1, h1), (A2, h2))
13+
for GT in (OnGrid, OnCell)
14+
for itp in (interpolate(A, BSpline(Constant()), GT()),
15+
interpolate(A, BSpline(Linear()), GT()))
16+
if ndims(A) == 1
17+
# Hessian of Constant and Linear should always be 0 in 1d
18+
for x in InterpolationTestUtils.thirds(axes(A))
19+
@test all(iszero, Interpolations.hessian(itp, x...))
20+
@test all(iszero, Interpolations.hessian!(h, itp, x...))
21+
end
22+
else
23+
for x in InterpolationTestUtils.thirds(axes(A))
24+
check_hessian(itp, h)
25+
end
26+
end
27+
end
28+
end
29+
30+
for BC in (Flat,Line,Free,Periodic,Reflect,Natural), GT in (OnGrid, OnCell)
31+
itp = interpolate(A, BSpline(Quadratic(BC())), GT())
32+
check_hessian(itp, h)
33+
I = first(eachindex(itp))
34+
@test Interpolations.hessian(itp, I) == Interpolations.hessian(itp, Tuple(I)...)
35+
end
36+
37+
for BC in (Line, Flat, Free, Periodic), GT in (OnGrid, OnCell)
38+
itp = interpolate(A, BSpline(Cubic(BC())), GT())
39+
check_hessian(itp, h)
40+
end
41+
end
42+
43+
# TODO: mixed interpolation (see gradient.jl)
44+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ using Interpolations
2323

2424
# # test gradient evaluation
2525
include("gradient.jl")
26+
# # test hessian evaluation
27+
include("hessian.jl")
2628

2729
# # gridded interpolation tests
2830
# include("gridded/runtests.jl")

0 commit comments

Comments
 (0)