From 2cdba1db246ab550295e6d27a9d9370ca3281ffe Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Wed, 27 Jul 2022 16:02:13 +0800 Subject: [PATCH 1/5] Native GPU support. --- Project.toml | 4 +- src/Interpolations.jl | 16 ++++- src/b-splines/b-splines.jl | 7 +- src/b-splines/indexing.jl | 8 +-- src/extrapolation/extrapolation.jl | 4 +- src/gpu_support.jl | 72 ++++++++++++++++++++ src/gridded/indexing.jl | 6 +- src/lanczos/lanczos.jl | 5 +- src/scaling/scaling.jl | 4 +- test/gpu_support.jl | 103 +++++++++++++++++++++++++++++ test/runtests.jl | 6 ++ 11 files changed, 218 insertions(+), 17 deletions(-) create mode 100644 src/gpu_support.jl create mode 100644 test/gpu_support.jl diff --git a/Project.toml b/Project.toml index 496f350b..73d45bbc 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" version = "0.14.0" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" AxisAlgorithms = "13072b0f-2c55-5437-9ae7-d433b7a33950" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -31,6 +32,7 @@ DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -38,4 +40,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["OffsetArrays", "Unitful", "SharedArrays", "ForwardDiff", "LinearAlgebra", "DualNumbers", "Random", "Test", "Zygote", "ColorVectorSpace"] +test = ["OffsetArrays", "Unitful", "SharedArrays", "ForwardDiff", "LinearAlgebra", "DualNumbers", "Random", "Pkg", "Test", "Zygote", "ColorVectorSpace"] diff --git a/src/Interpolations.jl b/src/Interpolations.jl index cb095f0f..e2662d3d 100644 --- a/src/Interpolations.jl +++ b/src/Interpolations.jl @@ -41,7 +41,7 @@ using StaticArrays, WoodburyMatrices, Ratios, AxisAlgorithms, OffsetArrays using ChainRulesCore, Requires using Base: @propagate_inbounds, HasEltype, EltypeUnknown, HasLength, IsInfinite, - SizeUnknown + SizeUnknown, Indices import Base: convert, size, axes, promote_rule, ndims, eltype, checkbounds, axes1, iterate, length, IteratorEltype, IteratorSize, firstindex, getindex, LogicalIndex @@ -269,6 +269,15 @@ Base.:(/)(wi::WeightedArbIndex, x::Number) = WeightedArbIndex(wi.indexes, wi.wei ### Indexing with WeightedIndex +struct InterpGetindex{N,A<:AbstractArray{<:Any,N}} + coeffs::A + InterpGetindex(itp::AbstractInterpolation) = InterpGetindex(coefficients(itp)) + InterpGetindex(A::AbstractArray) = new{ndims(A),typeof(A)}(A) +end +Base.@propagate_inbounds Base.getindex(A::InterpGetindex{N}, I::Vararg{Union{Int,WeightedIndex},N}) where {N} = + interp_getindex(A.coeffs, I, ntuple(d->0, Val(N))...) + +# TODO: should we drop the following injection? # We inject indexing with `WeightedIndex` at a non-exported point in the dispatch heirarchy. # This is to avoid ambiguities with methods that specialize on the array type rather than # the index type. @@ -277,7 +286,7 @@ if VERSION < v"1.6.0-DEV.104" @propagate_inbounds Base._getindex(::IndexLinear, A::AbstractVector, i::Int) = getindex(A, i) # ambiguity resolution end @inline function Base._getindex(::IndexStyle, A::AbstractArray{T,N}, I::Vararg{Union{Int,WeightedIndex},N}) where {T,N} - interp_getindex(A, I, ntuple(d->0, Val(N))...) + InterpGetindex(A)[I...] end # The non-generated version is currently disabled due to https://github.com/JuliaLang/julia/issues/29117 @@ -489,6 +498,9 @@ include("lanczos/lanczos_opencv.jl") include("iterate.jl") include("chainrules/chainrules.jl") include("hermite/cubic.jl") +if VERSION >= v"1.6" + include("gpu_support.jl") +end function __init__() @require Unitful="1986cc42-f94f-5a68-af5c-568840ba703d" include("requires/unitful.jl") diff --git a/src/b-splines/b-splines.jl b/src/b-splines/b-splines.jl index bdb0476d..209e4fb3 100644 --- a/src/b-splines/b-splines.jl +++ b/src/b-splines/b-splines.jl @@ -66,7 +66,7 @@ However, for customized control you may also construct them with where `T` gets computed from the product of `TWeights` and `eltype(coefs)`. (This is equivalent to indicating that you'll be evaluating at locations `itp(x::TWeights, y::TWeights, ...)`.) """ -struct BSplineInterpolation{T,N,TCoefs<:AbstractArray,IT<:DimSpec{BSpline},Axs<:Tuple{Vararg{AbstractUnitRange,N}}} <: AbstractInterpolation{T,N,IT} +struct BSplineInterpolation{T,N,TCoefs<:AbstractArray,IT<:DimSpec{BSpline},Axs<:Indices{N}} <: AbstractInterpolation{T,N,IT} coefs::TCoefs parentaxes::Axs it::IT @@ -78,6 +78,9 @@ function Base.:(==)(o1::BSplineInterpolation, o2::BSplineInterpolation) o1.coefs == o2.coefs end +BSplineInterpolation{T,N}(A::AbstractArray, axs::Indices{N}, it::IT) where {T,N,IT} = + BSplineInterpolation{T,N,typeof(A),IT,typeof(axs)}(A, axs, it) + function BSplineInterpolation(::Type{TWeights}, A::AbstractArray{Tel,N}, it::IT, axs) where {N,Tel,TWeights<:Real,IT<:DimSpec{BSpline}} # String interpolation causes allocation, noinline avoids that unless they get called @noinline err_concrete(IT) = error("The b-spline type must be a concrete type (was $IT)") @@ -98,7 +101,7 @@ function BSplineInterpolation(::Type{TWeights}, A::AbstractArray{Tel,N}, it::IT, else T = typeof(zero(TWeights) * first(A)) end - BSplineInterpolation{T,N,typeof(A),IT,typeof(axs)}(A, fix_axis.(axs), it) + BSplineInterpolation{T,N}(A, fix_axis.(axs), it) end function BSplineInterpolation(A::AbstractArray{Tel,N}, it::IT, axs) where {N,Tel,IT<:DimSpec{BSpline}} diff --git a/src/b-splines/indexing.jl b/src/b-splines/indexing.jl index 4ab7a381..26f16c95 100644 --- a/src/b-splines/indexing.jl +++ b/src/b-splines/indexing.jl @@ -5,7 +5,7 @@ itpinfo(itp) = (tcollect(itpflag, itp), axes(itp)) @inline function (itp::BSplineInterpolation{T,N})(x::Vararg{Number,N}) where {T,N} @boundscheck (checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x)) wis = weightedindexes((value_weights,), itpinfo(itp)..., x) - itp.coefs[wis...] + InterpGetindex(itp)[wis...] end @propagate_inbounds function (itp::BSplineInterpolation{T,N})(x::Vararg{Number,M}) where {T,M,N} inds, trailing = split_trailing(itp, x) @@ -17,7 +17,7 @@ end @boundscheck (checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x)) itps = tcollect(itpflag, itp) wis = dimension_wis(value_weights, itps, axes(itp), x) - coefs = coefficients(itp) + coefs = InterpGetindex(itp) ret = [coefs[i...] for i in Iterators.product(wis...)] reshape(ret, shape(wis...)) end @@ -25,7 +25,7 @@ end @propagate_inbounds function gradient(itp::BSplineInterpolation{T,N}, x::Vararg{Number,N}) where {T,N} @boundscheck checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x) wis = weightedindexes((value_weights, gradient_weights), itpinfo(itp)..., x) - return SVector(_gradient(itp.coefs, wis...)) # work around #311 + return SVector(_gradient(InterpGetindex(itp), wis...)) # work around #311 end @inline _gradient(coefs, inds, moreinds...) = (coefs[inds...], _gradient(coefs, moreinds...)...) _gradient(coefs) = () @@ -37,7 +37,7 @@ end @propagate_inbounds function hessian(itp::BSplineInterpolation{T,N}, x::Vararg{Number,N}) where {T,N} @boundscheck checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x) wis = weightedindexes((value_weights, gradient_weights, hessian_weights), itpinfo(itp)..., x) - symmatrix(map(inds->itp.coefs[inds...], wis)) + symmatrix(map(inds->InterpGetindex(itp)[inds...], wis)) end @propagate_inbounds function hessian!(dest, itp::BSplineInterpolation{T,N}, x::Vararg{Number,N}) where {T,N} dest .= hessian(itp, x...) diff --git a/src/extrapolation/extrapolation.jl b/src/extrapolation/extrapolation.jl index 43b5edc3..1eb772a4 100644 --- a/src/extrapolation/extrapolation.jl +++ b/src/extrapolation/extrapolation.jl @@ -71,8 +71,8 @@ end eflag = tcollect(etpflag, etp) xs = inbounds_position(eflag, bounds(itp), x, etp, x) g = @inbounds gradient(itp, xs...) - skipni = t->skip_flagged_nointerp(itp, t) - SVector(extrapolate_gradient.(skipni(eflag), skipni(x), skipni(xs), Tuple(g))) + skipni = Base.Fix1(skip_flagged_nointerp, itp) + SVector(map(extrapolate_gradient, skipni(eflag), skipni(x), skipni(xs), Tuple(g))) end end diff --git a/src/gpu_support.jl b/src/gpu_support.jl new file mode 100644 index 00000000..0cbc185d --- /dev/null +++ b/src/gpu_support.jl @@ -0,0 +1,72 @@ +import Adapt: adapt_structure +using Adapt: adapt + +function adapt_structure(to, itp::BSplineInterpolation{T,N}) where {T,N} + coefs′ = adapt(to, itp.coefs) + T′ = update_eltype(T, coefs′, itp.coefs) + BSplineInterpolation{T′,N}(coefs′, itp.parentaxes, itp.it) +end + +function update_eltype(T, coefs′, coefs) + ET = eltype(coefs′) + ET === eltype(coefs) && return T + WT = tweight(coefs′) + if isempty(coefs) + T′ = Base.promote_op(*, WT, ET) + else + T′ = Base.promote_op(*, WT, ET) + if !isconcretetype(T′) + T′ = typeof(zero(WT) * convert(ET, first(coefs))) + end + end + return T′ +end + +function adapt_structure(to, itp::LanczosInterpolation{T,N}) where {T,N} + coefs′ = adapt(to, itp.coefs) + parentaxes′ = adapt(to, itp.parentaxes) + LanczosInterpolation{eltype(coefs′),N}(coefs′, parentaxes′, itp.it) +end + +function adapt_structure(to, itp::GriddedInterpolation{T,N}) where {T,N} + coefs′ = adapt(to, itp.coefs) + knots′ = adapt(to, itp.knots) + T′ = update_eltype(T, coefs′, itp.coefs) + GriddedInterpolation{T′,N,typeof(coefs′),itptype(itp),typeof(knots′)}(knots′, coefs′, itp.it) +end + +function adapt_structure(to, itp::ScaledInterpolation{T,N,<:Any,IT,RT}) where {T,N,IT,RT<:NTuple{N,AbstractRange}} + ranges = itp.ranges + itp′ = adapt(to, itp.itp) + ScaledInterpolation{eltype(itp′),N,typeof(itp′),IT,RT}(itp′, ranges) +end + +function adapt_structure(to, itp::Extrapolation{T,N}) where {T,N} + et = itp.et + itp′ = adapt(to, itp.itp) + Extrapolation{eltype(itp′),N,typeof(itp′),itptype(itp),typeof(et)}(itp′, et) +end + +import Base.Broadcast: broadcasted, BroadcastStyle +using Base.Broadcast: broadcastable, combine_styles, AbstractArrayStyle +function broadcasted(itp::AbstractInterpolation, args...) + args′ = map(broadcastable, args) + # we overload BroadcastStyle here (try our best to do broadcast on GPU) + style = combine_styles(Ref(itp), args′...) + broadcasted(style, itp, args′...) +end + +root_storage_type(::Type{T}) where {T<:AbstractArray} = T +root_storage_type(::Type{T}) where {T<:Extrapolation} = root_storage_type(fieldtype(T, 1)) +root_storage_type(::Type{T}) where {T<:ScaledInterpolation} = root_storage_type(fieldtype(T, 1)) +root_storage_type(::Type{T}) where {T<:BSplineInterpolation} = root_storage_type(fieldtype(T, 1)) +root_storage_type(::Type{T}) where {T<:LanczosInterpolation} = root_storage_type(fieldtype(T, 1)) +root_storage_type(::Type{T}) where {T<:GriddedInterpolation} = root_storage_type(fieldtype(T, 2)) +root_storage_type(::Type{T}) where {T<:OffsetArray} = root_storage_type(fieldtype(T, 1)) + +BroadcastStyle(::Type{<:Ref{T}}) where {T<:AbstractInterpolation} = _to_scalar_style(BroadcastStyle(T)) +BroadcastStyle(::Type{T}) where {T<:AbstractInterpolation} = BroadcastStyle(root_storage_type(T)) + +_to_scalar_style(::S) where {S<:AbstractArrayStyle} = S(Val(0)) +_to_scalar_style(S::AbstractArrayStyle{Any}) = S +_to_scalar_style(S) = S diff --git a/src/gridded/indexing.jl b/src/gridded/indexing.jl index aa685120..43664b72 100644 --- a/src/gridded/indexing.jl +++ b/src/gridded/indexing.jl @@ -2,7 +2,7 @@ @inline function (itp::GriddedInterpolation{T,N})(x::Vararg{Number,N}) where {T,N} @boundscheck (checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x)) wis = weightedindexes((value_weights,), itpinfo(itp)..., x) - coefficients(itp)[wis...] + InterpGetindex(itp)[wis...] end @propagate_inbounds function (itp::GriddedInterpolation{T,N})(x::Vararg{Number,M}) where {T,M,N} inds, trailing = split_trailing(itp, x) @@ -19,7 +19,7 @@ end @inline function gradient(itp::GriddedInterpolation{T,N}, x::Vararg{Number,N}) where {T,N} @boundscheck (checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x)) wis = weightedindexes((value_weights, gradient_weights), itpinfo(itp)..., x) - SVector(map(inds->coefficients(itp)[inds...], wis)) + SVector(map(inds->InterpGetindex(itp)[inds...], wis)) end itpinfo(itp::GriddedInterpolation) = (tcollect(itpflag, itp), itp.knots) @@ -87,7 +87,7 @@ rescale_gridded(::typeof(hessian_weights), coefs, Δx) = coefs./Δx.^2 else wis = dimension_wis(value_weights, itps, itp.knots, x) end - coefs = coefficients(itp) + coefs = InterpGetindex(itp) ret = [coefs[i...] for i in Iterators.product(wis...)] reshape(ret, shape(wis...)) end diff --git a/src/lanczos/lanczos.jl b/src/lanczos/lanczos.jl index 57f6940a..62d0ed55 100644 --- a/src/lanczos/lanczos.jl +++ b/src/lanczos/lanczos.jl @@ -29,6 +29,9 @@ struct LanczosInterpolation{T,N,IT <: DimSpec{AbstractLanczos},A <: AbstractArra it::IT end +LanczosInterpolation{T,N}(coefs::AbstractArray{T,N}, parentaxes::NTuple{N,AbstractArray}, it::IT) where {T,N,IT} = + LanczosInterpolation{T,N,IT,typeof(coefs),typeof(parentaxes)}(coefs, parentaxes, it) + @generated degree(::Lanczos{N}) where {N} = :($N) getknots(itp::LanczosInterpolation) = axes(itp) @@ -48,7 +51,7 @@ end @inline function (itp::LanczosInterpolation{T,N})(x::Vararg{Number,N}) where {T,N} @boundscheck (checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x)) wis = weightedindexes((value_weights,), itpinfo(itp)..., x) - itp.coefs[wis...] + InterpGetindex(itp)[wis...] end function weightedindex_parts(fs, it::AbstractLanczos, ax::AbstractUnitRange{<:Integer}, x) diff --git a/src/scaling/scaling.jl b/src/scaling/scaling.jl index d94ffafb..8d0cfb49 100644 --- a/src/scaling/scaling.jl +++ b/src/scaling/scaling.jl @@ -209,7 +209,7 @@ function Base.iterate(iter::ScaledIterator) ret === nothing && return nothing item, cistate = ret wis = getindex.(iter.wis, Tuple(item)) - ces = cache_evaluations(iter.sitp.itp.coefs, indexes(wis[1]), weights(wis[1]), Base.tail(wis)) + ces = cache_evaluations(InterpGetindex(iter.sitp), indexes(wis[1]), weights(wis[1]), Base.tail(wis)) return _reduce(+, weights(wis[1]).*ces), ScaledIterState(cistate, firstindex(iter.breaks1), ces) end @@ -226,7 +226,7 @@ function Base.iterate(iter::ScaledIterator, state) end # Re-evaluate. We're being a bit lazy here: in some cases, some of the cached values could be reused wis = getindex.(iter.wis, Tuple(item)) - ces = cache_evaluations(iter.sitp.itp.coefs, indexes(wis[1]), weights(wis[1]), Base.tail(wis)) + ces = cache_evaluations(InterpGetindex(iter.sitp), indexes(wis[1]), weights(wis[1]), Base.tail(wis)) return _reduce(+, weights(wis[1]).*ces), ScaledIterState(cistate, isnext1 ? state.ibreak+1 : firstindex(iter.breaks1), ces) end diff --git a/test/gpu_support.jl b/test/gpu_support.jl new file mode 100644 index 00000000..3acb691b --- /dev/null +++ b/test/gpu_support.jl @@ -0,0 +1,103 @@ +using JLArrays, Adapt +JLArrays.allowscalar(false) + +@testset "1d GPU Interpolation" begin + A_x = 1.:2.:40. + A = [log(x) for x in A_x] + itp = interpolate(A, BSpline(Cubic(Line(OnGrid())))) + jlitp = jl(itp) + idx = range(2, 19., 101) + jlidx = jl(collect(idx)) + @test itp.(idx) == collect(jlitp.(idx)) == collect(jlitp.(jlidx)) + @test gradient.(Ref(itp), idx) == + collect(gradient.(Ref(jlitp), idx)) == + collect(gradient.(Ref(jlitp), jlidx)) + + sitp = scale(itp, A_x) + jlsitp = jl(sitp) + idx = range(1., 39., 99) + jlidx = jl(collect(idx)) + @test sitp.(idx) == collect(jlsitp.(idx)) == collect(jlsitp.(jlidx)) + @test gradient.(Ref(sitp), idx) == + collect(gradient.(Ref(jlsitp), idx)) == + collect(gradient.(Ref(jlsitp), jlidx)) + + + esitp = extrapolate(sitp, Flat()) + jlesitp = jl(esitp) + idx = range(-1., 41., 51) + jlidx = jl(collect(idx)) + @test esitp.(idx) == collect(jlesitp.(idx)) == collect(jlesitp.(jlidx)) + @test gradient.(Ref(esitp), idx) == + collect(gradient.(Ref(jlesitp), idx)) == + collect(gradient.(Ref(jlesitp), jlidx)) +end + +@testset "2d GPU Interpolation" begin + A_x = 1.:2.:40. + A = [log(x + y) for x in A_x, y in 1.:2.:40.] + itp = interpolate(A, (BSpline(Cubic(Line(OnGrid()))), BSpline(Linear()))) + jlitp = jl(itp) + idx = range(2, 19., 101) + jlidx = jl(collect(idx)) + @test itp.(idx, idx') == collect(jlitp.(idx, idx')) == collect(jlitp.(jlidx, jlidx')) + @test gradient.(Ref(itp), idx, idx') == + collect(gradient.(Ref(jlitp), idx, idx')) == + collect(gradient.(Ref(jlitp), jlidx, jlidx')) + @test hessian.(Ref(itp), idx, idx') == + collect(hessian.(Ref(jlitp), idx, idx')) == + collect(hessian.(Ref(jlitp), jlidx, jlidx')) + + sitp = scale(itp, A_x, A_x) + jlsitp = jl(sitp) + idx = range(1., 39., 99) + jlidx = jl(collect(idx)) + @test sitp.(idx, idx') == collect(jlsitp.(idx, idx')) == collect(jlsitp.(jlidx, jlidx')) + @test gradient.(Ref(sitp), idx, idx') == + collect(gradient.(Ref(jlsitp), idx, idx')) == + collect(gradient.(Ref(jlsitp), jlidx, jlidx')) + @test hessian.(Ref(sitp), idx, idx') == + collect(hessian.(Ref(jlsitp), idx, idx')) == + collect(hessian.(Ref(jlsitp), jlidx, jlidx')) + + esitp = extrapolate(sitp, Flat()) + jlesitp = jl(esitp) + idx = range(-1., 41., 51) + jlidx = jl(collect(idx)) + @test esitp.(idx, idx') == collect(jlesitp.(idx, idx')) == collect(jlesitp.(jlidx, jlidx')) + # gradient for `extrapolation` is currently broken under CUDA + @test gradient.(Ref(esitp), idx, idx') == + collect(gradient.(Ref(jlesitp), idx, idx')) == + collect(gradient.(Ref(jlesitp), jlidx, jlidx')) +end + +@testset "Lanczos on gpu" begin + X = 1:100 + X = [X; reverse(X)[2:end]] + for N = 2:4 + itp = interpolate(X, Lanczos(N)) + @test itp.(X) == collect(jl(itp).(jl(X))) + end + itp = interpolate(X, Lanczos4OpenCV()) + @test itp.(X) == collect(jl(itp).(jl(X))) +end + +@testset "Gridded on gpu" begin + itp1 = interpolate(Vector.((range(-1,1,101), range(-1,1,101))), randn(101,101), Gridded(Linear())) + itp2 = interpolate((range(-1,1,101), range(-1,1,101)), randn(101,101), Gridded(Linear())) + idx = range(-1, 1, 301) + jlidx = jl(collect(idx)) + @test itp1.(idx, idx') == collect(jl(itp1).(idx, idx')) == collect(jl(itp1).(jlidx, jlidx')) + @test itp2.(idx, idx') == collect(jl(itp2).(idx, idx')) == collect(jl(itp2).(jlidx, jlidx')) +end + +@testset "eltype after adaption" begin + A_x = 1.:2.:40. + A = [log(x) for x in A_x] + itp = interpolate(A, BSpline(Cubic(Line(OnGrid())))) + @test eltype(adapt(Array{Float32}, itp)) === Float32 + @test eltype(adapt(Array{Float32}, scale(itp, A_x))) === Float32 + @test eltype(adapt(Array{Float32}, extrapolate(scale(itp, A_x), Flat()))) === Float32 + itp = interpolate((range(-1,1,10), range(-1,1,10)), randn(10,10), Gridded(Linear())) + @test eltype(adapt(Array{Float32}, itp)) === Float32 +end diff --git a/test/runtests.jl b/test/runtests.jl index 718d0c66..bd312092 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,4 +53,10 @@ const isci = get(ENV, "CI", "") in ("true", "True") # Chain rules interaction include("chainrules.jl") + + if VERSION >= v"1.6" + import Pkg + Pkg.add("JLArrays") + include("gpu_support.jl") + end end From c8cebe254aab1d9d05cb15b4bbbf75b44ea823c4 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Fri, 29 Jul 2022 19:11:02 +0800 Subject: [PATCH 2/5] Add some devdoc Co-Authored-By: Mark Kittisopikul --- docs/src/devdocs.md | 38 +++++++++++++++++++++++++++++++++----- src/gpu_support.jl | 8 +++++++- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/docs/src/devdocs.md b/docs/src/devdocs.md index 4d312d96..0f0aab1b 100644 --- a/docs/src/devdocs.md +++ b/docs/src/devdocs.md @@ -55,6 +55,8 @@ concrete subtypes: General `AbstractArray`s may be indexed with `WeightedIndex` indices, and the result produces the interpolated value. In other words, the end result is just `itp.coefs[wis...]`, where `wis` is a tuple of `WeightedIndex` indices. +To make sure this overloading is effective, we wrap the `coefs` with `InterpGetindex`, +i.e. `InterpGetindex(itp.coefs)[wis...]` Derivatives along a particular axis can be computed just by substituting a component of `wis` for one that has been designed to compute derivatives rather than values. @@ -107,7 +109,7 @@ Interpolations.WeightedAdjIndex{2, Float64}(1, (0.6000000000000001, 0.3999999999 julia> wis[3] Interpolations.WeightedAdjIndex{2, Float64}(1, (0.30000000000000004, 0.7)) -julia> A[wis...] +julia> Interpolations.InterpGetindex(A)[wis...] 8.7 ``` @@ -125,7 +127,7 @@ This computed the value of `itp` at `x...` because we called `weightedindexes` w !!! note Remember that prefiltering is not used for `Linear` interpolation. - In a case where prefiltering is used, we would substitute `itp.coefs[wis...]` for `A[wis...]` above. + In a case where prefiltering is used, we would substitute `InterpGetindex(itp.coefs)[wis...]` for `InterpGetindex(A)[wis...]` above. To compute derivatives, we *also* pass additional functions like [`Interpolations.gradient_weights`](@ref): @@ -142,13 +144,13 @@ julia> wis[2] julia> wis[3] (Interpolations.WeightedAdjIndex{2,Float64}(1, (0.8, 0.19999999999999996)), Interpolations.WeightedAdjIndex{2,Float64}(1, (0.6000000000000001, 0.3999999999999999)), Interpolations.WeightedAdjIndex{2,Float64}(1, (-1.0, 1.0))) -julia> A[wis[1]...] +julia> Interpolations.InterpGetindex(A)[wis[1]...] 1.0 -julia> A[wis[2]...] +julia> Interpolations.InterpGetindex(A)[wis[2]...] 3.000000000000001 -julia> A[wis[3]...] +julia> Interpolations.InterpGetindex(A)[wis[3]...] 9.0 ``` In this case you can see that `wis` is a 3-tuple-of-3-tuples. `A[wis[i]...]` can be used to compute the `i`th component of the gradient. @@ -168,3 +170,29 @@ The code to do this replacement is a bit complicated due to the need to support It makes good use of *tuple* manipulations, sometimes called "lispy tuple programming." You can search Julia's discourse forum for more tips about how to program this way. It could alternatively be done using generated functions, but this would increase compile time considerably and can lead to world-age problems. + +### GPU Support +At present, `Interpolations.jl` supports interpolant usage on GPU via broadcasting. + +A basic work flow looks like: +```julia +julia> using Interpolations, Adapt, CUDA # Or any other GPU package + +julia > itp = Interpolations.interpolate([1, 2, 3], (BSpline(Linear()))); # construct the interpolant object on CPU + +julia> cuitp = adapt(CuArray{Float32}, itp); # adapt it to GPU memory + +julia > cuitp.(1:0.5:2) # call interpolant object via broadcast + +julia> gradient.(Ref(cuitp), 1:0.5:2) +``` + +To achieve this, an `ITP <: AbstractInterpolation` should define it's own `Adapt.adapt_structure(to, itp::ITP)`, which constructs a new `ITP` with the adapted +fields (`adapt(to, itp.fieldname)`) of `itp`. The field adaption could be skipped +if we know that it has been GPU-compatable, e.g. a `isbit` range. + +!!! note + Some adaptors may change the storage type. Please ensure that the adapted `itp` + has the correct element type via the method `eltype`. + +Also, all GPU-compatable `AbstractInterpolation`s should define their own `Interpolations.root_storage_type`. This function allows us to modify the broadcast mechanism by overloading the default `BroadcastStyle`. See [Customizing broadcasting](https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting) for more details. diff --git a/src/gpu_support.jl b/src/gpu_support.jl index 0cbc185d..3d9760e3 100644 --- a/src/gpu_support.jl +++ b/src/gpu_support.jl @@ -56,13 +56,19 @@ function broadcasted(itp::AbstractInterpolation, args...) broadcasted(style, itp, args′...) end -root_storage_type(::Type{T}) where {T<:AbstractArray} = T +""" + Interpolations.root_storage_type(::Type{<:AbstractInterpolation}) -> Type{<:AbstractArray} + +This function returns the type of the root cofficients array of an `AbstractInterpolation`. +Some array wrappers, like `OffsetArray`, should be skipped. +""" root_storage_type(::Type{T}) where {T<:Extrapolation} = root_storage_type(fieldtype(T, 1)) root_storage_type(::Type{T}) where {T<:ScaledInterpolation} = root_storage_type(fieldtype(T, 1)) root_storage_type(::Type{T}) where {T<:BSplineInterpolation} = root_storage_type(fieldtype(T, 1)) root_storage_type(::Type{T}) where {T<:LanczosInterpolation} = root_storage_type(fieldtype(T, 1)) root_storage_type(::Type{T}) where {T<:GriddedInterpolation} = root_storage_type(fieldtype(T, 2)) root_storage_type(::Type{T}) where {T<:OffsetArray} = root_storage_type(fieldtype(T, 1)) +root_storage_type(::Type{T}) where {T<:AbstractArray} = T BroadcastStyle(::Type{<:Ref{T}}) where {T<:AbstractInterpolation} = _to_scalar_style(BroadcastStyle(T)) BroadcastStyle(::Type{T}) where {T<:AbstractInterpolation} = BroadcastStyle(root_storage_type(T)) From 494654743f6bc0c8cd035bfb37f007ca8c070720 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Sat, 30 Jul 2022 05:48:03 +0800 Subject: [PATCH 3/5] Simplify `update_eltype` and add more test This should omit the `isempty` check if inference give good result. fix test on 1.6, `range(2, 19.0, 101)` was introduced in 1.7 --- src/gpu_support.jl | 12 +++------ test/gpu_support.jl | 63 ++++++++++++++++++++++++--------------------- 2 files changed, 36 insertions(+), 39 deletions(-) diff --git a/src/gpu_support.jl b/src/gpu_support.jl index 3d9760e3..4990f5bd 100644 --- a/src/gpu_support.jl +++ b/src/gpu_support.jl @@ -11,15 +11,9 @@ function update_eltype(T, coefs′, coefs) ET = eltype(coefs′) ET === eltype(coefs) && return T WT = tweight(coefs′) - if isempty(coefs) - T′ = Base.promote_op(*, WT, ET) - else - T′ = Base.promote_op(*, WT, ET) - if !isconcretetype(T′) - T′ = typeof(zero(WT) * convert(ET, first(coefs))) - end - end - return T′ + T′ = Base.promote_op(*, WT, ET) + (isconcretetype(T′) || isempty(coefs)) && return T′ + return typeof(zero(WT) * convert(ET, first(coefs))) end function adapt_structure(to, itp::LanczosInterpolation{T,N}) where {T,N} diff --git a/test/gpu_support.jl b/test/gpu_support.jl index 3acb691b..c252a811 100644 --- a/test/gpu_support.jl +++ b/test/gpu_support.jl @@ -2,73 +2,73 @@ using JLArrays, Adapt JLArrays.allowscalar(false) @testset "1d GPU Interpolation" begin - A_x = 1.:2.:40. + A_x = 1.0:2.0:40.0 A = [log(x) for x in A_x] itp = interpolate(A, BSpline(Cubic(Line(OnGrid())))) jlitp = jl(itp) - idx = range(2, 19., 101) + idx = 2.0:0.17:19.0 jlidx = jl(collect(idx)) @test itp.(idx) == collect(jlitp.(idx)) == collect(jlitp.(jlidx)) @test gradient.(Ref(itp), idx) == - collect(gradient.(Ref(jlitp), idx)) == - collect(gradient.(Ref(jlitp), jlidx)) + collect(gradient.(Ref(jlitp), idx)) == + collect(gradient.(Ref(jlitp), jlidx)) sitp = scale(itp, A_x) jlsitp = jl(sitp) - idx = range(1., 39., 99) + idx = 1.0:0.4:39.0 jlidx = jl(collect(idx)) @test sitp.(idx) == collect(jlsitp.(idx)) == collect(jlsitp.(jlidx)) @test gradient.(Ref(sitp), idx) == - collect(gradient.(Ref(jlsitp), idx)) == - collect(gradient.(Ref(jlsitp), jlidx)) + collect(gradient.(Ref(jlsitp), idx)) == + collect(gradient.(Ref(jlsitp), jlidx)) esitp = extrapolate(sitp, Flat()) jlesitp = jl(esitp) - idx = range(-1., 41., 51) + idx = -1.0:0.84:41.0 jlidx = jl(collect(idx)) @test esitp.(idx) == collect(jlesitp.(idx)) == collect(jlesitp.(jlidx)) @test gradient.(Ref(esitp), idx) == - collect(gradient.(Ref(jlesitp), idx)) == - collect(gradient.(Ref(jlesitp), jlidx)) + collect(gradient.(Ref(jlesitp), idx)) == + collect(gradient.(Ref(jlesitp), jlidx)) end @testset "2d GPU Interpolation" begin - A_x = 1.:2.:40. - A = [log(x + y) for x in A_x, y in 1.:2.:40.] + A_x = 1.0:2.0:40.0 + A = [log(x + y) for x in A_x, y in 1.0:2.0:40.0] itp = interpolate(A, (BSpline(Cubic(Line(OnGrid()))), BSpline(Linear()))) jlitp = jl(itp) - idx = range(2, 19., 101) + idx = 2.0:0.17:19.0 jlidx = jl(collect(idx)) @test itp.(idx, idx') == collect(jlitp.(idx, idx')) == collect(jlitp.(jlidx, jlidx')) @test gradient.(Ref(itp), idx, idx') == - collect(gradient.(Ref(jlitp), idx, idx')) == - collect(gradient.(Ref(jlitp), jlidx, jlidx')) + collect(gradient.(Ref(jlitp), idx, idx')) == + collect(gradient.(Ref(jlitp), jlidx, jlidx')) @test hessian.(Ref(itp), idx, idx') == - collect(hessian.(Ref(jlitp), idx, idx')) == - collect(hessian.(Ref(jlitp), jlidx, jlidx')) + collect(hessian.(Ref(jlitp), idx, idx')) == + collect(hessian.(Ref(jlitp), jlidx, jlidx')) sitp = scale(itp, A_x, A_x) jlsitp = jl(sitp) - idx = range(1., 39., 99) + idx = 1.0:0.4:39.0 jlidx = jl(collect(idx)) @test sitp.(idx, idx') == collect(jlsitp.(idx, idx')) == collect(jlsitp.(jlidx, jlidx')) @test gradient.(Ref(sitp), idx, idx') == - collect(gradient.(Ref(jlsitp), idx, idx')) == - collect(gradient.(Ref(jlsitp), jlidx, jlidx')) + collect(gradient.(Ref(jlsitp), idx, idx')) == + collect(gradient.(Ref(jlsitp), jlidx, jlidx')) @test hessian.(Ref(sitp), idx, idx') == - collect(hessian.(Ref(jlsitp), idx, idx')) == - collect(hessian.(Ref(jlsitp), jlidx, jlidx')) + collect(hessian.(Ref(jlsitp), idx, idx')) == + collect(hessian.(Ref(jlsitp), jlidx, jlidx')) esitp = extrapolate(sitp, Flat()) jlesitp = jl(esitp) - idx = range(-1., 41., 51) + idx = -1.0:0.84:41.0 jlidx = jl(collect(idx)) @test esitp.(idx, idx') == collect(jlesitp.(idx, idx')) == collect(jlesitp.(jlidx, jlidx')) # gradient for `extrapolation` is currently broken under CUDA @test gradient.(Ref(esitp), idx, idx') == - collect(gradient.(Ref(jlesitp), idx, idx')) == - collect(gradient.(Ref(jlesitp), jlidx, jlidx')) + collect(gradient.(Ref(jlesitp), idx, idx')) == + collect(gradient.(Ref(jlesitp), jlidx, jlidx')) end @testset "Lanczos on gpu" begin @@ -83,21 +83,24 @@ end end @testset "Gridded on gpu" begin - itp1 = interpolate(Vector.((range(-1,1,101), range(-1,1,101))), randn(101,101), Gridded(Linear())) - itp2 = interpolate((range(-1,1,101), range(-1,1,101)), randn(101,101), Gridded(Linear())) - idx = range(-1, 1, 301) + itp1 = interpolate(Vector.((-1.0:0.02:1.0, -1.0:0.02:1.0)), randn(101, 101), Gridded(Linear())) + itp2 = interpolate((-1.0:0.02:1.0, -1.0:0.02:1.0), randn(101, 101), Gridded(Linear())) + idx = -1.0:0.01:1.0 jlidx = jl(collect(idx)) @test itp1.(idx, idx') == collect(jl(itp1).(idx, idx')) == collect(jl(itp1).(jlidx, jlidx')) @test itp2.(idx, idx') == collect(jl(itp2).(idx, idx')) == collect(jl(itp2).(jlidx, jlidx')) end @testset "eltype after adaption" begin - A_x = 1.:2.:40. + A_x = 1.0:2.0:40.0 A = [log(x) for x in A_x] itp = interpolate(A, BSpline(Cubic(Line(OnGrid())))) @test eltype(adapt(Array{Float32}, itp)) === Float32 + @test eltype(adapt(Array{Real}, itp)) === Float64 @test eltype(adapt(Array{Float32}, scale(itp, A_x))) === Float32 @test eltype(adapt(Array{Float32}, extrapolate(scale(itp, A_x), Flat()))) === Float32 - itp = interpolate((range(-1,1,10), range(-1,1,10)), randn(10,10), Gridded(Linear())) + itp = interpolate((-1:0.2:1, -1:0.2:1), randn(11, 11), Gridded(Linear())) @test eltype(adapt(Array{Float32}, itp)) === Float32 + itp = interpolate((1.0:0.0, 1.:0.), randn(0, 0), Gridded(Linear())) + @test eltype(adapt(Array{Real}, itp)) isa DataType # we don't care the result end From 4ebd9754446ad39c8ea53df5a47276d87fb73fa6 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Sat, 30 Jul 2022 01:33:16 +0800 Subject: [PATCH 4/5] Fully drop the old indexing overload for `WeightedIndex`. and remove unneeded `@propagate_inbounds`. --- src/Interpolations.jl | 16 +++------------- test/core.jl | 13 +++++++------ 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/src/Interpolations.jl b/src/Interpolations.jl index e2662d3d..e176ca8d 100644 --- a/src/Interpolations.jl +++ b/src/Interpolations.jl @@ -269,26 +269,16 @@ Base.:(/)(wi::WeightedArbIndex, x::Number) = WeightedArbIndex(wi.indexes, wi.wei ### Indexing with WeightedIndex +# We inject `WeightedIndex` as a non-exported indexing point with a `InterpGetindex` wrapper. +# `InterpGetindex` is not a subtype of `AbstractArray`. This ensures that the overload applies to all array types. struct InterpGetindex{N,A<:AbstractArray{<:Any,N}} coeffs::A InterpGetindex(itp::AbstractInterpolation) = InterpGetindex(coefficients(itp)) InterpGetindex(A::AbstractArray) = new{ndims(A),typeof(A)}(A) end -Base.@propagate_inbounds Base.getindex(A::InterpGetindex{N}, I::Vararg{Union{Int,WeightedIndex},N}) where {N} = +@inline Base.getindex(A::InterpGetindex{N}, I::Vararg{Union{Int,WeightedIndex},N}) where {N} = interp_getindex(A.coeffs, I, ntuple(d->0, Val(N))...) -# TODO: should we drop the following injection? -# We inject indexing with `WeightedIndex` at a non-exported point in the dispatch heirarchy. -# This is to avoid ambiguities with methods that specialize on the array type rather than -# the index type. -Base.to_indices(A, I::Tuple{Vararg{Union{Int,WeightedIndex}}}) = I -if VERSION < v"1.6.0-DEV.104" - @propagate_inbounds Base._getindex(::IndexLinear, A::AbstractVector, i::Int) = getindex(A, i) # ambiguity resolution -end -@inline function Base._getindex(::IndexStyle, A::AbstractArray{T,N}, I::Vararg{Union{Int,WeightedIndex},N}) where {T,N} - InterpGetindex(A)[I...] -end - # The non-generated version is currently disabled due to https://github.com/JuliaLang/julia/issues/29117 # # This follows a "move processed indexes to the back" strategy, so J contains the yet-to-be-processed # # indexes and I all the processed indexes. diff --git a/test/core.jl b/test/core.jl index 48e598b3..b126c7a5 100644 --- a/test/core.jl +++ b/test/core.jl @@ -29,18 +29,19 @@ end @testset "Core" begin A = reshape([0], 1, 1, 1, 1, 1) + IA = Interpolations.InterpGetindex(A) wis = ntuple(d->Interpolations.WeightedAdjIndex(1, (1,)), ndims(A)) - @test @inferred(A[wis...]) === 0 + @test @inferred(IA[wis...]) === 0 wis = ntuple(d->Interpolations.WeightedAdjIndex(1, (1.0,)), ndims(A)) - @test @inferred(A[wis...]) === 0.0 + @test @inferred(IA[wis...]) === 0.0 wis = ntuple(d->Interpolations.WeightedArbIndex((1,), (1,)), ndims(A)) - @test @inferred(A[wis...]) === 0 + @test @inferred(IA[wis...]) === 0 wis = ntuple(d->Interpolations.WeightedArbIndex((1,), (1.0,)), ndims(A)) - @test @inferred(A[wis...]) === 0.0 + @test @inferred(IA[wis...]) === 0.0 wis = ntuple(d->Interpolations.WeightedArbIndex((1,1), (1,0)), ndims(A)) - @test @inferred(A[wis...]) === 0 + @test @inferred(IA[wis...]) === 0 wis = ntuple(d->Interpolations.WeightedArbIndex((1,1), (1.0,0.0)), ndims(A)) - @test @inferred(A[wis...]) === 0.0 + @test @inferred(IA[wis...]) === 0.0 wi = Interpolations.WeightedAdjIndex(2, (0.2, 0.8)) @test wi*2 === Interpolations.WeightedAdjIndex(2, (0.4, 1.6)) From 074110ea365fabf73f842507776ab08feb674e93 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Sat, 30 Jul 2022 04:32:47 +0800 Subject: [PATCH 5/5] Reenable non-generated `interp_getindex`. And fix the inference problem. --- src/Interpolations.jl | 79 ++++++++++--------------------------------- 1 file changed, 18 insertions(+), 61 deletions(-) diff --git a/src/Interpolations.jl b/src/Interpolations.jl index e176ca8d..c2ff615c 100644 --- a/src/Interpolations.jl +++ b/src/Interpolations.jl @@ -277,67 +277,24 @@ struct InterpGetindex{N,A<:AbstractArray{<:Any,N}} InterpGetindex(A::AbstractArray) = new{ndims(A),typeof(A)}(A) end @inline Base.getindex(A::InterpGetindex{N}, I::Vararg{Union{Int,WeightedIndex},N}) where {N} = - interp_getindex(A.coeffs, I, ntuple(d->0, Val(N))...) - -# The non-generated version is currently disabled due to https://github.com/JuliaLang/julia/issues/29117 -# # This follows a "move processed indexes to the back" strategy, so J contains the yet-to-be-processed -# # indexes and I all the processed indexes. -# interp_getindex(A::AbstractArray{T,N}, J::Tuple{Int,Vararg{Any,L}}, I::Vararg{Int,M}) where {T,N,L,M} = -# interp_getindex(A, Base.tail(J), I..., J[1]) -# function interp_getindex(A::AbstractArray{T,N}, J::Tuple{WeightedIndex,Vararg{Any,L}}, I::Vararg{Int,M}) where {T,N,L,M} -# wi = J[1] -# interp_getindex1(A, indexes(wi), weights(wi), Base.tail(J), I...) -# end -# interp_getindex(A::AbstractArray{T,N}, ::Tuple{}, I::Vararg{Int,N}) where {T,N} = # termination -# @inbounds A[I...] # all bounds-checks have already happened -# -# ## Handle expansion of a single dimension -# # version for WeightedAdjIndex -# @inline interp_getindex1(A, i::Int, weights::NTuple{K,Any}, rest, I::Vararg{Int,M}) where {M,K} = -# weights[1] * interp_getindex(A, rest, I..., i) + interp_getindex1(A, i+1, Base.tail(weights), rest, I...) -# @inline interp_getindex1(A, i::Int, weights::Tuple{Any}, rest, I::Vararg{Int,M}) where M = -# weights[1] * interp_getindex(A, rest, I..., i) -# interp_getindex1(A, i::Int, weights::Tuple{}, rest, I::Vararg{Int,M}) where M = -# error("exhausted the weights, this should never happen") -# -# # version for WeightedArbIndex -# @inline interp_getindex1(A, indexes::NTuple{K,Int}, weights::NTuple{K,Any}, rest, I::Vararg{Int,M}) where {M,K} = -# weights[1] * interp_getindex(A, rest, I..., indexes[1]) + interp_getindex1(A, Base.tail(indexes), Base.tail(weights), rest, I...) -# @inline interp_getindex1(A, indexes::Tuple{Int}, weights::Tuple{Any}, rest, I::Vararg{Int,M}) where M = -# weights[1] * interp_getindex(A, rest, I..., indexes[1]) -# interp_getindex1(A, indexes::Tuple{}, weights::Tuple{}, rest, I::Vararg{Int,M}) where M = -# error("exhausted the weights and indexes, this should never happen") - -@inline interp_getindex(A::AbstractArray{T,N}, J::Tuple{Int,Vararg{Any,K}}, I::Vararg{Int,N}) where {T,N,K} = - interp_getindex(A, Base.tail(J), Base.tail(I)..., J[1]) -@generated function interp_getindex(A::AbstractArray{T,N}, J::Tuple{WeightedAdjIndex{L,W},Vararg{Any,K}}, I::Vararg{Int,N}) where {T,N,K,L,W} - ex = :(w[1]*interp_getindex(A, Jtail, Itail..., j)) - for l = 2:L - ex = :(w[$l]*interp_getindex(A, Jtail, Itail..., j+$(l-1)) + $ex) - end - quote - $(Expr(:meta, :inline)) - Jtail = Base.tail(J) - Itail = Base.tail(I) - j, w = J[1].istart, J[1].weights - $ex - end -end -@generated function interp_getindex(A::AbstractArray{T,N}, J::Tuple{WeightedArbIndex{L,W},Vararg{Any,K}}, I::Vararg{Int,N}) where {T,N,K,L,W} - ex = :(w[1]*interp_getindex(A, Jtail, Itail..., ij[1])) - for l = 2:L - ex = :(w[$l]*interp_getindex(A, Jtail, Itail..., ij[$l]) + $ex) - end - quote - $(Expr(:meta, :inline)) - Jtail = Base.tail(J) - Itail = Base.tail(I) - ij, w = J[1].indexes, J[1].weights - $ex - end -end -@inline interp_getindex(A::AbstractArray{T,N}, ::Tuple{}, I::Vararg{Int,N}) where {T,N} = # termination - @inbounds A[I...] # all bounds-checks have already happened + interp_getindex(A.coeffs, ntuple(_ -> 0, Val(N)), map(indexflag, I)...) +indexflag(I::Int) = I +@inline indexflag(I::WeightedIndex) = indextuple(I), weights(I) + +# A recursion-based `interp_getindex`, which follows a "move processed indexes to the back" strategy +# `I` contains the processed index, and (wi1, wis...) contains the yet-to-be-processed indexes +# Here we meet a no-interp dim, just append the index to `I`'s end. +@inline interp_getindex(A, I, wi1::Int, wis...) = + interp_getindex(A, (Base.tail(I)..., wi1), wis...) +# Here we handle the expansion of a single dimension. +@inline interp_getindex(A, I, wi1::NTuple{2,Tuple{Any,Vararg{Any,N}}}, wis...) where {N} = + wi1[2][end] * interp_getindex(A, (Base.tail(I)..., wi1[1][end]), wis...) + + interp_getindex(A, I, map(Base.front, wi1), wis...) +@inline interp_getindex(A, I, wi1::NTuple{2,Tuple{Any}}, wis...) = + wi1[2][1] * interp_getindex(A, (Base.tail(I)..., wi1[1][1]), wis...) +# Termination +@inline interp_getindex(A::AbstractArray{T,N}, I::Dims{N}) where {T,N} = + @inbounds A[I...] # all bounds-checks have already happened """ w = value_weights(degree, δx)