Skip to content

Commit 6ec8f1f

Browse files
N5N3mkitti
andauthored
Add native GPU support. (#504)
* Native GPU support. * Add some devdoc Co-Authored-By: Mark Kittisopikul <[email protected]> * Simplify `update_eltype` and add more test This should omit the `isempty` check if inference give good result. fix test on 1.6, `range(2, 19.0, 101)` was introduced in 1.7 * Fully drop the old indexing overload for `WeightedIndex`. and remove unneeded `@propagate_inbounds`. * Reenable non-generated `interp_getindex`. And fix the inference problem. Co-authored-by: Mark Kittisopikul <[email protected]>
1 parent 4c29ad6 commit 6ec8f1f

File tree

13 files changed

+276
-96
lines changed

13 files changed

+276
-96
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
33
version = "0.14.0"
44

55
[deps]
6+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
67
AxisAlgorithms = "13072b0f-2c55-5437-9ae7-d433b7a33950"
78
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -31,11 +32,12 @@ DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
3132
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3233
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
3334
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
35+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
3436
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3537
SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
3638
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3739
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
3840
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3941

4042
[targets]
41-
test = ["OffsetArrays", "Unitful", "SharedArrays", "ForwardDiff", "LinearAlgebra", "DualNumbers", "Random", "Test", "Zygote", "ColorVectorSpace"]
43+
test = ["OffsetArrays", "Unitful", "SharedArrays", "ForwardDiff", "LinearAlgebra", "DualNumbers", "Random", "Pkg", "Test", "Zygote", "ColorVectorSpace"]

docs/src/devdocs.md

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ concrete subtypes:
5555
General `AbstractArray`s may be indexed with `WeightedIndex` indices,
5656
and the result produces the interpolated value. In other words, the end result
5757
is just `itp.coefs[wis...]`, where `wis` is a tuple of `WeightedIndex` indices.
58+
To make sure this overloading is effective, we wrap the `coefs` with `InterpGetindex`
59+
i.e. `InterpGetindex(itp.coefs)[wis...]`
5860

5961
Derivatives along a particular axis can be computed just by substituting a component of `wis` for one that has been designed to compute derivatives rather than values.
6062

@@ -107,7 +109,7 @@ Interpolations.WeightedAdjIndex{2, Float64}(1, (0.6000000000000001, 0.3999999999
107109
julia> wis[3]
108110
Interpolations.WeightedAdjIndex{2, Float64}(1, (0.30000000000000004, 0.7))
109111
110-
julia> A[wis...]
112+
julia> Interpolations.InterpGetindex(A)[wis...]
111113
8.7
112114
```
113115

@@ -125,7 +127,7 @@ This computed the value of `itp` at `x...` because we called `weightedindexes` w
125127

126128
!!! note
127129
Remember that prefiltering is not used for `Linear` interpolation.
128-
In a case where prefiltering is used, we would substitute `itp.coefs[wis...]` for `A[wis...]` above.
130+
In a case where prefiltering is used, we would substitute `InterpGetindex(itp.coefs)[wis...]` for `InterpGetindex(A)[wis...]` above.
129131

130132
To compute derivatives, we *also* pass additional functions like [`Interpolations.gradient_weights`](@ref):
131133

@@ -142,13 +144,13 @@ julia> wis[2]
142144
julia> wis[3]
143145
(Interpolations.WeightedAdjIndex{2,Float64}(1, (0.8, 0.19999999999999996)), Interpolations.WeightedAdjIndex{2,Float64}(1, (0.6000000000000001, 0.3999999999999999)), Interpolations.WeightedAdjIndex{2,Float64}(1, (-1.0, 1.0)))
144146
145-
julia> A[wis[1]...]
147+
julia> Interpolations.InterpGetindex(A)[wis[1]...]
146148
1.0
147149
148-
julia> A[wis[2]...]
150+
julia> Interpolations.InterpGetindex(A)[wis[2]...]
149151
3.000000000000001
150152
151-
julia> A[wis[3]...]
153+
julia> Interpolations.InterpGetindex(A)[wis[3]...]
152154
9.0
153155
```
154156
In this case you can see that `wis` is a 3-tuple-of-3-tuples. `A[wis[i]...]` can be used to compute the `i`th component of the gradient.
@@ -168,3 +170,29 @@ The code to do this replacement is a bit complicated due to the need to support
168170
It makes good use of *tuple* manipulations, sometimes called "lispy tuple programming."
169171
You can search Julia's discourse forum for more tips about how to program this way.
170172
It could alternatively be done using generated functions, but this would increase compile time considerably and can lead to world-age problems.
173+
174+
### GPU Support
175+
At present, `Interpolations.jl` supports interpolant usage on GPU via broadcasting.
176+
177+
A basic work flow looks like:
178+
```julia
179+
julia> using Interpolations, Adapt, CUDA # Or any other GPU package
180+
181+
julia > itp = Interpolations.interpolate([1, 2, 3], (BSpline(Linear()))); # construct the interpolant object on CPU
182+
183+
julia> cuitp = adapt(CuArray{Float32}, itp); # adapt it to GPU memory
184+
185+
julia > cuitp.(1:0.5:2) # call interpolant object via broadcast
186+
187+
julia> gradient.(Ref(cuitp), 1:0.5:2)
188+
```
189+
190+
To achieve this, an `ITP <: AbstractInterpolation` should define it's own `Adapt.adapt_structure(to, itp::ITP)`, which constructs a new `ITP` with the adapted
191+
fields (`adapt(to, itp.fieldname)`) of `itp`. The field adaption could be skipped
192+
if we know that it has been GPU-compatable, e.g. a `isbit` range.
193+
194+
!!! note
195+
Some adaptors may change the storage type. Please ensure that the adapted `itp`
196+
has the correct element type via the method `eltype`.
197+
198+
Also, all GPU-compatable `AbstractInterpolation`s should define their own `Interpolations.root_storage_type`. This function allows us to modify the broadcast mechanism by overloading the default `BroadcastStyle`. See [Customizing broadcasting](https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting) for more details.

src/Interpolations.jl

Lines changed: 29 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ using StaticArrays, WoodburyMatrices, Ratios, AxisAlgorithms, OffsetArrays
4141
using ChainRulesCore, Requires
4242

4343
using Base: @propagate_inbounds, HasEltype, EltypeUnknown, HasLength, IsInfinite,
44-
SizeUnknown
44+
SizeUnknown, Indices
4545
import Base: convert, size, axes, promote_rule, ndims, eltype, checkbounds, axes1,
4646
iterate, length, IteratorEltype, IteratorSize, firstindex, getindex, LogicalIndex
4747

@@ -269,76 +269,32 @@ Base.:(/)(wi::WeightedArbIndex, x::Number) = WeightedArbIndex(wi.indexes, wi.wei
269269

270270
### Indexing with WeightedIndex
271271

272-
# We inject indexing with `WeightedIndex` at a non-exported point in the dispatch heirarchy.
273-
# This is to avoid ambiguities with methods that specialize on the array type rather than
274-
# the index type.
275-
Base.to_indices(A, I::Tuple{Vararg{Union{Int,WeightedIndex}}}) = I
276-
if VERSION < v"1.6.0-DEV.104"
277-
@propagate_inbounds Base._getindex(::IndexLinear, A::AbstractVector, i::Int) = getindex(A, i) # ambiguity resolution
272+
# We inject `WeightedIndex` as a non-exported indexing point with a `InterpGetindex` wrapper.
273+
# `InterpGetindex` is not a subtype of `AbstractArray`. This ensures that the overload applies to all array types.
274+
struct InterpGetindex{N,A<:AbstractArray{<:Any,N}}
275+
coeffs::A
276+
InterpGetindex(itp::AbstractInterpolation) = InterpGetindex(coefficients(itp))
277+
InterpGetindex(A::AbstractArray) = new{ndims(A),typeof(A)}(A)
278278
end
279-
@inline function Base._getindex(::IndexStyle, A::AbstractArray{T,N}, I::Vararg{Union{Int,WeightedIndex},N}) where {T,N}
280-
interp_getindex(A, I, ntuple(d->0, Val(N))...)
281-
end
282-
283-
# The non-generated version is currently disabled due to https://github.com/JuliaLang/julia/issues/29117
284-
# # This follows a "move processed indexes to the back" strategy, so J contains the yet-to-be-processed
285-
# # indexes and I all the processed indexes.
286-
# interp_getindex(A::AbstractArray{T,N}, J::Tuple{Int,Vararg{Any,L}}, I::Vararg{Int,M}) where {T,N,L,M} =
287-
# interp_getindex(A, Base.tail(J), I..., J[1])
288-
# function interp_getindex(A::AbstractArray{T,N}, J::Tuple{WeightedIndex,Vararg{Any,L}}, I::Vararg{Int,M}) where {T,N,L,M}
289-
# wi = J[1]
290-
# interp_getindex1(A, indexes(wi), weights(wi), Base.tail(J), I...)
291-
# end
292-
# interp_getindex(A::AbstractArray{T,N}, ::Tuple{}, I::Vararg{Int,N}) where {T,N} = # termination
293-
# @inbounds A[I...] # all bounds-checks have already happened
294-
#
295-
# ## Handle expansion of a single dimension
296-
# # version for WeightedAdjIndex
297-
# @inline interp_getindex1(A, i::Int, weights::NTuple{K,Any}, rest, I::Vararg{Int,M}) where {M,K} =
298-
# weights[1] * interp_getindex(A, rest, I..., i) + interp_getindex1(A, i+1, Base.tail(weights), rest, I...)
299-
# @inline interp_getindex1(A, i::Int, weights::Tuple{Any}, rest, I::Vararg{Int,M}) where M =
300-
# weights[1] * interp_getindex(A, rest, I..., i)
301-
# interp_getindex1(A, i::Int, weights::Tuple{}, rest, I::Vararg{Int,M}) where M =
302-
# error("exhausted the weights, this should never happen")
303-
#
304-
# # version for WeightedArbIndex
305-
# @inline interp_getindex1(A, indexes::NTuple{K,Int}, weights::NTuple{K,Any}, rest, I::Vararg{Int,M}) where {M,K} =
306-
# weights[1] * interp_getindex(A, rest, I..., indexes[1]) + interp_getindex1(A, Base.tail(indexes), Base.tail(weights), rest, I...)
307-
# @inline interp_getindex1(A, indexes::Tuple{Int}, weights::Tuple{Any}, rest, I::Vararg{Int,M}) where M =
308-
# weights[1] * interp_getindex(A, rest, I..., indexes[1])
309-
# interp_getindex1(A, indexes::Tuple{}, weights::Tuple{}, rest, I::Vararg{Int,M}) where M =
310-
# error("exhausted the weights and indexes, this should never happen")
311-
312-
@inline interp_getindex(A::AbstractArray{T,N}, J::Tuple{Int,Vararg{Any,K}}, I::Vararg{Int,N}) where {T,N,K} =
313-
interp_getindex(A, Base.tail(J), Base.tail(I)..., J[1])
314-
@generated function interp_getindex(A::AbstractArray{T,N}, J::Tuple{WeightedAdjIndex{L,W},Vararg{Any,K}}, I::Vararg{Int,N}) where {T,N,K,L,W}
315-
ex = :(w[1]*interp_getindex(A, Jtail, Itail..., j))
316-
for l = 2:L
317-
ex = :(w[$l]*interp_getindex(A, Jtail, Itail..., j+$(l-1)) + $ex)
318-
end
319-
quote
320-
$(Expr(:meta, :inline))
321-
Jtail = Base.tail(J)
322-
Itail = Base.tail(I)
323-
j, w = J[1].istart, J[1].weights
324-
$ex
325-
end
326-
end
327-
@generated function interp_getindex(A::AbstractArray{T,N}, J::Tuple{WeightedArbIndex{L,W},Vararg{Any,K}}, I::Vararg{Int,N}) where {T,N,K,L,W}
328-
ex = :(w[1]*interp_getindex(A, Jtail, Itail..., ij[1]))
329-
for l = 2:L
330-
ex = :(w[$l]*interp_getindex(A, Jtail, Itail..., ij[$l]) + $ex)
331-
end
332-
quote
333-
$(Expr(:meta, :inline))
334-
Jtail = Base.tail(J)
335-
Itail = Base.tail(I)
336-
ij, w = J[1].indexes, J[1].weights
337-
$ex
338-
end
339-
end
340-
@inline interp_getindex(A::AbstractArray{T,N}, ::Tuple{}, I::Vararg{Int,N}) where {T,N} = # termination
341-
@inbounds A[I...] # all bounds-checks have already happened
279+
@inline Base.getindex(A::InterpGetindex{N}, I::Vararg{Union{Int,WeightedIndex},N}) where {N} =
280+
interp_getindex(A.coeffs, ntuple(_ -> 0, Val(N)), map(indexflag, I)...)
281+
indexflag(I::Int) = I
282+
@inline indexflag(I::WeightedIndex) = indextuple(I), weights(I)
283+
284+
# A recursion-based `interp_getindex`, which follows a "move processed indexes to the back" strategy
285+
# `I` contains the processed index, and (wi1, wis...) contains the yet-to-be-processed indexes
286+
# Here we meet a no-interp dim, just append the index to `I`'s end.
287+
@inline interp_getindex(A, I, wi1::Int, wis...) =
288+
interp_getindex(A, (Base.tail(I)..., wi1), wis...)
289+
# Here we handle the expansion of a single dimension.
290+
@inline interp_getindex(A, I, wi1::NTuple{2,Tuple{Any,Vararg{Any,N}}}, wis...) where {N} =
291+
wi1[2][end] * interp_getindex(A, (Base.tail(I)..., wi1[1][end]), wis...) +
292+
interp_getindex(A, I, map(Base.front, wi1), wis...)
293+
@inline interp_getindex(A, I, wi1::NTuple{2,Tuple{Any}}, wis...) =
294+
wi1[2][1] * interp_getindex(A, (Base.tail(I)..., wi1[1][1]), wis...)
295+
# Termination
296+
@inline interp_getindex(A::AbstractArray{T,N}, I::Dims{N}) where {T,N} =
297+
@inbounds A[I...] # all bounds-checks have already happened
342298

343299
"""
344300
w = value_weights(degree, δx)
@@ -489,6 +445,9 @@ include("lanczos/lanczos_opencv.jl")
489445
include("iterate.jl")
490446
include("chainrules/chainrules.jl")
491447
include("hermite/cubic.jl")
448+
if VERSION >= v"1.6"
449+
include("gpu_support.jl")
450+
end
492451

493452
function __init__()
494453
@require Unitful="1986cc42-f94f-5a68-af5c-568840ba703d" include("requires/unitful.jl")

src/b-splines/b-splines.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ However, for customized control you may also construct them with
6666
where `T` gets computed from the product of `TWeights` and `eltype(coefs)`.
6767
(This is equivalent to indicating that you'll be evaluating at locations `itp(x::TWeights, y::TWeights, ...)`.)
6868
"""
69-
struct BSplineInterpolation{T,N,TCoefs<:AbstractArray,IT<:DimSpec{BSpline},Axs<:Tuple{Vararg{AbstractUnitRange,N}}} <: AbstractInterpolation{T,N,IT}
69+
struct BSplineInterpolation{T,N,TCoefs<:AbstractArray,IT<:DimSpec{BSpline},Axs<:Indices{N}} <: AbstractInterpolation{T,N,IT}
7070
coefs::TCoefs
7171
parentaxes::Axs
7272
it::IT
@@ -78,6 +78,9 @@ function Base.:(==)(o1::BSplineInterpolation, o2::BSplineInterpolation)
7878
o1.coefs == o2.coefs
7979
end
8080

81+
BSplineInterpolation{T,N}(A::AbstractArray, axs::Indices{N}, it::IT) where {T,N,IT} =
82+
BSplineInterpolation{T,N,typeof(A),IT,typeof(axs)}(A, axs, it)
83+
8184
function BSplineInterpolation(::Type{TWeights}, A::AbstractArray{Tel,N}, it::IT, axs) where {N,Tel,TWeights<:Real,IT<:DimSpec{BSpline}}
8285
# String interpolation causes allocation, noinline avoids that unless they get called
8386
@noinline err_concrete(IT) = error("The b-spline type must be a concrete type (was $IT)")
@@ -98,7 +101,7 @@ function BSplineInterpolation(::Type{TWeights}, A::AbstractArray{Tel,N}, it::IT,
98101
else
99102
T = typeof(zero(TWeights) * first(A))
100103
end
101-
BSplineInterpolation{T,N,typeof(A),IT,typeof(axs)}(A, fix_axis.(axs), it)
104+
BSplineInterpolation{T,N}(A, fix_axis.(axs), it)
102105
end
103106

104107
function BSplineInterpolation(A::AbstractArray{Tel,N}, it::IT, axs) where {N,Tel,IT<:DimSpec{BSpline}}

src/b-splines/indexing.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ itpinfo(itp) = (tcollect(itpflag, itp), axes(itp))
55
@inline function (itp::BSplineInterpolation{T,N})(x::Vararg{Number,N}) where {T,N}
66
@boundscheck (checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x))
77
wis = weightedindexes((value_weights,), itpinfo(itp)..., x)
8-
itp.coefs[wis...]
8+
InterpGetindex(itp)[wis...]
99
end
1010
@propagate_inbounds function (itp::BSplineInterpolation{T,N})(x::Vararg{Number,M}) where {T,M,N}
1111
inds, trailing = split_trailing(itp, x)
@@ -17,15 +17,15 @@ end
1717
@boundscheck (checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x))
1818
itps = tcollect(itpflag, itp)
1919
wis = dimension_wis(value_weights, itps, axes(itp), x)
20-
coefs = coefficients(itp)
20+
coefs = InterpGetindex(itp)
2121
ret = [coefs[i...] for i in Iterators.product(wis...)]
2222
reshape(ret, shape(wis...))
2323
end
2424

2525
@propagate_inbounds function gradient(itp::BSplineInterpolation{T,N}, x::Vararg{Number,N}) where {T,N}
2626
@boundscheck checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x)
2727
wis = weightedindexes((value_weights, gradient_weights), itpinfo(itp)..., x)
28-
return SVector(_gradient(itp.coefs, wis...)) # work around #311
28+
return SVector(_gradient(InterpGetindex(itp), wis...)) # work around #311
2929
end
3030
@inline _gradient(coefs, inds, moreinds...) = (coefs[inds...], _gradient(coefs, moreinds...)...)
3131
_gradient(coefs) = ()
@@ -37,7 +37,7 @@ end
3737
@propagate_inbounds function hessian(itp::BSplineInterpolation{T,N}, x::Vararg{Number,N}) where {T,N}
3838
@boundscheck checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x)
3939
wis = weightedindexes((value_weights, gradient_weights, hessian_weights), itpinfo(itp)..., x)
40-
symmatrix(map(inds->itp.coefs[inds...], wis))
40+
symmatrix(map(inds->InterpGetindex(itp)[inds...], wis))
4141
end
4242
@propagate_inbounds function hessian!(dest, itp::BSplineInterpolation{T,N}, x::Vararg{Number,N}) where {T,N}
4343
dest .= hessian(itp, x...)

src/extrapolation/extrapolation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ end
7171
eflag = tcollect(etpflag, etp)
7272
xs = inbounds_position(eflag, bounds(itp), x, etp, x)
7373
g = @inbounds gradient(itp, xs...)
74-
skipni = t->skip_flagged_nointerp(itp, t)
75-
SVector(extrapolate_gradient.(skipni(eflag), skipni(x), skipni(xs), Tuple(g)))
74+
skipni = Base.Fix1(skip_flagged_nointerp, itp)
75+
SVector(map(extrapolate_gradient, skipni(eflag), skipni(x), skipni(xs), Tuple(g)))
7676
end
7777
end
7878

src/gpu_support.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import Adapt: adapt_structure
2+
using Adapt: adapt
3+
4+
function adapt_structure(to, itp::BSplineInterpolation{T,N}) where {T,N}
5+
coefs′ = adapt(to, itp.coefs)
6+
T′ = update_eltype(T, coefs′, itp.coefs)
7+
BSplineInterpolation{T′,N}(coefs′, itp.parentaxes, itp.it)
8+
end
9+
10+
function update_eltype(T, coefs′, coefs)
11+
ET = eltype(coefs′)
12+
ET === eltype(coefs) && return T
13+
WT = tweight(coefs′)
14+
T′ = Base.promote_op(*, WT, ET)
15+
(isconcretetype(T′) || isempty(coefs)) && return T′
16+
return typeof(zero(WT) * convert(ET, first(coefs)))
17+
end
18+
19+
function adapt_structure(to, itp::LanczosInterpolation{T,N}) where {T,N}
20+
coefs′ = adapt(to, itp.coefs)
21+
parentaxes′ = adapt(to, itp.parentaxes)
22+
LanczosInterpolation{eltype(coefs′),N}(coefs′, parentaxes′, itp.it)
23+
end
24+
25+
function adapt_structure(to, itp::GriddedInterpolation{T,N}) where {T,N}
26+
coefs′ = adapt(to, itp.coefs)
27+
knots′ = adapt(to, itp.knots)
28+
T′ = update_eltype(T, coefs′, itp.coefs)
29+
GriddedInterpolation{T′,N,typeof(coefs′),itptype(itp),typeof(knots′)}(knots′, coefs′, itp.it)
30+
end
31+
32+
function adapt_structure(to, itp::ScaledInterpolation{T,N,<:Any,IT,RT}) where {T,N,IT,RT<:NTuple{N,AbstractRange}}
33+
ranges = itp.ranges
34+
itp′ = adapt(to, itp.itp)
35+
ScaledInterpolation{eltype(itp′),N,typeof(itp′),IT,RT}(itp′, ranges)
36+
end
37+
38+
function adapt_structure(to, itp::Extrapolation{T,N}) where {T,N}
39+
et = itp.et
40+
itp′ = adapt(to, itp.itp)
41+
Extrapolation{eltype(itp′),N,typeof(itp′),itptype(itp),typeof(et)}(itp′, et)
42+
end
43+
44+
import Base.Broadcast: broadcasted, BroadcastStyle
45+
using Base.Broadcast: broadcastable, combine_styles, AbstractArrayStyle
46+
function broadcasted(itp::AbstractInterpolation, args...)
47+
args′ = map(broadcastable, args)
48+
# we overload BroadcastStyle here (try our best to do broadcast on GPU)
49+
style = combine_styles(Ref(itp), args′...)
50+
broadcasted(style, itp, args′...)
51+
end
52+
53+
"""
54+
Interpolations.root_storage_type(::Type{<:AbstractInterpolation}) -> Type{<:AbstractArray}
55+
56+
This function returns the type of the root cofficients array of an `AbstractInterpolation`.
57+
Some array wrappers, like `OffsetArray`, should be skipped.
58+
"""
59+
root_storage_type(::Type{T}) where {T<:Extrapolation} = root_storage_type(fieldtype(T, 1))
60+
root_storage_type(::Type{T}) where {T<:ScaledInterpolation} = root_storage_type(fieldtype(T, 1))
61+
root_storage_type(::Type{T}) where {T<:BSplineInterpolation} = root_storage_type(fieldtype(T, 1))
62+
root_storage_type(::Type{T}) where {T<:LanczosInterpolation} = root_storage_type(fieldtype(T, 1))
63+
root_storage_type(::Type{T}) where {T<:GriddedInterpolation} = root_storage_type(fieldtype(T, 2))
64+
root_storage_type(::Type{T}) where {T<:OffsetArray} = root_storage_type(fieldtype(T, 1))
65+
root_storage_type(::Type{T}) where {T<:AbstractArray} = T
66+
67+
BroadcastStyle(::Type{<:Ref{T}}) where {T<:AbstractInterpolation} = _to_scalar_style(BroadcastStyle(T))
68+
BroadcastStyle(::Type{T}) where {T<:AbstractInterpolation} = BroadcastStyle(root_storage_type(T))
69+
70+
_to_scalar_style(::S) where {S<:AbstractArrayStyle} = S(Val(0))
71+
_to_scalar_style(S::AbstractArrayStyle{Any}) = S
72+
_to_scalar_style(S) = S

src/gridded/indexing.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
@inline function (itp::GriddedInterpolation{T,N})(x::Vararg{Number,N}) where {T,N}
33
@boundscheck (checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x))
44
wis = weightedindexes((value_weights,), itpinfo(itp)..., x)
5-
coefficients(itp)[wis...]
5+
InterpGetindex(itp)[wis...]
66
end
77
@propagate_inbounds function (itp::GriddedInterpolation{T,N})(x::Vararg{Number,M}) where {T,M,N}
88
inds, trailing = split_trailing(itp, x)
@@ -19,7 +19,7 @@ end
1919
@inline function gradient(itp::GriddedInterpolation{T,N}, x::Vararg{Number,N}) where {T,N}
2020
@boundscheck (checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x))
2121
wis = weightedindexes((value_weights, gradient_weights), itpinfo(itp)..., x)
22-
SVector(map(inds->coefficients(itp)[inds...], wis))
22+
SVector(map(inds->InterpGetindex(itp)[inds...], wis))
2323
end
2424

2525
itpinfo(itp::GriddedInterpolation) = (tcollect(itpflag, itp), itp.knots)
@@ -87,7 +87,7 @@ rescale_gridded(::typeof(hessian_weights), coefs, Δx) = coefs./Δx.^2
8787
else
8888
wis = dimension_wis(value_weights, itps, itp.knots, x)
8989
end
90-
coefs = coefficients(itp)
90+
coefs = InterpGetindex(itp)
9191
ret = [coefs[i...] for i in Iterators.product(wis...)]
9292
reshape(ret, shape(wis...))
9393
end

0 commit comments

Comments
 (0)