diff --git a/.github/workflows/IntegrationTestRequest.yml b/.github/workflows/IntegrationTestRequest.yml new file mode 100644 index 0000000..d42fcca --- /dev/null +++ b/.github/workflows/IntegrationTestRequest.yml @@ -0,0 +1,14 @@ +name: "Integration Test Request" + +on: + issue_comment: + types: [created] + +jobs: + integrationrequest: + if: | + github.event.issue.pull_request && + contains(fromJSON('["OWNER", "COLLABORATOR", "MEMBER"]'), github.event.comment.author_association) + uses: ITensor/ITensorActions/.github/workflows/IntegrationTestRequest.yml@main + with: + localregistry: https://github.com/ITensor/ITensorRegistry.git diff --git a/Project.toml b/Project.toml index 55cf1b9..d1ca535 100644 --- a/Project.toml +++ b/Project.toml @@ -5,19 +5,17 @@ version = "0.3.15" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ExproniconLite = "55351af7-c7e9-48d6-89ff-24e801d99491" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" -MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" [compat] Adapt = "4.1.1" -ArrayLayouts = "1.11.0" +Compat = "3.47,4.10" ExproniconLite = "0.10.13" LinearAlgebra = "1.10" MLStyle = "0.4.17" -MapBroadcast = "0.1.5" TypeParameterAccessors = "0.2, 0.3" julia = "1.10" diff --git a/docs/make.jl b/docs/make.jl index 5655860..8ba16e2 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -16,7 +16,7 @@ makedocs(; edit_link="main", assets=String[], ), - pages=["Home" => "index.md"], + pages=["Home" => "index.md", "Reference" => "reference.md"], ) deploydocs(; diff --git a/docs/src/reference.md b/docs/src/reference.md new file mode 100644 index 0000000..7a417ab --- /dev/null +++ b/docs/src/reference.md @@ -0,0 +1,5 @@ +# Reference + +```@autodocs +Modules = [DerivableInterfaces] +``` diff --git a/src/DerivableInterfaces.jl b/src/DerivableInterfaces.jl index 42ac21f..5ba6f2c 100644 --- a/src/DerivableInterfaces.jl +++ b/src/DerivableInterfaces.jl @@ -1,12 +1,19 @@ module DerivableInterfaces +export @derive, @interface +export interface, AbstractInterface, AbstractArrayInterface +export zero! + include("interface_function.jl") include("abstractinterface.jl") include("derive_macro.jl") include("interface_macro.jl") include("wrappedarrays.jl") -include("abstractarrayinterface.jl") -include("defaultarrayinterface.jl") -include("traits.jl") +include("arrayinterface.jl") + +# Specific AbstractArray alternatives + +include("concatenate.jl") +include("zero.jl") end diff --git a/src/abstractarrayinterface.jl b/src/abstractarrayinterface.jl deleted file mode 100644 index 7ea82a5..0000000 --- a/src/abstractarrayinterface.jl +++ /dev/null @@ -1,351 +0,0 @@ -# TODO: Add `ndims` type parameter. -abstract type AbstractArrayInterface <: AbstractInterface end - -function interface(::Type{<:Broadcast.AbstractArrayStyle}) - return DefaultArrayInterface() -end - -function interface(::Type{<:Broadcast.Broadcasted{Nothing}}) - return DefaultArrayInterface() -end - -function interface(::Type{<:Broadcast.Broadcasted{<:Style}}) where {Style} - return interface(Style) -end - -# TODO: Define as `Array{T}`. -arraytype(::AbstractArrayInterface, T::Type) = error("Not implemented.") - -using ArrayLayouts: ArrayLayouts - -@interface ::AbstractArrayInterface function Base.getindex(a::AbstractArray, I...) - return ArrayLayouts.layout_getindex(a, I...) -end - -@interface interface::AbstractArrayInterface function Base.setindex!( - a::AbstractArray, value, I... -) - # TODO: Change to this once broadcasting in `@interface` is supported: - # @interface interface a[I...] .= value - @interface interface map!(identity, @view(a[I...]), value) - return a -end - -# TODO: Maybe define as `ArrayLayouts.layout_getindex(a, I...)` or -# `invoke(getindex, Tuple{AbstractArray,Vararg{Any}}, a, I...)`. -# TODO: Use `MethodError`? -@interface ::AbstractArrayInterface function Base.getindex( - a::AbstractArray{<:Any,N}, I::Vararg{Int,N} -) where {N} - return error("Not implemented.") -end - -# TODO: Make this more general, use `Base.to_index`. -@interface interface::AbstractArrayInterface function Base.getindex( - a::AbstractArray{<:Any,N}, I::CartesianIndex{N} -) where {N} - return @interface interface getindex(a, Tuple(I)...) -end - -# Linear indexing. -@interface interface ::AbstractArrayInterface function Base.getindex( - a::AbstractArray, I::Int -) - return @interface interface getindex(a, CartesianIndices(a)[I]) -end - -# TODO: Use `MethodError`? -@interface ::AbstractArrayInterface function Base.setindex!( - a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N} -) where {N} - return error("Not implemented.") -end - -# Linear indexing. -@interface interface ::AbstractArrayInterface function Base.setindex!( - a::AbstractArray, value, I::Int -) - return @interface interface setindex!(a, value, CartesianIndices(a)[I]) -end - -# TODO: Make this more general, use `Base.to_index`. -@interface interface::AbstractArrayInterface function Base.setindex!( - a::AbstractArray{<:Any,N}, value, I::CartesianIndex{N} -) where {N} - return @interface interface setindex!(a, value, Tuple(I)...) -end - -@interface ::AbstractArrayInterface function Broadcast.BroadcastStyle(type::Type) - return Broadcast.DefaultArrayStyle{ndims(type)}() -end - -# TODO: Maybe define as `Array{T}(undef, size...)` or -# `invoke(Base.similar, Tuple{AbstractArray,Type,Vararg{Int}}, a, T, size)`. -# TODO: Use `MethodError`? -@interface interface::AbstractArrayInterface function Base.similar( - a::AbstractArray, T::Type, size::Tuple{Vararg{Int}} -) - return similar(arraytype(interface, T), size) -end - -@interface ::AbstractArrayInterface function Base.copy(a::AbstractArray) - a_dest = similar(a) - return a_dest .= a -end - -# TODO: Use `Base.to_shape(axes)` or -# `Base.invoke(similar, Tuple{AbstractArray,Type,Tuple{Union{Integer,Base.OneTo},Vararg{Union{Integer,Base.OneTo}}}}, a, T, axes)`. -# TODO: Make this more general, handle mixtures of integers and ranges (`Union{Integer,Base.OneTo}`). -@interface interface::AbstractArrayInterface function Base.similar( - a::AbstractArray, T::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}} -) - return @interface interface similar(a, T, Base.to_shape(axes)) -end - -@interface interface::AbstractArrayInterface function Base.similar( - bc::Broadcast.Broadcasted, T::Type, axes::Tuple -) - # `arraytype(::AbstractInterface)` determines the default array type associated with the interface. - return similar(arraytype(interface, T), axes) -end - -using MapBroadcast: Mapped -# TODO: Turn this into an `@interface AbstractArrayInterface` function? -# TODO: Look into `SparseArrays.capturescalars`: -# https://github.com/JuliaSparse/SparseArrays.jl/blob/1beb0e4a4618b0399907b0000c43d9f66d34accc/src/higherorderfns.jl#L1092-L1102 -@interface interface::AbstractArrayInterface function Base.copyto!( - a_dest::AbstractArray, bc::Broadcast.Broadcasted -) - m = Mapped(bc) - return @interface interface map!(m.f, a_dest, m.args...) -end - -# This captures broadcast expressions such as `a .= 2`. -# Ideally this would be handled by `map!(f, a_dest)` but that isn't defined yet: -# https://github.com/JuliaLang/julia/issues/31677 -# https://github.com/JuliaLang/julia/pull/40632 -@interface interface::AbstractArrayInterface function Base.copyto!( - a_dest::AbstractArray, bc::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}} -) - @interface interface fill!(a_dest, bc.f(bc.args...)[]) -end - -# This is defined in this way so we can rely on the Broadcast logic -# for determining the destination of the operation (element type, shape, etc.). -@interface ::AbstractArrayInterface function Base.map(f, as::AbstractArray...) - # TODO: Should this be `@interface interface ...`? That doesn't support - # broadcasting yet. - # Broadcasting is used here to determine the destination array but that - # could be done manually here. - return f.(as...) -end - -# TODO: Maybe define as -# `invoke(Base.map!, Tuple{Any,AbstractArray,Vararg{AbstractArray}}, f, dest, as...)`. -# TODO: Use `MethodError`? -@interface ::AbstractArrayInterface function Base.map!( - f, a_dest::AbstractArray, a_srcs::AbstractArray... -) - return error("Not implemented.") -end - -@interface interface::AbstractArrayInterface function Base.fill!(a::AbstractArray, value) - @interface interface map!(Returns(value), a, a) -end - -using ArrayLayouts: zero! - -# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts` -# and is useful for sparse array logic, since it can be used to empty -# the sparse array storage. -# We use a single function definition to minimize method ambiguities. -@interface interface::AbstractArrayInterface function ArrayLayouts.zero!(a::AbstractArray) - # More generally, the first codepath could be taking if `zero(eltype(a))` - # is defined and the elements are immutable. - f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero! - return @interface interface map!(f, a, a) -end - -# Specialized version of `Base.zero` written in terms of `ArrayLayouts.zero!`. -# This is friendlier for sparse arrays since `ArrayLayouts.zero!` makes it easier -# to handle the logic of dropping all elements of the sparse array when possible. -# We use a single function definition to minimize method ambiguities. -@interface interface::AbstractArrayInterface function Base.zero(a::AbstractArray) - # More generally, the first codepath could be taking if `zero(eltype(a))` - # is defined and the elements are immutable. - if eltype(a) <: Number - return @interface interface zero!(similar(a)) - end - return @interface interface map(interface(zero), a) -end - -@interface ::AbstractArrayInterface function Base.mapreduce( - f, op, as::AbstractArray...; kwargs... -) - return error("Not implemented.") -end - -# TODO: Generalize to multiple inputs. -@interface interface::AbstractInterface function Base.reduce(f, a::AbstractArray; kwargs...) - return @interface interface mapreduce(identity, f, a; kwargs...) -end - -@interface interface::AbstractArrayInterface function Base.all(a::AbstractArray) - return @interface interface reduce(&, a; init=true) -end - -@interface interface::AbstractArrayInterface function Base.all( - f::Function, a::AbstractArray -) - return @interface interface mapreduce(f, &, a; init=true) -end - -@interface interface::AbstractArrayInterface function Base.iszero(a::AbstractArray) - return @interface interface all(iszero, a) -end - -@interface interface::AbstractArrayInterface function Base.isreal(a::AbstractArray) - return @interface interface all(isreal, a) -end - -@interface interface::AbstractArrayInterface function Base.permutedims!( - a_dest::AbstractArray, a_src::AbstractArray, perm -) - return @interface interface map!(identity, a_dest, PermutedDimsArray(a_src, perm)) -end - -@interface interface::AbstractArrayInterface function Base.copyto!( - a_dest::AbstractArray, a_src::AbstractArray -) - return @interface interface map!(identity, a_dest, a_src) -end - -@interface interface::AbstractArrayInterface function Base.copy!( - a_dest::AbstractArray, a_src::AbstractArray -) - return @interface interface map!(identity, a_dest, a_src) -end - -using LinearAlgebra: LinearAlgebra -# This then requires overloading: -# function ArrayLayouts.materialize!( -# m::MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout} -# ) -# # Matmul implementation. -# end -@interface ::AbstractArrayInterface function LinearAlgebra.mul!( - a_dest::AbstractVecOrMat, a1::AbstractVecOrMat, a2::AbstractVecOrMat, α::Number, β::Number -) - return ArrayLayouts.mul!(a_dest, a1, a2, α, β) -end - -@interface ::AbstractArrayInterface function ArrayLayouts.MemoryLayout(type::Type) - # TODO: Define as `UnknownLayout()`? - # TODO: Use `MethodError`? - return error("Not implemented.") -end - -## TODO: Define `const AbstractMatrixInterface = AbstractArrayInterface{2}`, -## requires adding `ndims` type parameter to `AbstractArrayInterface`. -## @interface ::AbstractMatrixInterface function Base.*(a1, a2) -## return ArrayLayouts.mul(a1, a2) -## end - -# Concatenation - -axis_cat(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2)) -function axis_cat( - a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange... -) - return axis_cat(axis_cat(a1, a2), a_rest...) -end - -unval(x) = x -unval(::Val{x}) where {x} = x - -function cat_axes(as::AbstractArray...; dims) - return ntuple(length(first(axes.(as)))) do dim - return if dim in unval(dims) - axis_cat(map(axes -> axes[dim], axes.(as))...) - else - axes(first(as))[dim] - end - end -end - -function cat! end - -# Represents concatenating `args` over `dims`. -struct Cat{Args<:Tuple{Vararg{AbstractArray}},dims} - args::Args -end -to_cat_dims(dim::Integer) = Int(dim) -to_cat_dims(dim::Int) = (dim,) -to_cat_dims(dims::Val) = to_cat_dims(unval(dims)) -to_cat_dims(dims::Tuple) = dims -Cat(args::AbstractArray...; dims) = Cat{typeof(args),to_cat_dims(dims)}(args) -cat_dims(::Cat{<:Any,dims}) where {dims} = dims - -function Base.axes(a::Cat) - return cat_axes(a.args...; dims=cat_dims(a)) -end -Base.eltype(a::Cat) = promote_type(eltype.(a.args)...) -function Base.similar(a::Cat) - ax = axes(a) - elt = eltype(a) - # TODO: This drops GPU information, maybe use MemoryLayout? - return similar(arraytype(interface(a.args...), elt), ax) -end - -# https://github.com/JuliaLang/julia/blob/v1.11.1/base/abstractarray.jl#L1748-L1857 -# https://docs.julialang.org/en/v1/base/arrays/#Concatenation-and-permutation -# This is very similar to the `Base.cat` implementation but handles zero values better. -function cat_offset!( - a_dest::AbstractArray, offsets, a1::AbstractArray, a_rest::AbstractArray...; dims -) - inds = ntuple(ndims(a_dest)) do dim - dim in unval(dims) ? offsets[dim] .+ axes(a1, dim) : axes(a_dest, dim) - end - a_dest[inds...] = a1 - new_offsets = ntuple(ndims(a_dest)) do dim - dim in unval(dims) ? offsets[dim] + size(a1, dim) : offsets[dim] - end - cat_offset!(a_dest, new_offsets, a_rest...; dims) - return a_dest -end -function cat_offset!(a_dest::AbstractArray, offsets; dims) - return a_dest -end - -@interface ::AbstractArrayInterface function cat!( - a_dest::AbstractArray, as::AbstractArray...; dims -) - offsets = ntuple(zero, ndims(a_dest)) - # TODO: Fill `a_dest` with zeros if needed using `zero!`. - cat_offset!(a_dest, offsets, as...; dims) - return a_dest -end - -function cat_along(dims, as::AbstractArray...) - return @interface interface(as...) cat_along(dims, as...) -end - -@interface interface::AbstractArrayInterface function cat_along(dims, as::AbstractArray...) - a_dest = similar(Cat(as...; dims)) - @interface interface cat!(a_dest, as...; dims) - return a_dest -end - -@interface interface::AbstractArrayInterface function Base.cat(as::AbstractArray...; dims) - return @interface interface cat_along(dims, as...) -end - -# TODO: Use `@derive` instead: -# ```julia -# @derive (T=AbstractArray,) begin -# cat!(a_dest::AbstractArray, as::T...; dims) -# end -# ``` -function cat!(a_dest::AbstractArray, as::AbstractArray...; dims) - return @interface interface(as...) cat!(a_dest, as...; dims) -end diff --git a/src/abstractinterface.jl b/src/abstractinterface.jl index 1cee893..65c2790 100644 --- a/src/abstractinterface.jl +++ b/src/abstractinterface.jl @@ -6,6 +6,7 @@ interface(x1, x_rest...) = combine_interfaces(x1, x_rest...) # Adapted from `Base.Broadcast.combine_styles`. # Get the combined interfaces of the input objects. +# TODO: make rule definitions symmetric function combine_interfaces(x1, x2, x_rest...) return combine_interfaces(combine_interfaces(x1, x2), x_rest...) end diff --git a/src/arrayinterface.jl b/src/arrayinterface.jl new file mode 100644 index 0000000..b03e866 --- /dev/null +++ b/src/arrayinterface.jl @@ -0,0 +1,71 @@ +""" +`AbstractArrayInterface <: AbstractInterface` is the abstract supertype for any interface +using Base: BroadcastStyle +associated with an `AbstractArray` type. +""" +abstract type AbstractArrayInterface <: AbstractInterface end + +""" +`DefaultArrayInterface()` is the interface indicating that an object behaves as an +array, but hasn't defined a specialized interface. In the absence of overrides from other +`AbstractArrayInterface` arguments, this results in non-overdubbed function calls. +""" +struct DefaultArrayInterface <: AbstractArrayInterface end +# this effectively has almost no implementations, as they are inherited from the supertype +# either explicitly or will throw an error. It is simply a concrete instance to use the +# abstractarrayinterface implementations. + +using TypeParameterAccessors: parenttype +# attempt to figure out interface type from parent +function interface(::Type{A}) where {A<:AbstractArray} + pA = parenttype(A) + return pA === A ? DefaultArrayInterface() : interface(pA) +end + +function interface(::Type{B}) where {B<:Broadcast.AbstractArrayStyle} + return DefaultArrayInterface() +end + +function interface(::Type{B}) where {B<:Broadcast.Broadcasted} + return interface(Broadcast.BroadcastStyle(B)) +end + +# Combination rules +# ----------------- +function combine_interface_rule(::DefaultArrayInterface, I::AbstractArrayInterface) + return I +end +function combine_interface_rule(I::AbstractArrayInterface, ::DefaultArrayInterface) + return I +end +function combine_interface_rule(::DefaultArrayInterface, ::DefaultArrayInterface) + return DefaultArrayInterface() +end + +# Fallback implementations +# ------------------------ +# whenever we want to overload new interface implementations, we better have a fallback that +# sends us back to the default implementation. + +@interface ::AbstractArrayInterface Base.getindex(A::AbstractArray, I...) = ( + @inline; getindex(A, I...) +) +@interface ::AbstractArrayInterface Base.setindex!(A::AbstractArray, v, I...) = ( + @inline; setindex!(A, v, I...) +) + +@interface ::AbstractArrayInterface Base.similar( + A::AbstractArray, ::Type{T}, axes +) where {T} = similar(A, T, axes) + +@interface ::AbstractArrayInterface Base.map(f, A::AbstractArray, As::AbstractArray...) = + map(f, A, As...) +@interface ::AbstractArrayInterface Base.map!(f, A::AbstractArray, As::AbstractArray...) = + map!(f, A, As...) + +@interface ::AbstractArrayInterface Base.reduce( + op, A::AbstractArray, As::AbstractArray... +) = reduce(op, A, As...) +@interface ::AbstractArrayInterface Base.mapreduce( + f, op, A::AbstractArray, As::AbstractArray... +) = mapreduce(f, op, A, As...) diff --git a/src/concatenate.jl b/src/concatenate.jl new file mode 100644 index 0000000..287a6d1 --- /dev/null +++ b/src/concatenate.jl @@ -0,0 +1,161 @@ +""" + module Concatenate + +Alternative implementation for `Base.cat` through [`cat(!)`](@ref cat). + +This is mostly a copy of the Base implementation, with the main difference being +that the destination is chosen based on all inputs instead of just the first. + +Additionally, we have an intermediate representation in terms of a Concatenated object, +reminiscent of how Broadcast works. + +The various entry points for specializing behavior are: + +* Destination selection can be achieved through + + Base.similar(concat::Concatenated{Interface}, ::Type{T}, axes) where {Interface} + +* Implementation for moving one or more arguments into the destionation through + + copy_offset!(dest, shape, catdims, offsets, args...) + copy_offset1!(dest, shape, catdims, offsets, x) + +* Custom implementations: + + Base.copy(concat::Concatenated{Interface}) # custom implementation of cat + Base.copyto!(dest, concat::Concatenated{Interface}) # custom implementation of cat! based on interface + Base.copyto!(dest, concat::Concatenated{Nothing}) # custom implementation of cat! based on typeof(dest) +""" +module Concatenate + +using Compat: @compat +@compat public Concatenated + +using Base: promote_eltypeof +using ..DerivableInterfaces: DerivableInterfaces, AbstractInterface, interface, zero! + +""" + Concatenated{Interface,Dims,Args<:Tuple} + +Lazy representation of the concatenation of various `Args` along `Dims`, in order to provide +hooks to customize the implementation. +""" +struct Concatenated{Interface,Dims,Args<:Tuple} + interface::Interface + dims::Val{Dims} + args::Args + + function Concatenated( + interface::Union{Nothing,AbstractInterface}, dims::Val{Dims}, args::Tuple + ) where {Dims} + return new{typeof(interface),Dims,typeof(args)}(interface, dims, args) + end + function Concatenated(dims, args::Tuple) + return Concatenated(interface(args...), dims, args) + end + function Concatenated{Interface}(dims, args) where {Interface} + return Concatenated(Interface(), dims, args) + end + function Concatenated{Interface,Dims}(args) where {Interface,Dims} + return new{Interface,Dims,typeof(args)}(Interface(), Val(Dims), args) + end +end + +dims(::Concatenated{A,D}) where {A,D} = D +DerivableInterfaces.interface(concat::Concatenated) = concat.interface + +concatenated(args...; dims) = Concatenated(Val(dims), args) + +function Base.convert( + ::Type{Concatenated{NewInterface}}, concat::Concatenated{<:Any,Dims,Args} +) where {NewInterface,Dims,Args} + return Concatenated{NewInterface}( + concat.dims, concat.args + )::Concatenated{NewInterface,Dims,Args} +end + +# allocating the destination container +# ------------------------------------ +Base.similar(concat::Concatenated) = similar(concat, eltype(cat)) +Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(cat)) +function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T} + return similar(interface(concat), T, ax) +end + +Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) + +# For now, simply couple back to base implementation +function Base.axes(concat::Concatenated) + catdims = Base.dims2cat(dims(concat)) + return Base.cat_size_shape(catdims, concat.args...) +end + +# Main logic +# ---------- +""" + Concatenate.cat(args...; dims) + +Concatenate the supplied `args` along dimensions `dims`. + +See also [`cat!`](@ref). +""" +cat(args...; dims) = Base.materialize(concatenated(args...; dims)) +Base.materialize(concat::Concatenated) = copy(concat) + +""" + Concatenate.cat!(dest, args...; dims) + +Concatenate the supplied `args` along dimensions `dims`, placing the result into `dest`. +""" +function cat!(dest, args...; dims) + Base.materialize!(dest, concatenated(dims, args...)) + return dest +end +Base.materialize!(dest, concat::Concatenated) = copyto!(dest, concat) + +Base.copy(concat::Concatenated) = copyto!(similar(concat), concat) + +# default falls back to replacing interface with Nothing +# this permits specializing on typeof(dest) without ambiguities +# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base. +@inline Base.copyto!(dest::AbstractArray, concat::Concatenated) = + copyto!(dest, convert(Concatenated{Nothing}, concat)) + +# couple back to Base implementation if no specialization exists: +# https://github.com/JuliaLang/julia/blob/29da86bb983066dd076439c2c7bc5e28dbd611bb/base/abstractarray.jl#L1852 +function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing}) + catdims = Base.dims2cat(dims(concat)) + shape = Base.cat_size_shape(catdims, concat.args...) + count(!iszero, catdims)::Int > 1 && zero!(dest) + return Base.__cat(dest, shape, catdims, concat.args...) +end + +# Array implementation +# -------------------- +# Write in terms of a generic cat_offset!, which in term aims to specialize on 1 argument +# at a time via cat_offset1! to avoid having to write too many specializations +# function cat_offset!(dest, shape, catdims, offsets, x, X...) +# dest, newoffsets = cat_offset1!(dest, shape, catdims, offsets, x) +# return cat_offset!(dest, shape, catdims, newoffsets, X...) +# end +# cat_offset!(dest, shape, catdims, offsets) = dest + +# this is the typical specialization point, which is no longer vararg. +# it simply computes indices and calls out to copy_or_fill!, so if that +# pattern works you can also overload that function +# function cat_offset1!(dest, shape, catdims, offsets, x) +# inds = ntuple(length(offsets)) do i +# (i ≤ length(catdims) && catdims[i]) ? offsets[i] .+ axes(x, i) : 1:shape[i] +# end +# copy_or_fill!(dest, inds, x) +# newoffsets = ntuple(length(offsets)) do i +# (i ≤ length(catdims) && catdims[i]) ? offsets[i] + size(x, i) : offsets[i] +# end +# return dest, newoffsets +# end + +# copy of Base._copy_or_fill! +# copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x) +# copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x) + +end diff --git a/src/defaultarrayinterface.jl b/src/defaultarrayinterface.jl deleted file mode 100644 index bc2f9b6..0000000 --- a/src/defaultarrayinterface.jl +++ /dev/null @@ -1,32 +0,0 @@ -# TODO: Add `ndims` type parameter. -struct DefaultArrayInterface <: AbstractArrayInterface end - -using TypeParameterAccessors: parenttype -function interface(a::Type{<:AbstractArray}) - parenttype(a) === a && return DefaultArrayInterface() - return interface(parenttype(a)) -end - -@interface ::DefaultArrayInterface function Base.getindex( - a::AbstractArray{<:Any,N}, I::Vararg{Int,N} -) where {N} - return Base.getindex(a, I...) -end - -@interface ::DefaultArrayInterface function Base.setindex!( - a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N} -) where {N} - return Base.setindex!(a, value, I...) -end - -@interface ::DefaultArrayInterface function Base.map!( - f, a_dest::AbstractArray, a_srcs::AbstractArray... -) - return Base.map!(f, a_dest, a_srcs...) -end - -@interface ::DefaultArrayInterface function Base.mapreduce( - f, op, as::AbstractArray...; kwargs... -) - return Base.mapreduce(f, op, as...; kwargs...) -end diff --git a/src/derive_macro.jl b/src/derive_macro.jl index 0d17527..2799b67 100644 --- a/src/derive_macro.jl +++ b/src/derive_macro.jl @@ -10,7 +10,7 @@ argname(i::Int) = Symbol(:arg, i) function rmlines(expr) return @match expr begin e::Expr => Expr(e.head, filter(!isnothing, map(rmlines, e.args))...) - _::LineNumberNode => nothing + # _::LineNumberNode => nothing a => a end end @@ -68,26 +68,6 @@ function derive_expr(interface::Union{Symbol,Expr}, types::Expr, funcs::Expr) end end -#== -```julia -@derive SparseArrayDOK AbstractArrayOps -``` -==# -function derive_expr(type::Union{Symbol,Expr}, trait::Symbol) - return derive_trait(type, trait) -end - -#== -```julia -@derive SparseArrayInterface() SparseArrayDOK AbstractArrayOps -``` -==# -function derive_expr( - interface::Union{Symbol,Expr}, types::Union{Symbol,Expr}, trait::Symbol -) - return derive_trait(interface, types, trait) -end - function derive_funcs(args...) interface_and_or_types = Base.front(args) funcs = last(args) @@ -116,7 +96,7 @@ function replace_typevars(types::Expr, func::Expr) :($x = $y) => (x, y) end # TODO: Handle type parameters in other positions besides the first one. - new_args = map(args) do arg + new_args = map(new_args) do arg return @match arg begin :(::$Type{<:$T}) => T == typevar ? :(::$Type{<:$type}) : :(::$Type{<:$T}) :(::$T...) => T == typevar ? :(::$type...) : :(::$T...) @@ -150,6 +130,8 @@ function derive_func(interface_or_types::Union{Symbol,Expr}, func::Expr) return derive_interface_func(interface, func) end +derive_func(::Union{Symbol,Expr}, l::LineNumberNode) = l + #= ```julia @derive (T=SparseArrayDOK,) Base.getindex(::T, ::Int...) @@ -205,9 +187,7 @@ function derive_interface_func(interface::Union{Symbol,Expr}, func::Expr) # TODO: Use the `@interface` macro rather than `DerivableInterfaces.call` # directly, in case we want to change the implementation. body_args = [interface; name; body_args...] - body_name = @match name begin - :($M.$f) => :(DerivableInterfaces.call) - end + body_name = :(DerivableInterfaces.call) # TODO: Remove defaults from `kwargs`. _, body, _ = split_function( codegen_ast(JLFunction(; name=body_name, args=body_args, kwargs)) @@ -217,26 +197,3 @@ function derive_interface_func(interface::Union{Symbol,Expr}, func::Expr) # namespace when `@derive` is called. return globalref_derive(codegen_ast(jlfn)) end - -#= -```julia -@derive SparseArrayInterface() SparseArrayDOK AbstractArrayOps -``` -=# -function derive_trait( - interface::Union{Symbol,Expr}, type::Union{Symbol,Expr}, trait::Symbol -) - funcs = Expr(:block, derive(Val(trait), type).args...) - return derive_funcs(interface, funcs) -end - -#= -```julia -@derive SparseArrayDOK AbstractArrayOps -``` -=# -function derive_trait(type::Union{Symbol,Expr}, trait::Symbol) - types = :((T=$type,)) - funcs = Expr(:block, derive(Val(trait), :T).args...) - return derive_funcs(types, funcs) -end diff --git a/src/interface_function.jl b/src/interface_function.jl index 0ebde58..62d64b2 100644 --- a/src/interface_function.jl +++ b/src/interface_function.jl @@ -1,15 +1,24 @@ -#= -Rewrite `f(args...)` to `DerivableInterfaces.call(interface, f, args...)`. -Similar to `Cassette.overdub`. +""" + call(interface, f, args...; kwargs...) -This errors for debugging, but probably should be defined as: -```julia -call(interface, f, args...) = f(args...) -``` -=# -call(interface, f, args...; kwargs...) = error("Not implemented") +Call the overdubbed function implementing `f(args...; kwargs...)` for a given interface. -# Change the behavior of a function to use a certain interface. +See also [`@interface`](@ref). +""" +call(interface, f, args...; kwargs...) = throw(MethodError(interface(f), args)) +# TODO: do we want to methoderror for `call` instead? + +""" + struct InterfaceFunction{I,F} <: Function + +Callable struct to overdub a function `f::F` with a custom implementation based on +an interface `interface::I`. + +## Fields + +- `interface::I`: interface struct +- `f::F`: function to overdub +""" struct InterfaceFunction{Interface,F} <: Function interface::Interface f::F diff --git a/src/traits.jl b/src/traits.jl deleted file mode 100644 index a05bd75..0000000 --- a/src/traits.jl +++ /dev/null @@ -1,59 +0,0 @@ -using ArrayLayouts: ArrayLayouts -using LinearAlgebra: LinearAlgebra - -# TODO: Create a macro: -#= -``` -@derive_def AbstractArrayOps T begin - Base.getindex(::T, ::Any...) - Base.getindex(::T, ::Int...) - Base.setindex!(::T, ::Any, ::Int...) - Base.similar(::T, ::Type, ::Tuple{Vararg{Int}}) -end -``` -=# -# TODO: Define an `AbstractMatrixOps` trait, which is where -# matrix multiplication should be defined (both `mul!` and `*`). -#= -```julia -@derive SparseArrayDOK AbstractArrayOps -@derive SparseArrayInterface SparseArrayDOK AbstractArrayOps -``` -=# -function derive(::Val{:AbstractArrayOps}, type) - return quote - Base.getindex(::$type, ::Any...) - Base.getindex(::$type, ::Int...) - Base.setindex!(::$type, ::Any, ::Any...) - Base.setindex!(::$type, ::Any, ::Int...) - Base.similar(::$type, ::Type, ::Tuple{Vararg{Int}}) - Base.similar(::$type, ::Type, ::Tuple{Base.OneTo,Vararg{Base.OneTo}}) - Base.copy(::$type) - Base.copy!(::AbstractArray, ::$type) - Base.copyto!(::AbstractArray, ::$type) - Base.map(::Any, ::$type...) - Base.map!(::Any, ::AbstractArray, ::$type...) - Base.mapreduce(::Any, ::Any, ::$type...; kwargs...) - Base.reduce(::Any, ::$type...; kwargs...) - Base.all(::Function, ::$type) - Base.all(::$type) - Base.iszero(::$type) - Base.real(::$type) - Base.fill!(::$type, ::Any) - ArrayLayouts.zero!(::$type) - Base.zero(::$type) - Base.permutedims!(::Any, ::$type, ::Any) - Broadcast.BroadcastStyle(::Type{<:$type}) - Base.copyto!(::$type, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}) - Base.cat(::$type...; kwargs...) - ArrayLayouts.MemoryLayout(::Type{<:$type}) - LinearAlgebra.mul!(::AbstractMatrix, ::$type, ::$type, ::Number, ::Number) - end -end - -function derive(::Val{:AbstractArrayStyleOps}, type) - return quote - Base.similar(::Broadcast.Broadcasted{<:$type}, ::Type, ::Tuple) - Base.copyto!(::AbstractArray, ::Broadcast.Broadcasted{<:$type}) - end -end diff --git a/src/zero.jl b/src/zero.jl new file mode 100644 index 0000000..dae9277 --- /dev/null +++ b/src/zero.jl @@ -0,0 +1,8 @@ +""" + zero!(x::AbstractArray) + +In-place function for zero-ing out an array. +""" +zero!(x::AbstractArray) = @interface interface(x) zero!(x) + +@interface ::AbstractArrayInterface zero!(x::AbstractArray) = fill!(x, zero(eltype(x))) diff --git a/test/SparseArrayDOKs.jl b/test/SparseArrayDOKs.jl deleted file mode 100644 index bf54893..0000000 --- a/test/SparseArrayDOKs.jl +++ /dev/null @@ -1,271 +0,0 @@ -module SparseArrayDOKs - -isstored(a::AbstractArray, I::CartesianIndex) = isstored(a, Tuple(I)...) -getstoredindex(a::AbstractArray, I::CartesianIndex) = getstoredindex(a, Tuple(I)...) -getunstoredindex(a::AbstractArray, I::CartesianIndex) = getunstoredindex(a, Tuple(I)...) -function setstoredindex!(a::AbstractArray, value, I::CartesianIndex) - return setstoredindex!(a, value, Tuple(I)...) -end -function setunstoredindex!(a::AbstractArray, value, I::CartesianIndex) - return setunstoredindex!(a, value, Tuple(I)...) -end - -# A view of the stored values of an array. -# Similar to: `@view a[collect(eachstoredindex(a))]`, but the issue -# with that is it returns a `SubArray` wrapping a sparse array, which -# is then interpreted as a sparse array. Also, that involves extra -# logic for determining if the indices are stored or not, but we know -# the indices are stored. -struct StoredValues{T,A<:AbstractArray{T},I} <: AbstractVector{T} - array::A - storedindices::I -end -StoredValues(a::AbstractArray) = StoredValues(a, collect(eachstoredindex(a))) -Base.size(a::StoredValues) = size(a.storedindices) -Base.getindex(a::StoredValues, I::Int) = getstoredindex(a.array, a.storedindices[I]) -function Base.setindex!(a::StoredValues, value, I::Int) - return setstoredindex!(a.array, value, a.storedindices[I]) -end - -storedvalues(a::AbstractArray) = StoredValues(a) - -using ArrayLayouts: ArrayLayouts, MatMulMatAdd, MemoryLayout -using DerivableInterfaces: - DerivableInterfaces, - @array_aliases, - @derive, - @interface, - AbstractArrayInterface, - interface -using LinearAlgebra: LinearAlgebra - -# Define an interface. -struct SparseArrayInterface <: AbstractArrayInterface end - -# Define interface functions. -@interface ::SparseArrayInterface function Base.getindex( - a::AbstractArray{<:Any,N}, I::Vararg{Int,N} -) where {N} - checkbounds(a, I...) - !isstored(a, I...) && return getunstoredindex(a, I...) - return getstoredindex(a, I...) -end -@interface ::SparseArrayInterface function Base.setindex!( - a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N} -) where {N} - checkbounds(a, I...) - if !isstored(a, I...) - iszero(value) && return a - setunstoredindex!(a, value, I...) - return a - end - setstoredindex!(a, value, I...) - return a -end - -struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end -SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}() - -DerivableInterfaces.interface(::Type{<:SparseArrayStyle}) = SparseArrayInterface() - -@derive SparseArrayStyle AbstractArrayStyleOps - -DerivableInterfaces.arraytype(::SparseArrayInterface, T::Type) = SparseArrayDOK{T} - -# Interface functions. -@interface ::SparseArrayInterface function Broadcast.BroadcastStyle(type::Type) - return SparseArrayStyle{ndims(type)}() -end - -struct SparseLayout <: MemoryLayout end - -@interface ::SparseArrayInterface function ArrayLayouts.MemoryLayout(type::Type) - return SparseLayout() -end - -@interface ::SparseArrayInterface function Base.map!( - f, a_dest::AbstractArray, as::AbstractArray... -) - # TODO: Define a function `preserves_unstored(a_dest, f, as...)` - # to determine if a function preserves the stored values - # of the destination sparse array. - # The current code may be inefficient since it actually - # accesses an unstored element, which in the case of a - # sparse array of arrays can allocate an array. - # Sparse arrays could be expected to define a cheap - # unstored element allocator, for example - # `get_prototypical_unstored(a::AbstractArray)`. - I = first(eachindex(as...)) - preserves_unstored = iszero(f(map(a -> getunstoredindex(a, I), as)...)) - if !preserves_unstored - # Doesn't preserve unstored values, loop over all elements. - for I in eachindex(as...) - a_dest[I] = map(f, map(a -> a[I], as)...) - end - end - # TODO: Define `eachstoredindex(as...)`. - for I in union(eachstoredindex.(as)...) - a_dest[I] = map(f, map(a -> a[I], as)...) - end - return a_dest -end - -@interface ::SparseArrayInterface function Base.mapreduce( - f, op, a::AbstractArray; kwargs... -) - # TODO: Need to select a better `init`. - return mapreduce(f, op, storedvalues(a); kwargs...) -end - -# ArrayLayouts functionality. - -function ArrayLayouts.sub_materialize(::SparseLayout, a::AbstractArray, axes::Tuple) - a_dest = similar(a) - a_dest .= a - return a_dest -end - -function ArrayLayouts.materialize!( - m::MatMulMatAdd{<:SparseLayout,<:SparseLayout,<:SparseLayout} -) - a_dest, a1, a2, α, β = m.C, m.A, m.B, m.α, m.β - for I1 in eachstoredindex(a1) - for I2 in eachstoredindex(a2) - if I1[2] == I2[1] - I_dest = CartesianIndex(I1[1], I2[2]) - a_dest[I_dest] = a1[I1] * a2[I2] * α + a_dest[I_dest] * β - end - end - end - return a_dest -end - -# Sparse array minimal interface -using LinearAlgebra: Adjoint -function isstored(a::Adjoint, i::Int, j::Int) - return isstored(parent(a), j, i) -end -function getstoredindex(a::Adjoint, i::Int, j::Int) - return getstoredindex(parent(a), j, i)' -end -function getunstoredindex(a::Adjoint, i::Int, j::Int) - return getunstoredindex(parent(a), j, i)' -end -function eachstoredindex(a::Adjoint) - return map(CartesianIndex ∘ reverse ∘ Tuple, collect(eachstoredindex(parent(a)))) -end - -perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p -iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip - -# TODO: Use `Base.PermutedDimsArrays.genperm` or -# https://github.com/jipolanco/StaticPermutations.jl? -genperm(v, perm) = map(j -> v[j], perm) - -function isstored(a::PermutedDimsArray, I::Int...) - return isstored(parent(a), genperm(I, iperm(a))...) -end -function getstoredindex(a::PermutedDimsArray, I::Int...) - return getstoredindex(parent(a), genperm(I, iperm(a))...) -end -function getunstoredindex(a::PermutedDimsArray, I::Int...) - return getunstoredindex(parent(a), genperm(I, iperm(a))...) -end -function eachstoredindex(a::PermutedDimsArray) - return map(collect(eachstoredindex(parent(a)))) do I - return CartesianIndex(genperm(I, perm(a))) - end -end - -tuple_oneto(n) = ntuple(identity, n) -## This is an optimization for `storedvalues` for DOK. -## function valuesview(d::Dict, keys) -## return @view d.vals[[Base.ht_keyindex(d, key) for key in keys]] -## end - -function eachstoredparentindex(a::SubArray) - return filter(eachstoredindex(parent(a))) do I - return all(d -> I[d] ∈ parentindices(a)[d], 1:ndims(parent(a))) - end -end -function storedvalues(a::SubArray) - return @view parent(a)[collect(eachstoredparentindex(a))] -end -function isstored(a::SubArray, I::Int...) - return isstored(parent(a), Base.reindex(parentindices(a), I)...) -end -function getstoredindex(a::SubArray, I::Int...) - return getstoredindex(parent(a), Base.reindex(parentindices(a), I)...) -end -function getunstoredindex(a::SubArray, I::Int...) - return getunstoredindex(parent(a), Base.reindex(parentindices(a), I)...) -end -function setstoredindex!(a::SubArray, value, I::Int...) - return setstoredindex!(parent(a), value, Base.reindex(parentindices(a), I)...) -end -function setunstoredindex!(a::SubArray, value, I::Int...) - return setunstoredindex!(parent(a), value, Base.reindex(parentindices(a), I)...) -end -function eachstoredindex(a::SubArray) - nonscalardims = filter(tuple_oneto(ndims(parent(a)))) do d - return !(parentindices(a)[d] isa Real) - end - return collect(( - CartesianIndex( - map(nonscalardims) do d - return findfirst(==(I[d]), parentindices(a)[d]) - end, - ) for I in eachstoredparentindex(a) - )) -end - -# Define a type that will derive the interface. -struct SparseArrayDOK{T,N} <: AbstractArray{T,N} - storage::Dict{CartesianIndex{N},T} - size::NTuple{N,Int} -end -storage(a::SparseArrayDOK) = a.storage -Base.size(a::SparseArrayDOK) = a.size -function SparseArrayDOK{T}(size::Int...) where {T} - N = length(size) - return SparseArrayDOK{T,N}(Dict{CartesianIndex{N},T}(), size) -end -# Used in `Base.similar`. -function SparseArrayDOK{T}(::UndefInitializer, size::Tuple{Vararg{Int}}) where {T} - return SparseArrayDOK{T}(size...) -end -function isstored(a::SparseArrayDOK, I::Int...) - return CartesianIndex(I) in keys(storage(a)) -end -function getstoredindex(a::SparseArrayDOK, I::Int...) - return storage(a)[CartesianIndex(I)] -end -function getunstoredindex(a::SparseArrayDOK, I::Int...) - return zero(eltype(a)) -end -function setstoredindex!(a::SparseArrayDOK, value, I::Int...) - storage(a)[CartesianIndex(I)] = value - return a -end -function setunstoredindex!(a::SparseArrayDOK, value, I::Int...) - storage(a)[CartesianIndex(I)] = value - return a -end -eachstoredindex(a::SparseArrayDOK) = keys(storage(a)) -storedlength(a::SparseArrayDOK) = length(eachstoredindex(a)) - -function ArrayLayouts.zero!(a::SparseArrayDOK) - empty!(storage(a)) - return a -end - -# Specify the interface the type adheres to. -DerivableInterfaces.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface() - -# Define aliases like `SparseMatrixDOK`, `AnySparseArrayDOK`, etc. -@array_aliases SparseArrayDOK - -# DerivableInterfaces the interface for the type. -@derive AnySparseArrayDOK AbstractArrayOps - -end diff --git a/test/runtests.jl b/test/runtests.jl index 2b74a89..1c52c3e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,9 +24,11 @@ isexamplefile(fn) = # tests in groups based on folder structure for testgroup in filter(isdir, readdir(@__DIR__)) if GROUP == "ALL" || GROUP == uppercase(testgroup) - for file in filter(istestfile, readdir(joinpath(@__DIR__, testgroup); join=true)) - @eval @safetestset $(last(splitdir(file))) begin - include($file) + groupdir = joinpath(@__DIR__, testgroup) + for file in filter(istestfile, readdir(groupdir)) + filename = joinpath(groupdir, file) + @eval @safetestset $file begin + include($filename) end end end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index a01787c..1798cfd 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -3,5 +3,5 @@ using Aqua: Aqua using Test: @testset @testset "Code quality (Aqua.jl)" begin - # Aqua.test_all(DerivableInterfaces) + Aqua.test_all(DerivableInterfaces) end diff --git a/test/test_basics.jl b/test/test_basics.jl index d088f5a..282598a 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,153 +1,61 @@ -using ArrayLayouts: zero! -include("SparseArrayDOKs.jl") -using .SparseArrayDOKs: SparseArrayDOK, storedlength -using Test: @test, @testset - -elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) -@testset "DerivableInterfaces" for elt in elts - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - @test a isa SparseArrayDOK{elt,2} - @test size(a) == (2, 2) - @test a[1, 1] == 0 - @test a[1, 1, 1] == 0 - @test a[1, 2] == 12 - @test a[1, 2, 1] == 12 - @test storedlength(a) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - for b in (similar(a, Float32, (3, 3)), similar(a, Float32, Base.OneTo.((3, 3)))) - @test b isa SparseArrayDOK{Float32,2} - @test b == zeros(Float32, 3, 3) - @test size(b) == (3, 3) - end - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = similar(a) - bc = Broadcast.Broadcasted(x -> 2x, (a,)) - copyto!(b, bc) - @test b isa SparseArrayDOK{elt,2} - @test b == [0 24; 0 0] - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(3, 3, 3) - a[1, 2, 3] = 123 - b = permutedims(a, (2, 3, 1)) - @test b isa SparseArrayDOK{elt,3} - @test b[2, 3, 1] == 123 - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = copy(a') - @test b isa SparseArrayDOK{elt,2} - @test b == [0 0; 12 0] - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = map(x -> 2x, a) - @test b isa SparseArrayDOK{elt,2} - @test b == [0 24; 0 0] - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = a * a' - @test b isa SparseArrayDOK{elt,2} - @test b == [144 0; 0 0] - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = a .+ 2 .* a' - @test b isa SparseArrayDOK{elt,2} - @test b == [0 12; 24 0] - @test storedlength(b) == 2 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = a[1:2, 2] - @test b isa SparseArrayDOK{elt,1} - @test b == [12, 0] - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - @test iszero(a) - a[2, 1] = 21 - a[1, 2] = 12 - @test !iszero(a) - @test isreal(a) - @test sum(a) == 33 - @test mapreduce(x -> 2x, +, a) == 66 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = similar(a) - copyto!(b, a) - @test b isa SparseArrayDOK{elt,2} - @test b == a - @test b[1, 2] == 12 - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a .= 2 - @test storedlength(a) == length(a) - for I in eachindex(a) - @test a[I] == 2 - end - - a = SparseArrayDOK{elt}(2, 2) - fill!(a, 2) - @test storedlength(a) == length(a) - for I in eachindex(a) - @test a[I] == 2 - end +using Test: @test, @testset, @test_throws +using DerivableInterfaces: DerivableInterfaces as DI +using DerivableInterfaces: @derive, @interface, AbstractArrayInterface + +# Test setup +# ---------- +struct MyArray{T,N} <: AbstractArray{T,N} + parent::Array{T,N} +end +Base.parent(A::MyArray) = A.parent - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - zero!(a) - @test iszero(a) - @test iszero(storedlength(a)) +@derive (T=MyArray,) Base.getindex(::T, ::Int) - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = zero(a) - @test b isa SparseArrayDOK{elt,2} - @test iszero(b) - @test iszero(storedlength(b)) +# Interfacetype +struct MyInterface <: DI.AbstractArrayInterface end +DI.interface(::Type{A}) where {A<:MyArray} = MyInterface() - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = SparseArrayDOK{elt}(4, 4) - b[2:3, 2:3] .= a - @test isone(storedlength(b)) - @test b[2, 3] == 12 +const f_ctr = Ref(0) # used to verify if function was actually called +@interface ::MyInterface function Base.getindex(A::MyArray, i::Int) + f_ctr[] += 1 + return getindex(parent(A), i) +end +@interface ::MyInterface function Base.getindex(A::AbstractArray, i::Int) + f_ctr[] += 1 + return getindex(A::AbstractArray, i::Int) +end - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = SparseArrayDOK{elt}(4, 4) - b[2:3, 2:3] = a - @test isone(storedlength(b)) - @test b[2, 3] == 12 +@derive (TA=Any, TB=Any) f(::TA, ::TB) +@interface ::AbstractArrayInterface f(A, B) = -1 +@interface ::MyInterface f(A, B) = 1 + +# Tests +# ----- +# TODO: test type stability +@testset "@derived types" begin + ctr = f_ctr[] + A = rand(Int, 3) + B = MyArray(A) + @test A[1] == B[1] + @test f_ctr[] == ctr + 1 +end - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = SparseArrayDOK{elt}(4, 4) - c = @view b[2:3, 2:3] - c .= a - @test isone(storedlength(b)) - @test b[2, 3] == 12 +@testset "using @interface functions for non-derived types" begin + ctr = f_ctr[] + A = zeros(Int, 3) + @test A[1] == @interface MyInterface() A[1] + @test f_ctr[] == ctr + 1 +end - a1 = SparseArrayDOK{elt}(2, 2) - a1[1, 2] = 12 - a2 = SparseArrayDOK{elt}(2, 2) - a2[2, 1] = 21 - b = cat(a1, a2; dims=(1, 2)) - @test b isa SparseArrayDOK{elt,2} - @test storedlength(b) == 2 - @test b[1, 2] == 12 - @test b[4, 3] == 21 +@testset "interface promotion rules" begin + # DefaultArrayInterface should give default + @test f(zeros(1), zeros(1)) == -1 + @test f(zeros(1), zeros(1, 1)) == -1 + # MyInterface + @test f(MyArray(zeros(1)), MyArray(zeros(1))) == 1 + # Mix + @test f(MyArray(zeros(1)), zeros(1)) == 1 + @test f(zeros(1), MyArray(zeros(1))) == 1 + # undefined mix + @test_throws ErrorException f((1,), zeros(1)) == -1 end diff --git a/test/test_defaultarrayinterface.jl b/test/test_defaultarrayinterface.jl index d12bade..55be97c 100644 --- a/test/test_defaultarrayinterface.jl +++ b/test/test_defaultarrayinterface.jl @@ -33,6 +33,6 @@ end @testset "Broadcast.DefaultArrayStyle" begin @test interface(Broadcast.DefaultArrayStyle) == DefaultArrayInterface() - @test interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) == - DefaultArrayInterface() + @test interface(Broadcast.broadcasted(+, randn(2), randn(2))) == + DefaultArrayInterface() end