Skip to content

Commit 69dc2c5

Browse files
committed
Re-implement GriddedInterpolation
1 parent 42fa671 commit 69dc2c5

File tree

10 files changed

+153
-218
lines changed

10 files changed

+153
-218
lines changed

src/Interpolations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ using LinearAlgebra, SparseArrays
3333
using StaticArrays, WoodburyMatrices, Ratios, AxisAlgorithms, OffsetArrays
3434

3535
using Base: @propagate_inbounds
36-
import Base: convert, size, axes, promote_rule, ndims, eltype, checkbounds
36+
import Base: convert, size, axes, promote_rule, ndims, eltype, checkbounds, axes1
3737

3838
abstract type Flag end
3939
abstract type InterpolationType <: Flag end
@@ -300,7 +300,7 @@ end
300300

301301
include("nointerp/nointerp.jl")
302302
include("b-splines/b-splines.jl")
303-
# include("gridded/gridded.jl")
303+
include("gridded/gridded.jl")
304304
include("extrapolation/extrapolation.jl")
305305
include("scaling/scaling.jl")
306306
include("utils.jl")

src/b-splines/indexing.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ end
8888
end
8989

9090

91-
function weightedindexes(fs::F, itpflags::NTuple{N,Flag}, axs::NTuple{N,AbstractUnitRange}, xs::NTuple{N,Number}) where {F,N}
92-
parts = map((flag, ax, x)->weightedindex_parts(fs, flag, ax, x), itpflags, axs, xs)
91+
function weightedindexes(fs::F, itpflags::NTuple{N,Flag}, knots::NTuple{N,AbstractVector}, xs::NTuple{N,Number}) where {F,N}
92+
parts = map((flag, knotvec, x)->weightedindex_parts(fs, flag, knotvec, x), itpflags, knots, xs)
9393
weightedindexes(parts...)
9494
end
9595

@@ -179,7 +179,7 @@ slot_substitute(kind1::Tuple{}, kind2::Tuple{}, p, v, g, h) = ()
179179
weightedindex_parts(fs::F, itpflag::BSpline, ax, x) where F =
180180
weightedindex_parts(fs, degree(itpflag), ax, x)
181181

182-
function weightedindex_parts(fs::F, deg::Degree, ax, x) where F
182+
function weightedindex_parts(fs::F, deg::Degree, ax::AbstractUnitRange{<:Integer}, x) where F
183183
pos, δx = positions(deg, ax, x)
184184
(position=pos, coefs=fmap(fs, deg, δx))
185185
end
@@ -198,16 +198,22 @@ function getindex_return_type(::Type{BSplineInterpolation{T,N,TCoefs,IT,Axs}}, :
198198
end
199199

200200
# This handles round-towards-the-middle for points on half-integer edges
201-
roundbounds(x::Integer, bounds) = x
202-
function roundbounds(x, bounds)
201+
roundbounds(x::Integer, bounds::Tuple{Real,Real}) = x
202+
roundbounds(x::Integer, bounds::AbstractUnitRange) = x
203+
roundbounds(x::Number, bounds::Tuple{Real,Real}) = _roundbounds(x, bounds)
204+
roundbounds(x::Number, bounds::AbstractUnitRange) = _roundbounds(x, bounds)
205+
function _roundbounds(x::Number, bounds::Union{Tuple{Real,Real}, AbstractUnitRange})
203206
l, u = first(bounds), last(bounds)
204207
h = half(x)
205208
xh = x+h
206209
ifelse(x < u+half(u), floor(xh), ceil(xh)-1)
207210
end
208211

209-
floorbounds(x::Integer, ax) = x
210-
function floorbounds(x, ax)
212+
floorbounds(x::Integer, ax::Tuple{Real,Real}) = x
213+
floorbounds(x::Integer, ax::AbstractUnitRange) = x
214+
floorbounds(x, ax::Tuple{Real,Real}) = _floorbounds(x, ax)
215+
floorbounds(x, ax::AbstractUnitRange) = _floorbounds(x, ax)
216+
function _floorbounds(x, ax::Union{Tuple{Real,Real}, AbstractUnitRange})
211217
l = first(ax)
212218
h = half(x)
213219
ifelse(x < l, floor(x+h), floor(x+zero(h)))

src/b-splines/linear.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ a piecewise linear function connecting each pair of neighboring data points.
2121
"""
2222
Linear
2323

24-
function positions(::Linear, ax, x)
24+
function positions(::Linear, ax::AbstractUnitRange{<:Integer}, x)
2525
f = floor(x)
2626
# When x == last(ax) we want to use the x-1, x pair
2727
f = ifelse(x == last(ax), f - oneunit(f), f)

src/deprecations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# deprecate getindex for non-integer numeric indices
22
@deprecate getindex(itp::AbstractInterpolation{T,N}, i::Vararg{Number,N}) where {T,N} itp(i...)
3+
@deprecate getindex(itp::AbstractInterpolation{T,N}, i::Vararg{ExpandedIndexTypes,N}) where {T,N} itp(i...)
34

45
for T in (:Throw, :Flat, :Line, :Free, :Periodic, :Reflect, :InPlace, :InPlaceQ)
56
@eval begin

src/gridded/constant.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
function base_rem(::Constant, knotv, ki, x)
2+
l, u = knotv[ki], knotv[ki+1]
3+
4+
xm = roundbounds(x, bounds)
5+
δx = x - xm
6+
fast_trunc(Int, xm), δx
7+
end
8+
19
function define_indices_d(::Type{Gridded{Constant}}, d, pad)
210
symix, symx = Symbol("ix_",d), Symbol("x_",d)
311
symk, symkix = Symbol("k_",d), Symbol("kix_",d)

src/gridded/gridded.jl

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,22 @@
11
export Gridded
22

3-
struct Gridded{D<:Degree} <: InterpolationType end
4-
Gridded(::D) where {D<:Degree} = Gridded{D}()
5-
6-
griddedtype(::Type{Gridded{D}}) where {D<:Degree} = D
3+
struct Gridded{D<:Degree} <: InterpolationType
4+
degree::D
5+
end
76

87
const GridIndex{T} = Union{AbstractVector{T}, Tuple}
98

10-
# Because Ranges check bounds on getindex, it's actually faster to convert the
11-
# knots to Vectors. It's also good to take a copy, so it doesn't get modified later.
12-
struct GriddedInterpolation{T,N,TCoefs,IT<:DimSpec{Gridded},K<:Tuple{Vararg{AbstractVector}}} <: AbstractInterpolation{T,N,IT,OnGrid}
9+
struct GriddedInterpolation{T,N,TCoefs,IT<:DimSpec{Gridded},K<:Tuple{Vararg{AbstractVector}}} <: AbstractInterpolation{T,N,IT}
1310
knots::K
1411
coefs::Array{TCoefs,N}
12+
it::IT
1513
end
1614
function GriddedInterpolation(::Type{TWeights}, knots::NTuple{N,GridIndex}, A::AbstractArray{TCoefs,N}, it::IT) where {N,TCoefs,TWeights<:Real,IT<:DimSpec{Gridded},pad}
1715
isconcretetype(IT) || error("The b-spline type must be a leaf type (was $IT)")
1816
isconcretetype(TCoefs) || warn("For performance reasons, consider using an array of a concrete type (eltype(A) == $(eltype(A)))")
1917

20-
knts = mapcollect(knots...)
21-
for (d,k) in enumerate(knts)
22-
length(k) == size(A, d) || throw(DimensionMismatch("knot vectors must have the same number of elements as the corresponding dimension of the array"))
23-
length(k) == 1 && error("dimensions of length 1 not yet supported") # FIXME
24-
issorted(k) || error("knot-vectors must be sorted in increasing order")
25-
iextract(IT, d) != NoInterp || k == collect(1:size(A, d)) || error("knot-vector should be the range 1:$(size(A,d)) for the method Gridded{NoInterp}")
26-
end
18+
check_gridded(it, knots, axes(A))
2719
c = zero(TWeights)
28-
for _ in 2:N
29-
c *= c
30-
end
3120
if isempty(A)
3221
T = Base.promote_op(*, typeof(c), eltype(A))
3322
else
@@ -36,39 +25,44 @@ function GriddedInterpolation(::Type{TWeights}, knots::NTuple{N,GridIndex}, A::A
3625
GriddedInterpolation{T,N,TCoefs,IT,typeof(knots)}(knots, A, it)
3726
end
3827

39-
Base.parent(A::GriddedInterpolation) = A.coefs
28+
@inline function check_gridded(itpflag, knots, axs)
29+
flag, ax1, k1 = getfirst(itpflag), axs[1], knots[1]
30+
if flag isa NoInterp
31+
k1 == ax1 || error("for NoInterp knot vector should be $ax1, got $k1")
32+
else
33+
axes(k1, 1) == ax1 || throw(DimensionMismatch("knot vectors must have the same axes as the corresponding dimension of the array"))
34+
end
35+
degree(flag) isa Union{NoInterp,Constant,Linear} || error("only Linear, Constant, and NoInterp supported, got $flag")
36+
length(k1) == 1 && error("dimensions of length 1 not yet supported") # FIXME
37+
issorted(k1) || error("knot-vectors must be sorted in increasing order")
38+
check_gridded(getrest(itpflag), Base.tail(knots), Base.tail(axs))
39+
end
40+
check_gridded(::Any, ::Tuple{}, ::Tuple{}) = nothing
41+
degree(flag::Gridded) = flag.degree
4042

41-
# A type-stable version of map(collect, knots)
42-
mapcollect() = ()
43-
@inline mapcollect(k::AbstractVector) = (collect(k),)
44-
@inline mapcollect(k1::AbstractVector, k2::AbstractVector...) = (collect(k1), mapcollect(k2...)...)
43+
Base.parent(A::GriddedInterpolation) = A.coefs
44+
coefficients(A::GriddedInterpolation) = A.coefs
4545

46-
# Utilities for working either with scalars or tuples/tuple-types
47-
iextract(::Type{T}, d) where {T<:Gridded} = T
48-
iextract(::Type{T}, d) where {T<:GridType} = T
46+
size(A::GriddedInterpolation) = size(A.coefs)
47+
axes(A::GriddedInterpolation) = axes(A.coefs)
4948

50-
@generated function size(itp::GriddedInterpolation{T,N,TCoefs,IT,K,pad}, d) where {T,N,TCoefs,IT,K,pad}
51-
quote
52-
d <= $N ? size(itp.coefs, d) - 2*padextract($pad, d) : 1
53-
end
54-
end
49+
itpflag(A::GriddedInterpolation) = A.it
5550

5651
function interpolate(::Type{TWeights}, ::Type{TCoefs}, knots::NTuple{N,GridIndex}, A::AbstractArray{Tel,N}, it::IT) where {TWeights,TCoefs,Tel,N,IT<:DimSpec{Gridded}}
57-
GriddedInterpolation(TWeights, knots, A, it, Val{0}())
52+
GriddedInterpolation(TWeights, knots, A, it)
5853
end
5954
function interpolate(knots::NTuple{N,GridIndex}, A::AbstractArray{Tel,N}, it::IT) where {Tel,N,IT<:DimSpec{Gridded}}
6055
interpolate(tweight(A), tcoef(A), knots, A, it)
6156
end
6257

63-
interpolate!(::Type{TWeights}, knots::NTuple{N,GridIndex}, A::AbstractArray{Tel,N}, it::IT) where {TWeights,Tel,N,IT<:DimSpec{Gridded}} = GriddedInterpolation(TWeights, knots, A, it, Val{0}())
58+
interpolate!(::Type{TWeights}, knots::NTuple{N,GridIndex}, A::AbstractArray{Tel,N}, it::IT) where {TWeights,Tel,N,IT<:DimSpec{Gridded}} =
59+
GriddedInterpolation(TWeights, knots, A, it)
6460
function interpolate!(knots::NTuple{N,GridIndex}, A::AbstractArray{Tel,N}, it::IT) where {Tel,N,IT<:DimSpec{Gridded}}
6561
interpolate!(tweight(A), tcoef(A), knots, A, it)
6662
end
6763

68-
lbound(itp::GriddedInterpolation, d) = itp.knots[d][1]
69-
ubound(itp::GriddedInterpolation, d) = itp.knots[d][end]
70-
lbound(itp::GriddedInterpolation, d, inds) = itp.knots[d][1]
71-
ubound(itp::GriddedInterpolation, d, inds) = itp.knots[d][end]
64+
lbounds(itp::GriddedInterpolation) = first.(itp.knots)
65+
ubounds(itp::GriddedInterpolation) = last.(itp.knots)
7266

7367
include("constant.jl")
7468
include("linear.jl")

0 commit comments

Comments
 (0)