@@ -31,30 +31,38 @@ check_ranges(::Any, ::Tuple{}, ::Tuple{}) = nothing
31
31
check_range (:: NoInterp , ax, r) = ax == r || throw (ArgumentError (" The range $r did not equal the corresponding axis of the interpolation object $ax " ))
32
32
check_range (:: Any , ax, r) = length (ax) == length (r) || throw (ArgumentError (" The range $r is incommensurate with the corresponding axis $ax " ))
33
33
34
+ # With regards to size and [], ScaledInterpolation behaves like the underlying interpolation object
34
35
size (sitp:: ScaledInterpolation ) = size (sitp. itp)
35
36
axes (sitp:: ScaledInterpolation ) = axes (sitp. itp)
36
37
38
+ @propagate_inbounds function Base. getindex (sitp:: ScaledInterpolation{T,N} , i:: Vararg{Int,N} ) where {T,N}
39
+ sitp. itp[i... ]
40
+ end
41
+
37
42
lbounds (sitp:: ScaledInterpolation ) = _lbounds (sitp. ranges, itpflag (sitp. itp))
38
43
ubounds (sitp:: ScaledInterpolation ) = _ubounds (sitp. ranges, itpflag (sitp. itp))
39
44
40
45
boundstep (r:: StepRange ) = r. step / 2
41
46
boundstep (r:: UnitRange ) = 1 // 2
42
-
43
- lbound (ax:: AbstractRange , :: DegreeBC , :: OnCell ) = first (ax) - boundstep (ax)
44
- ubound (ax:: AbstractRange , :: DegreeBC , :: OnCell ) = last (ax) + boundstep (ax)
45
- lbound (ax:: AbstractRange , :: DegreeBC , :: OnGrid ) = first (ax)
46
- ubound (ax:: AbstractRange , :: DegreeBC , :: OnGrid ) = last (ax)
47
-
48
47
"""
49
48
Returns *half* the width of one step of the range.
50
49
51
50
This function is used to calculate the upper and lower bounds of `OnCell` interpolation objects.
52
51
""" boundstep
53
52
53
+ lbound (ax:: AbstractRange , :: DegreeBC , :: OnCell ) = first (ax) - boundstep (ax)
54
+ ubound (ax:: AbstractRange , :: DegreeBC , :: OnCell ) = last (ax) + boundstep (ax)
55
+ lbound (ax:: AbstractRange , :: DegreeBC , :: OnGrid ) = first (ax)
56
+ ubound (ax:: AbstractRange , :: DegreeBC , :: OnGrid ) = last (ax)
57
+
58
+ # For (), we scale the evaluation point
54
59
function (sitp:: ScaledInterpolation{T,N} )(xs:: Vararg{Number,N} ) where {T,N}
55
60
xl = coordslookup (itpflag (sitp. itp), sitp. ranges, xs)
56
61
sitp. itp (xl... )
57
62
end
63
+ @inline function (sitp:: ScaledInterpolation )(x:: Vararg{UnexpandedIndexTypes} )
64
+ sitp (to_indices (sitp, x)... )
65
+ end
58
66
59
67
(sitp:: ScaledInterpolation{T,1} , x:: Number , y:: Int ) where {T} = y == 1 ? sitp (x) : Base. throw_boundserror (sitp, (x, y))
60
68
@@ -134,167 +142,85 @@ rescale_gradient(r::UnitRange, g) = g
134
142
Implements the chain rule dy/dx = dy/du * du/dx for use when calculating gradients with scaled interpolation objects.
135
143
""" rescale_gradient
136
144
145
+ # ## Iteration
137
146
138
- # ### Iteration
139
- # mutable struct ScaledIterator{CR<:CartesianIndices,SITPT,X1,Deg,T}
140
- # rng::CR
141
- # sitp::SITPT
142
- # dx_1::X1
143
- # nremaining::Int
144
- # fx_1::X1
145
- # itp_tail::NTuple{Deg,T}
146
- # end
147
-
148
- # nelements(::Union{Type{NoInterp},Type{Constant}}) = 1
149
- # nelements(::Type{Linear}) = 2
150
- # nelements(::Type{Q}) where {Q<:Quadratic} = 3
151
-
152
- # eachvalue_zero(::Type{R}, ::Type{BT}) where {R,BT<:Union{Type{NoInterp},Type{Constant}}} =
153
- # (zero(R),)
154
- # eachvalue_zero(::Type{R}, ::Type{Linear}) where {R} = (zero(R),zero(R))
155
- # eachvalue_zero(::Type{R}, ::Type{Q}) where {R,Q<:Quadratic} = (zero(R),zero(R),zero(R))
156
-
157
- # """
158
- # `eachvalue(sitp)` constructs an iterator for efficiently visiting each
159
- # grid point of a ScaledInterpolation object in which a small grid is
160
- # being "scaled up" to a larger one. For example, suppose you have a
161
- # core `BSpline` object defined on a 5x7x4 grid, and you are scaling it
162
- # to a 100x120x20 grid (via `linspace(1,5,100), linspace(1,7,120),
163
- # linspace(1,4,20)`). You can perform interpolation at each of these
164
- # grid points via
165
-
166
- # ```
167
- # function foo!(dest, sitp)
168
- # i = 0
169
- # for s in eachvalue(sitp)
170
- # dest[i+=1] = s
171
- # end
172
- # dest
173
- # end
174
- # ```
175
-
176
- # which should be more efficient than
177
-
178
- # ```
179
- # function bar!(dest, sitp)
180
- # for I in CartesianIndices(size(dest))
181
- # dest[I] = sitp[I]
182
- # end
183
- # dest
184
- # end
185
- # ```
186
- # """
187
- # function eachvalue(sitp::ScaledInterpolation{T,N}) where {T,N}
188
- # ITPT = basetype(sitp)
189
- # IT = itptype(ITPT)
190
- # R = getindex_return_type(ITPT, Int)
191
- # BT = bsplinetype(iextract(IT, 1))
192
- # itp_tail = eachvalue_zero(R, BT)
193
- # dx_1 = coordlookup(sitp.ranges[1], 2) - coordlookup(sitp.ranges[1], 1)
194
- # ScaledIterator(CartesianIndices(ssize(sitp)), sitp, dx_1, 0, zero(dx_1), itp_tail)
195
- # end
147
+ struct ScaledIterator{SITPT,CI,WIS}
148
+ sitp:: SITPT # ScaledInterpolation object
149
+ ci:: CI # the CartesianIndices object
150
+ wis:: WIS # WeightedIndex vectors
151
+ breaks1:: Vector{Int} # breaks along dimension 1 where new evaluations must occur
152
+ end
196
153
197
- # function index_gen1(::Union{Type{NoInterp}, Type{BSpline{Constant}}})
198
- # quote
199
- # value = iter.itp_tail[1]
200
- # end
201
- # end
154
+ Base. IteratorSize (:: Type{ScaledIterator{SITPT,CI,WIS}} ) where {SITPT,CI<: CartesianIndices{N} ,WIS} where N = Base. HasShape {N} ()
155
+ Base. axes (iter:: ScaledIterator ) = axes (iter. ci)
156
+ Base. size (iter:: ScaledIterator ) = size (iter. ci)
202
157
203
- # function index_gen1(::Type{BSpline{Linear}})
204
- # quote
205
- # p = iter.itp_tail
206
- # value = c_1*p[1] + cp_1*p[2]
207
- # end
208
- # end
158
+ struct ScaledIterState{N,V}
159
+ cistate:: CartesianIndex{N}
160
+ ibreak:: Int
161
+ cached_evaluations:: NTuple{N,V}
162
+ end
209
163
210
- # function index_gen1(::Type{BSpline{Q}}) where Q<:Quadratic
211
- # quote
212
- # p = iter.itp_tail
213
- # value = cm_1*p[1] + c_1*p[2] + cp_1*p[3]
214
- # end
215
- # end
216
- # function index_gen_tail(B::Union{Type{NoInterp}, Type{BSpline{Constant}}}, ::Type{IT}, N) where IT
217
- # [index_gen(B, IT, N, 0)]
218
- # end
164
+ function eachvalue (sitp:: ScaledInterpolation{T,N} ) where {T,N}
165
+ itps = tcollect (itpflag, sitp. itp)
166
+ newaxes = map (r-> Base. Slice (ceil (Int, first (r)): floor (Int, last (r))), sitp. ranges)
167
+ wis = dimension_wis (value_weights, itps, axes (sitp. itp), newaxes, sitp. ranges)
168
+ wis1 = wis[1 ]
169
+ i1 = first (axes (wis1, 1 ))
170
+ breaks1 = [i1]
171
+ for i in Iterators. drop (axes (wis1, 1 ), 1 )
172
+ if indexes (wis1[i]) != indexes (wis1[i- 1 ])
173
+ push! (breaks1, i)
174
+ end
175
+ end
176
+ push! (breaks1, last (axes (wis1, 1 ))+ 1 )
177
+ ScaledIterator (sitp, CartesianIndices (newaxes), wis, breaks1)
178
+ end
219
179
220
- # function index_gen_tail(::Type{BSpline{Linear}}, ::Type{IT}, N) where IT
221
- # [index_gen(BS1, IT, N, i) for i = 0:1]
222
- # end
180
+ function dimension_wis (f:: F , itps, axs, newaxes, ranges) where F
181
+ itpflag, ax, nax, r = itps[1 ], axs[1 ], newaxes[1 ], ranges[1 ]
182
+ function makewi (x)
183
+ pos, coefs = weightedindex_parts ((f,), itpflag, ax, coordlookup (r, x))
184
+ maybe_weightedindex (pos, coefs[1 ])
185
+ end
186
+ (makewi .(nax), dimension_wis (f, Base. tail (itps), Base. tail (axs), Base. tail (newaxes), Base. tail (ranges))... )
187
+ end
188
+ dimension_wis (f, :: Tuple{} , :: Tuple{} , :: Tuple{} , :: Tuple{} ) = ()
189
+
190
+ function Base. iterate (iter:: ScaledIterator )
191
+ ret = iterate (iter. ci)
192
+ ret === nothing && return nothing
193
+ item, cistate = ret
194
+ wis = getindex .(iter. wis, Tuple (item))
195
+ ces = cache_evaluations (iter. sitp. itp. coefs, indexes (wis[1 ]), weights (wis[1 ]), Base. tail (wis))
196
+ return _reduce (+ , weights (wis[1 ]).* ces), ScaledIterState (cistate, first (iter. breaks1), ces)
197
+ end
223
198
224
- # function index_gen_tail(::Type{BSpline{Q}}, ::Type{IT}, N) where {IT,Q<:Quadratic}
225
- # [index_gen(BSpline{Q}, IT, N, i) for i = -1:1]
226
- # end
227
- # function nremaining_gen(::Union{Type{BSpline{Constant}}, Type{BSpline{Q}}}) where Q<:Quadratic
228
- # quote
229
- # EPS = 0.001*iter.dx_1
230
- # floor(Int, iter.dx_1 >= 0 ?
231
- # (min(length(range1)+EPS, round(Int,x_1) + 0.5) - x_1)/iter.dx_1 :
232
- # (max(1-EPS, round(Int,x_1) - 0.5) - x_1)/iter.dx_1)
233
- # end
234
- # end
199
+ function Base. iterate (iter:: ScaledIterator , state)
200
+ ret = iterate (iter. ci, state. cistate)
201
+ ret === nothing && return nothing
202
+ item, cistate = ret
203
+ i1 = item[1 ]
204
+ isnext1 = i1 == state. cistate[1 ]+ 1
205
+ if isnext1 && i1 < iter. breaks1[state. ibreak+ 1 ]
206
+ # We can use the previously cached values
207
+ wis1 = iter. wis[1 ][i1]
208
+ return _reduce (+ , weights (wis1).* state. cached_evaluations), ScaledIterState (cistate, state. ibreak, state. cached_evaluations)
209
+ end
210
+ # Re-evaluate. We're being a bit lazy here: in some cases, some of the cached values could be reused
211
+ wis = getindex .(iter. wis, Tuple (item))
212
+ ces = cache_evaluations (iter. sitp. itp. coefs, indexes (wis[1 ]), weights (wis[1 ]), Base. tail (wis))
213
+ return _reduce (+ , weights (wis[1 ]).* ces), ScaledIterState (cistate, isnext1 ? state. ibreak+ 1 : first (iter. breaks1), ces)
214
+ end
235
215
236
- # function nremaining_gen(::Type{BSpline{Linear}})
237
- # quote
238
- # EPS = 0.001*iter.dx_1
239
- # floor(Int, iter.dx_1 >= 0 ?
240
- # (min(length(range1)+EPS, floor(Int,x_1) + 1) - x_1)/iter.dx_1 :
241
- # (max(1-EPS, floor(Int,x_1)) - x_1)/iter.dx_1)
242
- # end
243
- # end
244
- # function next_gen(::Type{ScaledIterator{CR,SITPT,X1,Deg,T}}) where {CR,SITPT,X1,Deg,T}
245
- # N = ndims(CR)
246
- # ITPT = basetype(SITPT)
247
- # IT = itptype(ITPT)
248
- # BS1 = iextract(IT, 1)
249
- # BS1 == NoInterp && error("eachvalue is not implemented (and does not make sense) for NoInterp along the first dimension")
250
- # pad = padding(ITPT)
251
- # x_syms = [Symbol("x_", i) for i = 1:N]
252
- # interp_index(IT, i) = iextract(IT, i) != NoInterp ?
253
- # :($(x_syms[i]) = coordlookup(sitp.ranges[$i], state[$i])) :
254
- # :($(x_syms[i]) = state[$i])
255
- # # Calculations for the first dimension
256
- # interp_index1 = interp_index(IT, 1)
257
- # indices1 = define_indices_d(BS1, 1, padextract(pad, 1))
258
- # coefexprs1 = coefficients(BS1, N, 1)
259
- # nremaining_expr = nremaining_gen(BS1)
260
- # # Calculations for the rest of the dimensions
261
- # interp_indices_tail = map(i -> interp_index(IT, i), 2:N)
262
- # indices_tail = [define_indices_d(iextract(IT, i), i, padextract(pad, i)) for i = 2:N]
263
- # coefexprs_tail = [coefficients(iextract(IT, i), N, i) for i = 2:N]
264
- # value_exprs_tail = index_gen_tail(BS1, IT, N)
265
- # quote
266
- # sitp = iter.sitp
267
- # itp = sitp.itp
268
- # inds_itp = axes(itp)
269
- # if iter.nremaining > 0
270
- # iter.nremaining -= 1
271
- # iter.fx_1 += iter.dx_1
272
- # else
273
- # range1 = sitp.ranges[1]
274
- # $interp_index1
275
- # $indices1
276
- # iter.nremaining = $nremaining_expr
277
- # iter.fx_1 = fx_1
278
- # $(interp_indices_tail...)
279
- # $(indices_tail...)
280
- # $(coefexprs_tail...)
281
- # @inbounds iter.itp_tail = ($(value_exprs_tail...),)
282
- # end
283
- # fx_1 = iter.fx_1
284
- # $coefexprs1
285
- # $(index_gen1(BS1))
286
- # end
287
- # end
216
+ _reduce (op, list) = op (list[1 ], _reduce (op, Base. tail (list)))
217
+ _reduce (op, list:: Tuple{Number} ) = list[1 ]
218
+ _reduce (op, list:: Tuple{} ) = error (" cannot reduce an empty list" )
288
219
289
- # @generated function iterate(iter::ScaledIterator{CR,ITPT}, state::Union{Nothing,CartesianIndex{N}} = nothing) where {CR,ITPT,N}
290
- # value_expr = next_gen(iter)
291
- # quote
292
- # rng_next = state ≡ nothing ? iterate(iter.rng) : iterate(iter.rng, state)
293
- # rng_next ≡ nothing && return nothing
294
- # state = rng_next[2]
295
- # $value_expr
296
- # (value, state)
297
- # end
298
- # end
220
+ # We use weights only as a ruler to determine when we are done
221
+ cache_evaluations (coefs, i:: Int , weights, rest) = (coefs[i, rest... ], cache_evaluations (coefs, i+ 1 , Base. tail (weights), rest)... )
222
+ cache_evaluations (coefs, indexes, weights, rest) = (coefs[indexes[1 ], rest... ], cache_evaluations (coefs, Base. tail (indexes), Base. tail (weights), rest)... )
223
+ cache_evaluations (coefs, :: Int , :: Tuple{} , rest) = ()
224
+ cache_evaluations (coefs, :: Any , :: Tuple{} , rest) = ()
299
225
300
- # ssize(sitp::ScaledInterpolation{T,N}) where {T,N} = map(r->round(Int, last(r)-first(r)+1), sitp.ranges)::NTuple{N,Int}
226
+ ssize (sitp:: ScaledInterpolation{T,N} ) where {T,N} = map (r-> round (Int, last (r)- first (r)+ 1 ), sitp. ranges):: NTuple{N,Int}
0 commit comments