Skip to content

Commit 54552ad

Browse files
committed
Introduce and implement WeightedIndex
1 parent adf5571 commit 54552ad

File tree

14 files changed

+333
-354
lines changed

14 files changed

+333
-354
lines changed

src/Interpolations.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,80 @@ count_interp_dims(it::Type{IT}, n) where IT<:Tuple{Vararg{InterpolationType,N}}
127127
_count_interp_dims(c + count_interp_dims(IT1), args...)
128128
_count_interp_dims(c) = c
129129

130+
131+
"""
132+
wi = WeightedIndex(indexes, weights)
133+
134+
Construct a weighted index `wi`, which can be thought of as a generalization of an
135+
ordinary array index to the context of interpolation.
136+
For an ordinary vector `a`, `a[i]` extracts the element at index `i`.
137+
When interpolating, one is typically interested in a range of indexes and the output is
138+
some weighted combination of array values at these indexes.
139+
For example, for linear interpolation between `i` and `i+1` we have
140+
141+
ret = (1-f)*a[i] + f*a[i]
142+
143+
This can be represented `a[wi]`, where
144+
145+
wi = WeightedIndex(i:i+1, (1-f, f))
146+
147+
i.e.,
148+
149+
ret = sum(a[indexes] .* weights)
150+
151+
Linear interpolation thus constructs weighted indices using a 2-tuple for `weights` and
152+
a length-2 `indexes` range.
153+
Higher-order interpolation would involve more positions and weights (e.g., 3-tuples for
154+
quadratic interpolation, 4-tuples for cubic).
155+
156+
In multiple dimensions, separable interpolation schemes are implemented in terms
157+
of multiple weighted indices, accessing `A[wi1, wi2, ...]` where each `wi` is the
158+
`WeightedIndex` along the corresponding dimension.
159+
160+
For value interpolation, `weights` will typically sum to 1.
161+
However, for gradient and Hessian computation this will not necessarily be true.
162+
For example, the gradient of one-dimensional linear interpolation can be represented as
163+
164+
gwi = WeightedIndex(i:i+1, (-1, 1))
165+
g1 = a[gwi]
166+
167+
For a three-dimensional array `A`, one might compute `∂A/∂x₂` (the second component
168+
of the gradient) as `A[wi1, gwi2, wi3]`, where `wi1` and `wi3` are "value" weights
169+
and `gwi2` "gradient" weights.
170+
171+
`indexes` may be supplied as a range or as a tuple of the same length as `weights`.
172+
The latter is applicable, e.g., for periodic boundary conditions.
173+
"""
174+
abstract type WeightedIndex{L,W} end
175+
176+
# Type to use when array locations are adjacent. This may offer more opportunities
177+
# for compiler optimizations (e.g., SIMD).
178+
struct WeightedAdjIndex{L,W} <: WeightedIndex{L,W}
179+
istart::Int
180+
weights::NTuple{L,W}
181+
end
182+
# Type to use with non-adjacent locations. E.g., periodic boundary conditions.
183+
struct WeightedArbIndex{L,W} <: WeightedIndex{L,W}
184+
indexes::NTuple{L,Int}
185+
weights::NTuple{L,W}
186+
end
187+
188+
function WeightedIndex(indexes::AbstractUnitRange{<:Integer}, weights::NTuple{L,Any}) where L
189+
@noinline mismatch(indexes, weights) = throw(ArgumentError("the length of indexes must match weights, got $indexes vs $weights"))
190+
length(indexes) == L || mismatch(indexes, weights)
191+
WeightedAdjIndex(first(indexes), promote(weights...))
192+
end
193+
WeightedIndex(istart::Integer, weights::NTuple{L,Any}) where L =
194+
WeightedAdjIndex(istart, promote(weights...))
195+
WeightedIndex(indexes::NTuple{L,Integer}, weights::NTuple{L,Any}) where L =
196+
WeightedArbIndex(indexes, promote(weights...))
197+
198+
weights(wi::WeightedIndex) = wi.weights
199+
indexes(wi::WeightedAdjIndex) = wi.istart
200+
indexes(wi::WeightedArbIndex) = wi.indexes
201+
202+
203+
130204
"""
131205
w = value_weights(degree, δx)
132206

src/b-splines/b-splines.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ _ubounds(axs, itp) = (ubound(axs[1], getfirst(itp)), _ubounds(Base.tail(axs), ge
8080
_lbounds(::Tuple{}, itp) = ()
8181
_ubounds(::Tuple{}, itp) = ()
8282

83-
lbound(ax::AbstractUnitRange, bs::BSpline) = lbound(ax, degree(bs))
84-
lbound(ax::AbstractUnitRange, deg::Degree) = first(ax)
85-
lbound(ax::AbstractUnitRange, deg::DegreeBC) = lbound(ax, deg, deg.bc.gt)
86-
ubound(ax::AbstractUnitRange, bs::BSpline) = ubound(ax, degree(bs))
87-
ubound(ax::AbstractUnitRange, deg::Degree) = last(ax)
88-
ubound(ax::AbstractUnitRange, deg::DegreeBC) = ubound(ax, deg, deg.bc.gt)
83+
lbound(ax::AbstractRange, bs::BSpline) = lbound(ax, degree(bs))
84+
lbound(ax::AbstractRange, deg::Degree) = first(ax)
85+
lbound(ax::AbstractRange, deg::DegreeBC) = lbound(ax, deg, deg.bc.gt)
86+
ubound(ax::AbstractRange, bs::BSpline) = ubound(ax, degree(bs))
87+
ubound(ax::AbstractRange, deg::Degree) = last(ax)
88+
ubound(ax::AbstractRange, deg::DegreeBC) = ubound(ax, deg, deg.bc.gt)
8989

9090
lbound(ax::AbstractUnitRange, ::DegreeBC, ::OnCell) = first(ax) - 0.5
9191
ubound(ax::AbstractUnitRange, ::DegreeBC, ::OnCell) = last(ax) + 0.5

src/b-splines/constant.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,18 @@ struct Constant <: Degree{0} end
22

33
"""
44
Constant b-splines are *nearest-neighbor* interpolations, and effectively
5-
return `A[round(Int,x)]` when interpolating
5+
return `A[round(Int,x)]` when interpolating.
66
"""
77
Constant
88

9-
function base_rem(::Constant, bounds, x)
10-
xm = roundbounds(x, bounds)
9+
function positions(::Constant, ax, x) # discontinuity occurs at half-integer locations
10+
xm = roundbounds(x, ax)
1111
δx = x - xm
1212
fast_trunc(Int, xm), δx
1313
end
1414

15-
expand_index(::Constant, xi::Number, ax::AbstractUnitRange, δx) = (xi,)
16-
17-
value_weights(::Constant, δx) = (oneunit(δx),)
18-
gradient_weights(::Constant, δx) = (zero(δx),)
19-
hessian_weights(::Constant, δx) = (zero(δx),)
15+
value_weights(::Constant, δx) = (1,)
16+
gradient_weights(::Constant, δx) = (0,)
17+
hessian_weights(::Constant, δx) = (0,)
2018

2119
padded_axis(ax::AbstractUnitRange, ::BSpline{Constant}) = ax

src/b-splines/cubic.jl

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,44 +27,34 @@ When we derive boundary conditions we will use derivatives `y_0'(x)` and
2727
"""
2828
Cubic
2929

30-
function base_rem(::Cubic, bounds, x)
31-
xf = floorbounds(x, bounds)
32-
xf -= ifelse(xf > last(bounds)-1, oneunit(xf), zero(xf))
30+
function positions(deg::Cubic, ax, x)
31+
xf = floorbounds(x, ax)
32+
xf -= ifelse(xf > last(ax)-1, oneunit(xf), zero(xf))
3333
δx = x - xf
34-
fast_trunc(Int, xf), δx
34+
expand_index(deg, fast_trunc(Int, xf), ax, δx), δx
3535
end
3636

37-
expand_index(::Cubic{BC}, xi::Number, ax::AbstractUnitRange, δx) where BC = (xi-1, xi, xi+1, xi+2)
37+
expand_index(::Cubic{BC}, xi::Number, ax::AbstractUnitRange, δx) where BC = xi-1
3838
expand_index(::Cubic{Periodic{GT}}, xi::Number, ax::AbstractUnitRange, δx) where GT<:GridType =
3939
(modrange(xi-1, ax), modrange(xi, ax), modrange(xi+1, ax), modrange(xi+2, ax))
4040

41-
# expand_coefs(::Type{BSpline{Cubic{BC}}}, δx) = cvcoefs(δx)
42-
# expand_coefs(::Type{BSpline{Cubic{BC}}}, dref, d, δx) = ifelse(d==dref, cgcoefs(δx), cvcoefs(δx))
43-
# function expand_coefs(::Type{BSpline{Cubic{BC}}}, dref1, dref2, d, δx)
44-
# if dref1 == dref2
45-
# d == dref1 ? chcoefs(δx) : cvcoefs(δx)
46-
# else
47-
# d == dref1 | d == dref2 ? cgcoefs(δx) : cvcoefs(δx)
48-
# end
49-
# end
50-
51-
function value_weights(::BSpline{<:Cubic}, δx)
41+
function value_weights(::Cubic, δx)
5242
x3, xcomp3 = cub(δx), cub(1-δx)
5343
(SimpleRatio(1,6) * xcomp3,
5444
SimpleRatio(2,3) - sqr(δx) + SimpleRatio(1,2)*x3,
5545
SimpleRatio(2,3) - sqr(1-δx) + SimpleRatio(1,2)*xcomp3,
5646
SimpleRatio(1,6) * x3)
5747
end
5848

59-
function gradient_weights(::BSpline{<:Cubic}, δx)
49+
function gradient_weights(::Cubic, δx)
6050
x2, xcomp2 = sqr(δx), sqr(1-δx)
6151
(-SimpleRatio(1,2) * xcomp2,
6252
-2*δx + SimpleRatio(3,2)*x2,
6353
+2*(1-δx) - SimpleRatio(3,2)*xcomp2,
6454
SimpleRatio(1,2) * x2)
6555
end
6656

67-
hessian_weights(::BSpline{<:Cubic}, δx) = (1-δx, 3*δx-2, 3*(1-δx)-2, δx)
57+
hessian_weights(::Cubic, δx) = (1-δx, 3*δx-2, 3*(1-δx)-2, δx)
6858

6959

7060
# ------------ #

0 commit comments

Comments
 (0)