Skip to content

Commit 42fa671

Browse files
committed
Re-implement fast iteration for ScaledInterpolation
1 parent e6c8a54 commit 42fa671

File tree

3 files changed

+123
-213
lines changed

3 files changed

+123
-213
lines changed

src/scaling/scaling.jl

Lines changed: 87 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -31,30 +31,38 @@ check_ranges(::Any, ::Tuple{}, ::Tuple{}) = nothing
3131
check_range(::NoInterp, ax, r) = ax == r || throw(ArgumentError("The range $r did not equal the corresponding axis of the interpolation object $ax"))
3232
check_range(::Any, ax, r) = length(ax) == length(r) || throw(ArgumentError("The range $r is incommensurate with the corresponding axis $ax"))
3333

34+
# With regards to size and [], ScaledInterpolation behaves like the underlying interpolation object
3435
size(sitp::ScaledInterpolation) = size(sitp.itp)
3536
axes(sitp::ScaledInterpolation) = axes(sitp.itp)
3637

38+
@propagate_inbounds function Base.getindex(sitp::ScaledInterpolation{T,N}, i::Vararg{Int,N}) where {T,N}
39+
sitp.itp[i...]
40+
end
41+
3742
lbounds(sitp::ScaledInterpolation) = _lbounds(sitp.ranges, itpflag(sitp.itp))
3843
ubounds(sitp::ScaledInterpolation) = _ubounds(sitp.ranges, itpflag(sitp.itp))
3944

4045
boundstep(r::StepRange) = r.step / 2
4146
boundstep(r::UnitRange) = 1//2
42-
43-
lbound(ax::AbstractRange, ::DegreeBC, ::OnCell) = first(ax) - boundstep(ax)
44-
ubound(ax::AbstractRange, ::DegreeBC, ::OnCell) = last(ax) + boundstep(ax)
45-
lbound(ax::AbstractRange, ::DegreeBC, ::OnGrid) = first(ax)
46-
ubound(ax::AbstractRange, ::DegreeBC, ::OnGrid) = last(ax)
47-
4847
"""
4948
Returns *half* the width of one step of the range.
5049
5150
This function is used to calculate the upper and lower bounds of `OnCell` interpolation objects.
5251
""" boundstep
5352

53+
lbound(ax::AbstractRange, ::DegreeBC, ::OnCell) = first(ax) - boundstep(ax)
54+
ubound(ax::AbstractRange, ::DegreeBC, ::OnCell) = last(ax) + boundstep(ax)
55+
lbound(ax::AbstractRange, ::DegreeBC, ::OnGrid) = first(ax)
56+
ubound(ax::AbstractRange, ::DegreeBC, ::OnGrid) = last(ax)
57+
58+
# For (), we scale the evaluation point
5459
function (sitp::ScaledInterpolation{T,N})(xs::Vararg{Number,N}) where {T,N}
5560
xl = coordslookup(itpflag(sitp.itp), sitp.ranges, xs)
5661
sitp.itp(xl...)
5762
end
63+
@inline function (sitp::ScaledInterpolation)(x::Vararg{UnexpandedIndexTypes})
64+
sitp(to_indices(sitp, x)...)
65+
end
5866

5967
(sitp::ScaledInterpolation{T,1}, x::Number, y::Int) where {T} = y == 1 ? sitp(x) : Base.throw_boundserror(sitp, (x, y))
6068

@@ -134,167 +142,85 @@ rescale_gradient(r::UnitRange, g) = g
134142
Implements the chain rule dy/dx = dy/du * du/dx for use when calculating gradients with scaled interpolation objects.
135143
""" rescale_gradient
136144

145+
### Iteration
137146

138-
# ### Iteration
139-
# mutable struct ScaledIterator{CR<:CartesianIndices,SITPT,X1,Deg,T}
140-
# rng::CR
141-
# sitp::SITPT
142-
# dx_1::X1
143-
# nremaining::Int
144-
# fx_1::X1
145-
# itp_tail::NTuple{Deg,T}
146-
# end
147-
148-
# nelements(::Union{Type{NoInterp},Type{Constant}}) = 1
149-
# nelements(::Type{Linear}) = 2
150-
# nelements(::Type{Q}) where {Q<:Quadratic} = 3
151-
152-
# eachvalue_zero(::Type{R}, ::Type{BT}) where {R,BT<:Union{Type{NoInterp},Type{Constant}}} =
153-
# (zero(R),)
154-
# eachvalue_zero(::Type{R}, ::Type{Linear}) where {R} = (zero(R),zero(R))
155-
# eachvalue_zero(::Type{R}, ::Type{Q}) where {R,Q<:Quadratic} = (zero(R),zero(R),zero(R))
156-
157-
# """
158-
# `eachvalue(sitp)` constructs an iterator for efficiently visiting each
159-
# grid point of a ScaledInterpolation object in which a small grid is
160-
# being "scaled up" to a larger one. For example, suppose you have a
161-
# core `BSpline` object defined on a 5x7x4 grid, and you are scaling it
162-
# to a 100x120x20 grid (via `linspace(1,5,100), linspace(1,7,120),
163-
# linspace(1,4,20)`). You can perform interpolation at each of these
164-
# grid points via
165-
166-
# ```
167-
# function foo!(dest, sitp)
168-
# i = 0
169-
# for s in eachvalue(sitp)
170-
# dest[i+=1] = s
171-
# end
172-
# dest
173-
# end
174-
# ```
175-
176-
# which should be more efficient than
177-
178-
# ```
179-
# function bar!(dest, sitp)
180-
# for I in CartesianIndices(size(dest))
181-
# dest[I] = sitp[I]
182-
# end
183-
# dest
184-
# end
185-
# ```
186-
# """
187-
# function eachvalue(sitp::ScaledInterpolation{T,N}) where {T,N}
188-
# ITPT = basetype(sitp)
189-
# IT = itptype(ITPT)
190-
# R = getindex_return_type(ITPT, Int)
191-
# BT = bsplinetype(iextract(IT, 1))
192-
# itp_tail = eachvalue_zero(R, BT)
193-
# dx_1 = coordlookup(sitp.ranges[1], 2) - coordlookup(sitp.ranges[1], 1)
194-
# ScaledIterator(CartesianIndices(ssize(sitp)), sitp, dx_1, 0, zero(dx_1), itp_tail)
195-
# end
147+
struct ScaledIterator{SITPT,CI,WIS}
148+
sitp::SITPT # ScaledInterpolation object
149+
ci::CI # the CartesianIndices object
150+
wis::WIS # WeightedIndex vectors
151+
breaks1::Vector{Int} # breaks along dimension 1 where new evaluations must occur
152+
end
196153

197-
# function index_gen1(::Union{Type{NoInterp}, Type{BSpline{Constant}}})
198-
# quote
199-
# value = iter.itp_tail[1]
200-
# end
201-
# end
154+
Base.IteratorSize(::Type{ScaledIterator{SITPT,CI,WIS}}) where {SITPT,CI<:CartesianIndices{N},WIS} where N = Base.HasShape{N}()
155+
Base.axes(iter::ScaledIterator) = axes(iter.ci)
156+
Base.size(iter::ScaledIterator) = size(iter.ci)
202157

203-
# function index_gen1(::Type{BSpline{Linear}})
204-
# quote
205-
# p = iter.itp_tail
206-
# value = c_1*p[1] + cp_1*p[2]
207-
# end
208-
# end
158+
struct ScaledIterState{N,V}
159+
cistate::CartesianIndex{N}
160+
ibreak::Int
161+
cached_evaluations::NTuple{N,V}
162+
end
209163

210-
# function index_gen1(::Type{BSpline{Q}}) where Q<:Quadratic
211-
# quote
212-
# p = iter.itp_tail
213-
# value = cm_1*p[1] + c_1*p[2] + cp_1*p[3]
214-
# end
215-
# end
216-
# function index_gen_tail(B::Union{Type{NoInterp}, Type{BSpline{Constant}}}, ::Type{IT}, N) where IT
217-
# [index_gen(B, IT, N, 0)]
218-
# end
164+
function eachvalue(sitp::ScaledInterpolation{T,N}) where {T,N}
165+
itps = tcollect(itpflag, sitp.itp)
166+
newaxes = map(r->Base.Slice(ceil(Int, first(r)):floor(Int, last(r))), sitp.ranges)
167+
wis = dimension_wis(value_weights, itps, axes(sitp.itp), newaxes, sitp.ranges)
168+
wis1 = wis[1]
169+
i1 = first(axes(wis1, 1))
170+
breaks1 = [i1]
171+
for i in Iterators.drop(axes(wis1, 1), 1)
172+
if indexes(wis1[i]) != indexes(wis1[i-1])
173+
push!(breaks1, i)
174+
end
175+
end
176+
push!(breaks1, last(axes(wis1, 1))+1)
177+
ScaledIterator(sitp, CartesianIndices(newaxes), wis, breaks1)
178+
end
219179

220-
# function index_gen_tail(::Type{BSpline{Linear}}, ::Type{IT}, N) where IT
221-
# [index_gen(BS1, IT, N, i) for i = 0:1]
222-
# end
180+
function dimension_wis(f::F, itps, axs, newaxes, ranges) where F
181+
itpflag, ax, nax, r = itps[1], axs[1], newaxes[1], ranges[1]
182+
function makewi(x)
183+
pos, coefs = weightedindex_parts((f,), itpflag, ax, coordlookup(r, x))
184+
maybe_weightedindex(pos, coefs[1])
185+
end
186+
(makewi.(nax), dimension_wis(f, Base.tail(itps), Base.tail(axs), Base.tail(newaxes), Base.tail(ranges))...)
187+
end
188+
dimension_wis(f, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
189+
190+
function Base.iterate(iter::ScaledIterator)
191+
ret = iterate(iter.ci)
192+
ret === nothing && return nothing
193+
item, cistate = ret
194+
wis = getindex.(iter.wis, Tuple(item))
195+
ces = cache_evaluations(iter.sitp.itp.coefs, indexes(wis[1]), weights(wis[1]), Base.tail(wis))
196+
return _reduce(+, weights(wis[1]).*ces), ScaledIterState(cistate, first(iter.breaks1), ces)
197+
end
223198

224-
# function index_gen_tail(::Type{BSpline{Q}}, ::Type{IT}, N) where {IT,Q<:Quadratic}
225-
# [index_gen(BSpline{Q}, IT, N, i) for i = -1:1]
226-
# end
227-
# function nremaining_gen(::Union{Type{BSpline{Constant}}, Type{BSpline{Q}}}) where Q<:Quadratic
228-
# quote
229-
# EPS = 0.001*iter.dx_1
230-
# floor(Int, iter.dx_1 >= 0 ?
231-
# (min(length(range1)+EPS, round(Int,x_1) + 0.5) - x_1)/iter.dx_1 :
232-
# (max(1-EPS, round(Int,x_1) - 0.5) - x_1)/iter.dx_1)
233-
# end
234-
# end
199+
function Base.iterate(iter::ScaledIterator, state)
200+
ret = iterate(iter.ci, state.cistate)
201+
ret === nothing && return nothing
202+
item, cistate = ret
203+
i1 = item[1]
204+
isnext1 = i1 == state.cistate[1]+1
205+
if isnext1 && i1 < iter.breaks1[state.ibreak+1]
206+
# We can use the previously cached values
207+
wis1 = iter.wis[1][i1]
208+
return _reduce(+, weights(wis1).*state.cached_evaluations), ScaledIterState(cistate, state.ibreak, state.cached_evaluations)
209+
end
210+
# Re-evaluate. We're being a bit lazy here: in some cases, some of the cached values could be reused
211+
wis = getindex.(iter.wis, Tuple(item))
212+
ces = cache_evaluations(iter.sitp.itp.coefs, indexes(wis[1]), weights(wis[1]), Base.tail(wis))
213+
return _reduce(+, weights(wis[1]).*ces), ScaledIterState(cistate, isnext1 ? state.ibreak+1 : first(iter.breaks1), ces)
214+
end
235215

236-
# function nremaining_gen(::Type{BSpline{Linear}})
237-
# quote
238-
# EPS = 0.001*iter.dx_1
239-
# floor(Int, iter.dx_1 >= 0 ?
240-
# (min(length(range1)+EPS, floor(Int,x_1) + 1) - x_1)/iter.dx_1 :
241-
# (max(1-EPS, floor(Int,x_1)) - x_1)/iter.dx_1)
242-
# end
243-
# end
244-
# function next_gen(::Type{ScaledIterator{CR,SITPT,X1,Deg,T}}) where {CR,SITPT,X1,Deg,T}
245-
# N = ndims(CR)
246-
# ITPT = basetype(SITPT)
247-
# IT = itptype(ITPT)
248-
# BS1 = iextract(IT, 1)
249-
# BS1 == NoInterp && error("eachvalue is not implemented (and does not make sense) for NoInterp along the first dimension")
250-
# pad = padding(ITPT)
251-
# x_syms = [Symbol("x_", i) for i = 1:N]
252-
# interp_index(IT, i) = iextract(IT, i) != NoInterp ?
253-
# :($(x_syms[i]) = coordlookup(sitp.ranges[$i], state[$i])) :
254-
# :($(x_syms[i]) = state[$i])
255-
# # Calculations for the first dimension
256-
# interp_index1 = interp_index(IT, 1)
257-
# indices1 = define_indices_d(BS1, 1, padextract(pad, 1))
258-
# coefexprs1 = coefficients(BS1, N, 1)
259-
# nremaining_expr = nremaining_gen(BS1)
260-
# # Calculations for the rest of the dimensions
261-
# interp_indices_tail = map(i -> interp_index(IT, i), 2:N)
262-
# indices_tail = [define_indices_d(iextract(IT, i), i, padextract(pad, i)) for i = 2:N]
263-
# coefexprs_tail = [coefficients(iextract(IT, i), N, i) for i = 2:N]
264-
# value_exprs_tail = index_gen_tail(BS1, IT, N)
265-
# quote
266-
# sitp = iter.sitp
267-
# itp = sitp.itp
268-
# inds_itp = axes(itp)
269-
# if iter.nremaining > 0
270-
# iter.nremaining -= 1
271-
# iter.fx_1 += iter.dx_1
272-
# else
273-
# range1 = sitp.ranges[1]
274-
# $interp_index1
275-
# $indices1
276-
# iter.nremaining = $nremaining_expr
277-
# iter.fx_1 = fx_1
278-
# $(interp_indices_tail...)
279-
# $(indices_tail...)
280-
# $(coefexprs_tail...)
281-
# @inbounds iter.itp_tail = ($(value_exprs_tail...),)
282-
# end
283-
# fx_1 = iter.fx_1
284-
# $coefexprs1
285-
# $(index_gen1(BS1))
286-
# end
287-
# end
216+
_reduce(op, list) = op(list[1], _reduce(op, Base.tail(list)))
217+
_reduce(op, list::Tuple{Number}) = list[1]
218+
_reduce(op, list::Tuple{}) = error("cannot reduce an empty list")
288219

289-
# @generated function iterate(iter::ScaledIterator{CR,ITPT}, state::Union{Nothing,CartesianIndex{N}} = nothing) where {CR,ITPT,N}
290-
# value_expr = next_gen(iter)
291-
# quote
292-
# rng_next = state ≡ nothing ? iterate(iter.rng) : iterate(iter.rng, state)
293-
# rng_next ≡ nothing && return nothing
294-
# state = rng_next[2]
295-
# $value_expr
296-
# (value, state)
297-
# end
298-
# end
220+
# We use weights only as a ruler to determine when we are done
221+
cache_evaluations(coefs, i::Int, weights, rest) = (coefs[i, rest...], cache_evaluations(coefs, i+1, Base.tail(weights), rest)...)
222+
cache_evaluations(coefs, indexes, weights, rest) = (coefs[indexes[1], rest...], cache_evaluations(coefs, Base.tail(indexes), Base.tail(weights), rest)...)
223+
cache_evaluations(coefs, ::Int, ::Tuple{}, rest) = ()
224+
cache_evaluations(coefs, ::Any, ::Tuple{}, rest) = ()
299225

300-
# ssize(sitp::ScaledInterpolation{T,N}) where {T,N} = map(r->round(Int, last(r)-first(r)+1), sitp.ranges)::NTuple{N,Int}
226+
ssize(sitp::ScaledInterpolation{T,N}) where {T,N} = map(r->round(Int, last(r)-first(r)+1), sitp.ranges)::NTuple{N,Int}

test/readme-examples.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ using Interpolations, Test
1818
@test v (0.9*(0.8*A[3,4]+0.2*A[4,4]) + 0.1*(0.8*A[3,5]+0.2*A[4,5]))
1919

2020
# Quadratic interpolation with reflecting boundary conditions
21-
# Quadratic is the lowest order that has continuous gradien
21+
# Quadratic is the lowest order that has continuous gradient
2222
itp = interpolate(A, BSpline(Quadratic(Reflect(OnCell()))))
2323

2424
# Linear interpolation in the first dimension, and no interpolation (just lookup) in the second
@@ -27,23 +27,23 @@ using Interpolations, Test
2727
@test v (0.35*A[3,5] + 0.65*A[4,5])
2828

2929

30-
# ## Scaled Bsplines
31-
# A_x = 1.:2.:40.
32-
# A = [log(x) for x in A_x]
33-
# itp = interpolate(A, BSpline(Cubic(Line())), OnGrid())
34-
# sitp = scale(itp, A_x)
35-
# @test sitp(3.) ≈ log(3.) # exactly log(3.)
36-
# @test sitp(3.5) ≈ log(3.5) atol=.1 # approximately log(3.5)
30+
## Scaled Bsplines
31+
A_x = 1.:2.:40.
32+
A = [log(x) for x in A_x]
33+
itp = interpolate(A, BSpline(Cubic(Line())), OnGrid())
34+
sitp = scale(itp, A_x)
35+
@test sitp(3.) log(3.) # exactly log(3.)
36+
@test sitp(3.5) log(3.5) atol=.1 # approximately log(3.5)
3737

3838
# For multidimensional uniformly spaced grids
3939
A_x1 = 1:.1:10
4040
A_x2 = 1:.5:20
4141
f(x1, x2) = log(x1+x2)
4242
A = [f(x1,x2) for x1 in A_x1, x2 in A_x2]
4343
itp = interpolate(A, BSpline(Cubic(Line(OnGrid()))))
44-
# sitp = scale(itp, A_x1, A_x2)
45-
# @test sitp(5., 10.) ≈ log(5 + 10) # exactly log(5 + 10)
46-
# @test sitp(5.6, 7.1) ≈ log(5.6 + 7.1) atol=.1 # approximately log(5.6 + 7.1)
44+
sitp = scale(itp, A_x1, A_x2)
45+
@test sitp(5., 10.) log(5 + 10) # exactly log(5 + 10)
46+
@test sitp(5.6, 7.1) log(5.6 + 7.1) atol=.1 # approximately log(5.6 + 7.1)
4747

4848
# ## Gridded interpolation
4949
# A = rand(8,20)

test/scaling/scaling.jl

Lines changed: 25 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -50,45 +50,29 @@ using Test, LinearAlgebra
5050
sitp32 = @inferred scale(interpolate(Float32[testfunction(x,y) for x in -5:.5:5, y in -4:.2:4], BSpline(Quadratic(Flat(OnGrid())))), -5f0:.5f0:5f0, -4f0:.2f0:4f0)
5151
@test typeof(@inferred(sitp32(-3.4f0, 1.2f0))) == Float32
5252

53-
# # Iteration
54-
# itp = interpolate(rand(3,3,3), BSpline(Quadratic(Flat())), OnCell())
55-
# knots = map(d->1:10:21, 1:3)
56-
# sitp = @inferred scale(itp, knots...)
57-
58-
# iter = @inferred(eachvalue(sitp))
59-
60-
# iter_next = iterate(iter)
61-
# @test iter_next isa Tuple
62-
# @test iter_next[1] isa Float64
63-
# state = iter_next[2]
64-
# inferred_next = Base.return_types(iterate, (typeof(iter),))
65-
# @test length(inferred_next) == 1
66-
# @test inferred_next[1] == Union{Nothing,Tuple{Float64,typeof(state)}}
67-
# iter_next = iterate(iter, state)
68-
# @test iter_next isa Tuple
69-
# @test iter_next[1] isa Float64
70-
# inferred_next = Base.return_types(iterate, (typeof(iter),typeof(state)))
71-
# state = iter_next[2]
72-
# @test length(inferred_next) == 1
73-
# @test inferred_next[1] == Union{Nothing,Tuple{Float64,typeof(state)}}
74-
75-
# function foo!(dest, sitp)
76-
# i = 0
77-
# for s in eachvalue(sitp)
78-
# dest[i+=1] = s
79-
# end
80-
# dest
81-
# end
82-
# function bar!(dest, sitp)
83-
# for I in CartesianIndices(size(dest))
84-
# dest[I] = sitp[I]
85-
# end
86-
# dest
87-
# end
88-
# rfoo = Array{Float64}(undef, Interpolations.ssize(sitp))
89-
# rbar = similar(rfoo)
90-
# foo!(rfoo, sitp)
91-
# bar!(rbar, sitp)
92-
# @test rfoo ≈ rbar
93-
53+
# Iteration
54+
itp = interpolate(rand(3,3,3), BSpline(Quadratic(Flat(OnCell()))))
55+
knots = map(d->1:10:21, 1:3)
56+
sitp = @inferred scale(itp, knots...)
57+
58+
iter = @inferred(eachvalue(sitp))
59+
60+
function foo!(dest, sitp)
61+
i = 0
62+
for s in eachvalue(sitp)
63+
dest[i+=1] = s
64+
end
65+
dest
66+
end
67+
function bar!(dest, sitp)
68+
for I in CartesianIndices(size(dest))
69+
dest[I] = sitp(I)
70+
end
71+
dest
72+
end
73+
rfoo = Array{Float64}(undef, Interpolations.ssize(sitp))
74+
rbar = similar(rfoo)
75+
foo!(rfoo, sitp)
76+
bar!(rbar, sitp)
77+
@test rfoo rbar
9478
end

0 commit comments

Comments
 (0)