diff --git a/Project.toml b/Project.toml index 496f350b..73d45bbc 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" version = "0.14.0" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" AxisAlgorithms = "13072b0f-2c55-5437-9ae7-d433b7a33950" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -31,6 +32,7 @@ DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -38,4 +40,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["OffsetArrays", "Unitful", "SharedArrays", "ForwardDiff", "LinearAlgebra", "DualNumbers", "Random", "Test", "Zygote", "ColorVectorSpace"] +test = ["OffsetArrays", "Unitful", "SharedArrays", "ForwardDiff", "LinearAlgebra", "DualNumbers", "Random", "Pkg", "Test", "Zygote", "ColorVectorSpace"] diff --git a/docs/src/devdocs.md b/docs/src/devdocs.md index 4d312d96..0f0aab1b 100644 --- a/docs/src/devdocs.md +++ b/docs/src/devdocs.md @@ -55,6 +55,8 @@ concrete subtypes: General `AbstractArray`s may be indexed with `WeightedIndex` indices, and the result produces the interpolated value. In other words, the end result is just `itp.coefs[wis...]`, where `wis` is a tuple of `WeightedIndex` indices. +To make sure this overloading is effective, we wrap the `coefs` with `InterpGetindex`, +i.e. `InterpGetindex(itp.coefs)[wis...]` Derivatives along a particular axis can be computed just by substituting a component of `wis` for one that has been designed to compute derivatives rather than values. @@ -107,7 +109,7 @@ Interpolations.WeightedAdjIndex{2, Float64}(1, (0.6000000000000001, 0.3999999999 julia> wis[3] Interpolations.WeightedAdjIndex{2, Float64}(1, (0.30000000000000004, 0.7)) -julia> A[wis...] +julia> Interpolations.InterpGetindex(A)[wis...] 8.7 ``` @@ -125,7 +127,7 @@ This computed the value of `itp` at `x...` because we called `weightedindexes` w !!! note Remember that prefiltering is not used for `Linear` interpolation. - In a case where prefiltering is used, we would substitute `itp.coefs[wis...]` for `A[wis...]` above. + In a case where prefiltering is used, we would substitute `InterpGetindex(itp.coefs)[wis...]` for `InterpGetindex(A)[wis...]` above. To compute derivatives, we *also* pass additional functions like [`Interpolations.gradient_weights`](@ref): @@ -142,13 +144,13 @@ julia> wis[2] julia> wis[3] (Interpolations.WeightedAdjIndex{2,Float64}(1, (0.8, 0.19999999999999996)), Interpolations.WeightedAdjIndex{2,Float64}(1, (0.6000000000000001, 0.3999999999999999)), Interpolations.WeightedAdjIndex{2,Float64}(1, (-1.0, 1.0))) -julia> A[wis[1]...] +julia> Interpolations.InterpGetindex(A)[wis[1]...] 1.0 -julia> A[wis[2]...] +julia> Interpolations.InterpGetindex(A)[wis[2]...] 3.000000000000001 -julia> A[wis[3]...] +julia> Interpolations.InterpGetindex(A)[wis[3]...] 9.0 ``` In this case you can see that `wis` is a 3-tuple-of-3-tuples. `A[wis[i]...]` can be used to compute the `i`th component of the gradient. @@ -168,3 +170,29 @@ The code to do this replacement is a bit complicated due to the need to support It makes good use of *tuple* manipulations, sometimes called "lispy tuple programming." You can search Julia's discourse forum for more tips about how to program this way. It could alternatively be done using generated functions, but this would increase compile time considerably and can lead to world-age problems. + +### GPU Support +At present, `Interpolations.jl` supports interpolant usage on GPU via broadcasting. + +A basic work flow looks like: +```julia +julia> using Interpolations, Adapt, CUDA # Or any other GPU package + +julia > itp = Interpolations.interpolate([1, 2, 3], (BSpline(Linear()))); # construct the interpolant object on CPU + +julia> cuitp = adapt(CuArray{Float32}, itp); # adapt it to GPU memory + +julia > cuitp.(1:0.5:2) # call interpolant object via broadcast + +julia> gradient.(Ref(cuitp), 1:0.5:2) +``` + +To achieve this, an `ITP <: AbstractInterpolation` should define it's own `Adapt.adapt_structure(to, itp::ITP)`, which constructs a new `ITP` with the adapted +fields (`adapt(to, itp.fieldname)`) of `itp`. The field adaption could be skipped +if we know that it has been GPU-compatable, e.g. a `isbit` range. + +!!! note + Some adaptors may change the storage type. Please ensure that the adapted `itp` + has the correct element type via the method `eltype`. + +Also, all GPU-compatable `AbstractInterpolation`s should define their own `Interpolations.root_storage_type`. This function allows us to modify the broadcast mechanism by overloading the default `BroadcastStyle`. See [Customizing broadcasting](https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting) for more details. diff --git a/src/Interpolations.jl b/src/Interpolations.jl index cb095f0f..c2ff615c 100644 --- a/src/Interpolations.jl +++ b/src/Interpolations.jl @@ -41,7 +41,7 @@ using StaticArrays, WoodburyMatrices, Ratios, AxisAlgorithms, OffsetArrays using ChainRulesCore, Requires using Base: @propagate_inbounds, HasEltype, EltypeUnknown, HasLength, IsInfinite, - SizeUnknown + SizeUnknown, Indices import Base: convert, size, axes, promote_rule, ndims, eltype, checkbounds, axes1, iterate, length, IteratorEltype, IteratorSize, firstindex, getindex, LogicalIndex @@ -269,76 +269,32 @@ Base.:(/)(wi::WeightedArbIndex, x::Number) = WeightedArbIndex(wi.indexes, wi.wei ### Indexing with WeightedIndex -# We inject indexing with `WeightedIndex` at a non-exported point in the dispatch heirarchy. -# This is to avoid ambiguities with methods that specialize on the array type rather than -# the index type. -Base.to_indices(A, I::Tuple{Vararg{Union{Int,WeightedIndex}}}) = I -if VERSION < v"1.6.0-DEV.104" - @propagate_inbounds Base._getindex(::IndexLinear, A::AbstractVector, i::Int) = getindex(A, i) # ambiguity resolution +# We inject `WeightedIndex` as a non-exported indexing point with a `InterpGetindex` wrapper. +# `InterpGetindex` is not a subtype of `AbstractArray`. This ensures that the overload applies to all array types. +struct InterpGetindex{N,A<:AbstractArray{<:Any,N}} + coeffs::A + InterpGetindex(itp::AbstractInterpolation) = InterpGetindex(coefficients(itp)) + InterpGetindex(A::AbstractArray) = new{ndims(A),typeof(A)}(A) end -@inline function Base._getindex(::IndexStyle, A::AbstractArray{T,N}, I::Vararg{Union{Int,WeightedIndex},N}) where {T,N} - interp_getindex(A, I, ntuple(d->0, Val(N))...) -end - -# The non-generated version is currently disabled due to https://github.com/JuliaLang/julia/issues/29117 -# # This follows a "move processed indexes to the back" strategy, so J contains the yet-to-be-processed -# # indexes and I all the processed indexes. -# interp_getindex(A::AbstractArray{T,N}, J::Tuple{Int,Vararg{Any,L}}, I::Vararg{Int,M}) where {T,N,L,M} = -# interp_getindex(A, Base.tail(J), I..., J[1]) -# function interp_getindex(A::AbstractArray{T,N}, J::Tuple{WeightedIndex,Vararg{Any,L}}, I::Vararg{Int,M}) where {T,N,L,M} -# wi = J[1] -# interp_getindex1(A, indexes(wi), weights(wi), Base.tail(J), I...) -# end -# interp_getindex(A::AbstractArray{T,N}, ::Tuple{}, I::Vararg{Int,N}) where {T,N} = # termination -# @inbounds A[I...] # all bounds-checks have already happened -# -# ## Handle expansion of a single dimension -# # version for WeightedAdjIndex -# @inline interp_getindex1(A, i::Int, weights::NTuple{K,Any}, rest, I::Vararg{Int,M}) where {M,K} = -# weights[1] * interp_getindex(A, rest, I..., i) + interp_getindex1(A, i+1, Base.tail(weights), rest, I...) -# @inline interp_getindex1(A, i::Int, weights::Tuple{Any}, rest, I::Vararg{Int,M}) where M = -# weights[1] * interp_getindex(A, rest, I..., i) -# interp_getindex1(A, i::Int, weights::Tuple{}, rest, I::Vararg{Int,M}) where M = -# error("exhausted the weights, this should never happen") -# -# # version for WeightedArbIndex -# @inline interp_getindex1(A, indexes::NTuple{K,Int}, weights::NTuple{K,Any}, rest, I::Vararg{Int,M}) where {M,K} = -# weights[1] * interp_getindex(A, rest, I..., indexes[1]) + interp_getindex1(A, Base.tail(indexes), Base.tail(weights), rest, I...) -# @inline interp_getindex1(A, indexes::Tuple{Int}, weights::Tuple{Any}, rest, I::Vararg{Int,M}) where M = -# weights[1] * interp_getindex(A, rest, I..., indexes[1]) -# interp_getindex1(A, indexes::Tuple{}, weights::Tuple{}, rest, I::Vararg{Int,M}) where M = -# error("exhausted the weights and indexes, this should never happen") - -@inline interp_getindex(A::AbstractArray{T,N}, J::Tuple{Int,Vararg{Any,K}}, I::Vararg{Int,N}) where {T,N,K} = - interp_getindex(A, Base.tail(J), Base.tail(I)..., J[1]) -@generated function interp_getindex(A::AbstractArray{T,N}, J::Tuple{WeightedAdjIndex{L,W},Vararg{Any,K}}, I::Vararg{Int,N}) where {T,N,K,L,W} - ex = :(w[1]*interp_getindex(A, Jtail, Itail..., j)) - for l = 2:L - ex = :(w[$l]*interp_getindex(A, Jtail, Itail..., j+$(l-1)) + $ex) - end - quote - $(Expr(:meta, :inline)) - Jtail = Base.tail(J) - Itail = Base.tail(I) - j, w = J[1].istart, J[1].weights - $ex - end -end -@generated function interp_getindex(A::AbstractArray{T,N}, J::Tuple{WeightedArbIndex{L,W},Vararg{Any,K}}, I::Vararg{Int,N}) where {T,N,K,L,W} - ex = :(w[1]*interp_getindex(A, Jtail, Itail..., ij[1])) - for l = 2:L - ex = :(w[$l]*interp_getindex(A, Jtail, Itail..., ij[$l]) + $ex) - end - quote - $(Expr(:meta, :inline)) - Jtail = Base.tail(J) - Itail = Base.tail(I) - ij, w = J[1].indexes, J[1].weights - $ex - end -end -@inline interp_getindex(A::AbstractArray{T,N}, ::Tuple{}, I::Vararg{Int,N}) where {T,N} = # termination - @inbounds A[I...] # all bounds-checks have already happened +@inline Base.getindex(A::InterpGetindex{N}, I::Vararg{Union{Int,WeightedIndex},N}) where {N} = + interp_getindex(A.coeffs, ntuple(_ -> 0, Val(N)), map(indexflag, I)...) +indexflag(I::Int) = I +@inline indexflag(I::WeightedIndex) = indextuple(I), weights(I) + +# A recursion-based `interp_getindex`, which follows a "move processed indexes to the back" strategy +# `I` contains the processed index, and (wi1, wis...) contains the yet-to-be-processed indexes +# Here we meet a no-interp dim, just append the index to `I`'s end. +@inline interp_getindex(A, I, wi1::Int, wis...) = + interp_getindex(A, (Base.tail(I)..., wi1), wis...) +# Here we handle the expansion of a single dimension. +@inline interp_getindex(A, I, wi1::NTuple{2,Tuple{Any,Vararg{Any,N}}}, wis...) where {N} = + wi1[2][end] * interp_getindex(A, (Base.tail(I)..., wi1[1][end]), wis...) + + interp_getindex(A, I, map(Base.front, wi1), wis...) +@inline interp_getindex(A, I, wi1::NTuple{2,Tuple{Any}}, wis...) = + wi1[2][1] * interp_getindex(A, (Base.tail(I)..., wi1[1][1]), wis...) +# Termination +@inline interp_getindex(A::AbstractArray{T,N}, I::Dims{N}) where {T,N} = + @inbounds A[I...] # all bounds-checks have already happened """ w = value_weights(degree, δx) @@ -489,6 +445,9 @@ include("lanczos/lanczos_opencv.jl") include("iterate.jl") include("chainrules/chainrules.jl") include("hermite/cubic.jl") +if VERSION >= v"1.6" + include("gpu_support.jl") +end function __init__() @require Unitful="1986cc42-f94f-5a68-af5c-568840ba703d" include("requires/unitful.jl") diff --git a/src/b-splines/b-splines.jl b/src/b-splines/b-splines.jl index bdb0476d..209e4fb3 100644 --- a/src/b-splines/b-splines.jl +++ b/src/b-splines/b-splines.jl @@ -66,7 +66,7 @@ However, for customized control you may also construct them with where `T` gets computed from the product of `TWeights` and `eltype(coefs)`. (This is equivalent to indicating that you'll be evaluating at locations `itp(x::TWeights, y::TWeights, ...)`.) """ -struct BSplineInterpolation{T,N,TCoefs<:AbstractArray,IT<:DimSpec{BSpline},Axs<:Tuple{Vararg{AbstractUnitRange,N}}} <: AbstractInterpolation{T,N,IT} +struct BSplineInterpolation{T,N,TCoefs<:AbstractArray,IT<:DimSpec{BSpline},Axs<:Indices{N}} <: AbstractInterpolation{T,N,IT} coefs::TCoefs parentaxes::Axs it::IT @@ -78,6 +78,9 @@ function Base.:(==)(o1::BSplineInterpolation, o2::BSplineInterpolation) o1.coefs == o2.coefs end +BSplineInterpolation{T,N}(A::AbstractArray, axs::Indices{N}, it::IT) where {T,N,IT} = + BSplineInterpolation{T,N,typeof(A),IT,typeof(axs)}(A, axs, it) + function BSplineInterpolation(::Type{TWeights}, A::AbstractArray{Tel,N}, it::IT, axs) where {N,Tel,TWeights<:Real,IT<:DimSpec{BSpline}} # String interpolation causes allocation, noinline avoids that unless they get called @noinline err_concrete(IT) = error("The b-spline type must be a concrete type (was $IT)") @@ -98,7 +101,7 @@ function BSplineInterpolation(::Type{TWeights}, A::AbstractArray{Tel,N}, it::IT, else T = typeof(zero(TWeights) * first(A)) end - BSplineInterpolation{T,N,typeof(A),IT,typeof(axs)}(A, fix_axis.(axs), it) + BSplineInterpolation{T,N}(A, fix_axis.(axs), it) end function BSplineInterpolation(A::AbstractArray{Tel,N}, it::IT, axs) where {N,Tel,IT<:DimSpec{BSpline}} diff --git a/src/b-splines/indexing.jl b/src/b-splines/indexing.jl index 4ab7a381..26f16c95 100644 --- a/src/b-splines/indexing.jl +++ b/src/b-splines/indexing.jl @@ -5,7 +5,7 @@ itpinfo(itp) = (tcollect(itpflag, itp), axes(itp)) @inline function (itp::BSplineInterpolation{T,N})(x::Vararg{Number,N}) where {T,N} @boundscheck (checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x)) wis = weightedindexes((value_weights,), itpinfo(itp)..., x) - itp.coefs[wis...] + InterpGetindex(itp)[wis...] end @propagate_inbounds function (itp::BSplineInterpolation{T,N})(x::Vararg{Number,M}) where {T,M,N} inds, trailing = split_trailing(itp, x) @@ -17,7 +17,7 @@ end @boundscheck (checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x)) itps = tcollect(itpflag, itp) wis = dimension_wis(value_weights, itps, axes(itp), x) - coefs = coefficients(itp) + coefs = InterpGetindex(itp) ret = [coefs[i...] for i in Iterators.product(wis...)] reshape(ret, shape(wis...)) end @@ -25,7 +25,7 @@ end @propagate_inbounds function gradient(itp::BSplineInterpolation{T,N}, x::Vararg{Number,N}) where {T,N} @boundscheck checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x) wis = weightedindexes((value_weights, gradient_weights), itpinfo(itp)..., x) - return SVector(_gradient(itp.coefs, wis...)) # work around #311 + return SVector(_gradient(InterpGetindex(itp), wis...)) # work around #311 end @inline _gradient(coefs, inds, moreinds...) = (coefs[inds...], _gradient(coefs, moreinds...)...) _gradient(coefs) = () @@ -37,7 +37,7 @@ end @propagate_inbounds function hessian(itp::BSplineInterpolation{T,N}, x::Vararg{Number,N}) where {T,N} @boundscheck checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x) wis = weightedindexes((value_weights, gradient_weights, hessian_weights), itpinfo(itp)..., x) - symmatrix(map(inds->itp.coefs[inds...], wis)) + symmatrix(map(inds->InterpGetindex(itp)[inds...], wis)) end @propagate_inbounds function hessian!(dest, itp::BSplineInterpolation{T,N}, x::Vararg{Number,N}) where {T,N} dest .= hessian(itp, x...) diff --git a/src/extrapolation/extrapolation.jl b/src/extrapolation/extrapolation.jl index 43b5edc3..1eb772a4 100644 --- a/src/extrapolation/extrapolation.jl +++ b/src/extrapolation/extrapolation.jl @@ -71,8 +71,8 @@ end eflag = tcollect(etpflag, etp) xs = inbounds_position(eflag, bounds(itp), x, etp, x) g = @inbounds gradient(itp, xs...) - skipni = t->skip_flagged_nointerp(itp, t) - SVector(extrapolate_gradient.(skipni(eflag), skipni(x), skipni(xs), Tuple(g))) + skipni = Base.Fix1(skip_flagged_nointerp, itp) + SVector(map(extrapolate_gradient, skipni(eflag), skipni(x), skipni(xs), Tuple(g))) end end diff --git a/src/gpu_support.jl b/src/gpu_support.jl new file mode 100644 index 00000000..4990f5bd --- /dev/null +++ b/src/gpu_support.jl @@ -0,0 +1,72 @@ +import Adapt: adapt_structure +using Adapt: adapt + +function adapt_structure(to, itp::BSplineInterpolation{T,N}) where {T,N} + coefs′ = adapt(to, itp.coefs) + T′ = update_eltype(T, coefs′, itp.coefs) + BSplineInterpolation{T′,N}(coefs′, itp.parentaxes, itp.it) +end + +function update_eltype(T, coefs′, coefs) + ET = eltype(coefs′) + ET === eltype(coefs) && return T + WT = tweight(coefs′) + T′ = Base.promote_op(*, WT, ET) + (isconcretetype(T′) || isempty(coefs)) && return T′ + return typeof(zero(WT) * convert(ET, first(coefs))) +end + +function adapt_structure(to, itp::LanczosInterpolation{T,N}) where {T,N} + coefs′ = adapt(to, itp.coefs) + parentaxes′ = adapt(to, itp.parentaxes) + LanczosInterpolation{eltype(coefs′),N}(coefs′, parentaxes′, itp.it) +end + +function adapt_structure(to, itp::GriddedInterpolation{T,N}) where {T,N} + coefs′ = adapt(to, itp.coefs) + knots′ = adapt(to, itp.knots) + T′ = update_eltype(T, coefs′, itp.coefs) + GriddedInterpolation{T′,N,typeof(coefs′),itptype(itp),typeof(knots′)}(knots′, coefs′, itp.it) +end + +function adapt_structure(to, itp::ScaledInterpolation{T,N,<:Any,IT,RT}) where {T,N,IT,RT<:NTuple{N,AbstractRange}} + ranges = itp.ranges + itp′ = adapt(to, itp.itp) + ScaledInterpolation{eltype(itp′),N,typeof(itp′),IT,RT}(itp′, ranges) +end + +function adapt_structure(to, itp::Extrapolation{T,N}) where {T,N} + et = itp.et + itp′ = adapt(to, itp.itp) + Extrapolation{eltype(itp′),N,typeof(itp′),itptype(itp),typeof(et)}(itp′, et) +end + +import Base.Broadcast: broadcasted, BroadcastStyle +using Base.Broadcast: broadcastable, combine_styles, AbstractArrayStyle +function broadcasted(itp::AbstractInterpolation, args...) + args′ = map(broadcastable, args) + # we overload BroadcastStyle here (try our best to do broadcast on GPU) + style = combine_styles(Ref(itp), args′...) + broadcasted(style, itp, args′...) +end + +""" + Interpolations.root_storage_type(::Type{<:AbstractInterpolation}) -> Type{<:AbstractArray} + +This function returns the type of the root cofficients array of an `AbstractInterpolation`. +Some array wrappers, like `OffsetArray`, should be skipped. +""" +root_storage_type(::Type{T}) where {T<:Extrapolation} = root_storage_type(fieldtype(T, 1)) +root_storage_type(::Type{T}) where {T<:ScaledInterpolation} = root_storage_type(fieldtype(T, 1)) +root_storage_type(::Type{T}) where {T<:BSplineInterpolation} = root_storage_type(fieldtype(T, 1)) +root_storage_type(::Type{T}) where {T<:LanczosInterpolation} = root_storage_type(fieldtype(T, 1)) +root_storage_type(::Type{T}) where {T<:GriddedInterpolation} = root_storage_type(fieldtype(T, 2)) +root_storage_type(::Type{T}) where {T<:OffsetArray} = root_storage_type(fieldtype(T, 1)) +root_storage_type(::Type{T}) where {T<:AbstractArray} = T + +BroadcastStyle(::Type{<:Ref{T}}) where {T<:AbstractInterpolation} = _to_scalar_style(BroadcastStyle(T)) +BroadcastStyle(::Type{T}) where {T<:AbstractInterpolation} = BroadcastStyle(root_storage_type(T)) + +_to_scalar_style(::S) where {S<:AbstractArrayStyle} = S(Val(0)) +_to_scalar_style(S::AbstractArrayStyle{Any}) = S +_to_scalar_style(S) = S diff --git a/src/gridded/indexing.jl b/src/gridded/indexing.jl index aa685120..43664b72 100644 --- a/src/gridded/indexing.jl +++ b/src/gridded/indexing.jl @@ -2,7 +2,7 @@ @inline function (itp::GriddedInterpolation{T,N})(x::Vararg{Number,N}) where {T,N} @boundscheck (checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x)) wis = weightedindexes((value_weights,), itpinfo(itp)..., x) - coefficients(itp)[wis...] + InterpGetindex(itp)[wis...] end @propagate_inbounds function (itp::GriddedInterpolation{T,N})(x::Vararg{Number,M}) where {T,M,N} inds, trailing = split_trailing(itp, x) @@ -19,7 +19,7 @@ end @inline function gradient(itp::GriddedInterpolation{T,N}, x::Vararg{Number,N}) where {T,N} @boundscheck (checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x)) wis = weightedindexes((value_weights, gradient_weights), itpinfo(itp)..., x) - SVector(map(inds->coefficients(itp)[inds...], wis)) + SVector(map(inds->InterpGetindex(itp)[inds...], wis)) end itpinfo(itp::GriddedInterpolation) = (tcollect(itpflag, itp), itp.knots) @@ -87,7 +87,7 @@ rescale_gridded(::typeof(hessian_weights), coefs, Δx) = coefs./Δx.^2 else wis = dimension_wis(value_weights, itps, itp.knots, x) end - coefs = coefficients(itp) + coefs = InterpGetindex(itp) ret = [coefs[i...] for i in Iterators.product(wis...)] reshape(ret, shape(wis...)) end diff --git a/src/lanczos/lanczos.jl b/src/lanczos/lanczos.jl index 57f6940a..62d0ed55 100644 --- a/src/lanczos/lanczos.jl +++ b/src/lanczos/lanczos.jl @@ -29,6 +29,9 @@ struct LanczosInterpolation{T,N,IT <: DimSpec{AbstractLanczos},A <: AbstractArra it::IT end +LanczosInterpolation{T,N}(coefs::AbstractArray{T,N}, parentaxes::NTuple{N,AbstractArray}, it::IT) where {T,N,IT} = + LanczosInterpolation{T,N,IT,typeof(coefs),typeof(parentaxes)}(coefs, parentaxes, it) + @generated degree(::Lanczos{N}) where {N} = :($N) getknots(itp::LanczosInterpolation) = axes(itp) @@ -48,7 +51,7 @@ end @inline function (itp::LanczosInterpolation{T,N})(x::Vararg{Number,N}) where {T,N} @boundscheck (checkbounds(Bool, itp, x...) || Base.throw_boundserror(itp, x)) wis = weightedindexes((value_weights,), itpinfo(itp)..., x) - itp.coefs[wis...] + InterpGetindex(itp)[wis...] end function weightedindex_parts(fs, it::AbstractLanczos, ax::AbstractUnitRange{<:Integer}, x) diff --git a/src/scaling/scaling.jl b/src/scaling/scaling.jl index d94ffafb..8d0cfb49 100644 --- a/src/scaling/scaling.jl +++ b/src/scaling/scaling.jl @@ -209,7 +209,7 @@ function Base.iterate(iter::ScaledIterator) ret === nothing && return nothing item, cistate = ret wis = getindex.(iter.wis, Tuple(item)) - ces = cache_evaluations(iter.sitp.itp.coefs, indexes(wis[1]), weights(wis[1]), Base.tail(wis)) + ces = cache_evaluations(InterpGetindex(iter.sitp), indexes(wis[1]), weights(wis[1]), Base.tail(wis)) return _reduce(+, weights(wis[1]).*ces), ScaledIterState(cistate, firstindex(iter.breaks1), ces) end @@ -226,7 +226,7 @@ function Base.iterate(iter::ScaledIterator, state) end # Re-evaluate. We're being a bit lazy here: in some cases, some of the cached values could be reused wis = getindex.(iter.wis, Tuple(item)) - ces = cache_evaluations(iter.sitp.itp.coefs, indexes(wis[1]), weights(wis[1]), Base.tail(wis)) + ces = cache_evaluations(InterpGetindex(iter.sitp), indexes(wis[1]), weights(wis[1]), Base.tail(wis)) return _reduce(+, weights(wis[1]).*ces), ScaledIterState(cistate, isnext1 ? state.ibreak+1 : firstindex(iter.breaks1), ces) end diff --git a/test/core.jl b/test/core.jl index 48e598b3..b126c7a5 100644 --- a/test/core.jl +++ b/test/core.jl @@ -29,18 +29,19 @@ end @testset "Core" begin A = reshape([0], 1, 1, 1, 1, 1) + IA = Interpolations.InterpGetindex(A) wis = ntuple(d->Interpolations.WeightedAdjIndex(1, (1,)), ndims(A)) - @test @inferred(A[wis...]) === 0 + @test @inferred(IA[wis...]) === 0 wis = ntuple(d->Interpolations.WeightedAdjIndex(1, (1.0,)), ndims(A)) - @test @inferred(A[wis...]) === 0.0 + @test @inferred(IA[wis...]) === 0.0 wis = ntuple(d->Interpolations.WeightedArbIndex((1,), (1,)), ndims(A)) - @test @inferred(A[wis...]) === 0 + @test @inferred(IA[wis...]) === 0 wis = ntuple(d->Interpolations.WeightedArbIndex((1,), (1.0,)), ndims(A)) - @test @inferred(A[wis...]) === 0.0 + @test @inferred(IA[wis...]) === 0.0 wis = ntuple(d->Interpolations.WeightedArbIndex((1,1), (1,0)), ndims(A)) - @test @inferred(A[wis...]) === 0 + @test @inferred(IA[wis...]) === 0 wis = ntuple(d->Interpolations.WeightedArbIndex((1,1), (1.0,0.0)), ndims(A)) - @test @inferred(A[wis...]) === 0.0 + @test @inferred(IA[wis...]) === 0.0 wi = Interpolations.WeightedAdjIndex(2, (0.2, 0.8)) @test wi*2 === Interpolations.WeightedAdjIndex(2, (0.4, 1.6)) diff --git a/test/gpu_support.jl b/test/gpu_support.jl new file mode 100644 index 00000000..c252a811 --- /dev/null +++ b/test/gpu_support.jl @@ -0,0 +1,106 @@ +using JLArrays, Adapt +JLArrays.allowscalar(false) + +@testset "1d GPU Interpolation" begin + A_x = 1.0:2.0:40.0 + A = [log(x) for x in A_x] + itp = interpolate(A, BSpline(Cubic(Line(OnGrid())))) + jlitp = jl(itp) + idx = 2.0:0.17:19.0 + jlidx = jl(collect(idx)) + @test itp.(idx) == collect(jlitp.(idx)) == collect(jlitp.(jlidx)) + @test gradient.(Ref(itp), idx) == + collect(gradient.(Ref(jlitp), idx)) == + collect(gradient.(Ref(jlitp), jlidx)) + + sitp = scale(itp, A_x) + jlsitp = jl(sitp) + idx = 1.0:0.4:39.0 + jlidx = jl(collect(idx)) + @test sitp.(idx) == collect(jlsitp.(idx)) == collect(jlsitp.(jlidx)) + @test gradient.(Ref(sitp), idx) == + collect(gradient.(Ref(jlsitp), idx)) == + collect(gradient.(Ref(jlsitp), jlidx)) + + + esitp = extrapolate(sitp, Flat()) + jlesitp = jl(esitp) + idx = -1.0:0.84:41.0 + jlidx = jl(collect(idx)) + @test esitp.(idx) == collect(jlesitp.(idx)) == collect(jlesitp.(jlidx)) + @test gradient.(Ref(esitp), idx) == + collect(gradient.(Ref(jlesitp), idx)) == + collect(gradient.(Ref(jlesitp), jlidx)) +end + +@testset "2d GPU Interpolation" begin + A_x = 1.0:2.0:40.0 + A = [log(x + y) for x in A_x, y in 1.0:2.0:40.0] + itp = interpolate(A, (BSpline(Cubic(Line(OnGrid()))), BSpline(Linear()))) + jlitp = jl(itp) + idx = 2.0:0.17:19.0 + jlidx = jl(collect(idx)) + @test itp.(idx, idx') == collect(jlitp.(idx, idx')) == collect(jlitp.(jlidx, jlidx')) + @test gradient.(Ref(itp), idx, idx') == + collect(gradient.(Ref(jlitp), idx, idx')) == + collect(gradient.(Ref(jlitp), jlidx, jlidx')) + @test hessian.(Ref(itp), idx, idx') == + collect(hessian.(Ref(jlitp), idx, idx')) == + collect(hessian.(Ref(jlitp), jlidx, jlidx')) + + sitp = scale(itp, A_x, A_x) + jlsitp = jl(sitp) + idx = 1.0:0.4:39.0 + jlidx = jl(collect(idx)) + @test sitp.(idx, idx') == collect(jlsitp.(idx, idx')) == collect(jlsitp.(jlidx, jlidx')) + @test gradient.(Ref(sitp), idx, idx') == + collect(gradient.(Ref(jlsitp), idx, idx')) == + collect(gradient.(Ref(jlsitp), jlidx, jlidx')) + @test hessian.(Ref(sitp), idx, idx') == + collect(hessian.(Ref(jlsitp), idx, idx')) == + collect(hessian.(Ref(jlsitp), jlidx, jlidx')) + + esitp = extrapolate(sitp, Flat()) + jlesitp = jl(esitp) + idx = -1.0:0.84:41.0 + jlidx = jl(collect(idx)) + @test esitp.(idx, idx') == collect(jlesitp.(idx, idx')) == collect(jlesitp.(jlidx, jlidx')) + # gradient for `extrapolation` is currently broken under CUDA + @test gradient.(Ref(esitp), idx, idx') == + collect(gradient.(Ref(jlesitp), idx, idx')) == + collect(gradient.(Ref(jlesitp), jlidx, jlidx')) +end + +@testset "Lanczos on gpu" begin + X = 1:100 + X = [X; reverse(X)[2:end]] + for N = 2:4 + itp = interpolate(X, Lanczos(N)) + @test itp.(X) == collect(jl(itp).(jl(X))) + end + itp = interpolate(X, Lanczos4OpenCV()) + @test itp.(X) == collect(jl(itp).(jl(X))) +end + +@testset "Gridded on gpu" begin + itp1 = interpolate(Vector.((-1.0:0.02:1.0, -1.0:0.02:1.0)), randn(101, 101), Gridded(Linear())) + itp2 = interpolate((-1.0:0.02:1.0, -1.0:0.02:1.0), randn(101, 101), Gridded(Linear())) + idx = -1.0:0.01:1.0 + jlidx = jl(collect(idx)) + @test itp1.(idx, idx') == collect(jl(itp1).(idx, idx')) == collect(jl(itp1).(jlidx, jlidx')) + @test itp2.(idx, idx') == collect(jl(itp2).(idx, idx')) == collect(jl(itp2).(jlidx, jlidx')) +end + +@testset "eltype after adaption" begin + A_x = 1.0:2.0:40.0 + A = [log(x) for x in A_x] + itp = interpolate(A, BSpline(Cubic(Line(OnGrid())))) + @test eltype(adapt(Array{Float32}, itp)) === Float32 + @test eltype(adapt(Array{Real}, itp)) === Float64 + @test eltype(adapt(Array{Float32}, scale(itp, A_x))) === Float32 + @test eltype(adapt(Array{Float32}, extrapolate(scale(itp, A_x), Flat()))) === Float32 + itp = interpolate((-1:0.2:1, -1:0.2:1), randn(11, 11), Gridded(Linear())) + @test eltype(adapt(Array{Float32}, itp)) === Float32 + itp = interpolate((1.0:0.0, 1.:0.), randn(0, 0), Gridded(Linear())) + @test eltype(adapt(Array{Real}, itp)) isa DataType # we don't care the result +end diff --git a/test/runtests.jl b/test/runtests.jl index 718d0c66..bd312092 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,4 +53,10 @@ const isci = get(ENV, "CI", "") in ("true", "True") # Chain rules interaction include("chainrules.jl") + + if VERSION >= v"1.6" + import Pkg + Pkg.add("JLArrays") + include("gpu_support.jl") + end end