diff --git a/Project.toml b/Project.toml index 460f026..3dda832 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DerivableInterfaces" uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" +version = "0.5.6" authors = ["ITensor developers and contributors"] -version = "0.5.5" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/docs/make.jl b/docs/make.jl index 1ac9178..11f5f43 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,23 +2,23 @@ using DerivableInterfaces: DerivableInterfaces using Documenter: Documenter, DocMeta, deploydocs, makedocs DocMeta.setdocmeta!( - DerivableInterfaces, :DocTestSetup, :(using DerivableInterfaces); recursive=true + DerivableInterfaces, :DocTestSetup, :(using DerivableInterfaces); recursive = true ) include("make_index.jl") makedocs(; - modules=[DerivableInterfaces], - authors="ITensor developers and contributors", - sitename="DerivableInterfaces.jl", - format=Documenter.HTML(; - canonical="https://itensor.github.io/DerivableInterfaces.jl", - edit_link="main", - assets=["assets/favicon.ico", "assets/extras.css"], - ), - pages=["Home" => "index.md", "Reference" => "reference.md"], + modules = [DerivableInterfaces], + authors = "ITensor developers and contributors", + sitename = "DerivableInterfaces.jl", + format = Documenter.HTML(; + canonical = "https://itensor.github.io/DerivableInterfaces.jl", + edit_link = "main", + assets = ["assets/favicon.ico", "assets/extras.css"], + ), + pages = ["Home" => "index.md", "Reference" => "reference.md"], ) deploydocs(; - repo="github.com/ITensor/DerivableInterfaces.jl", devbranch="main", push_preview=true + repo = "github.com/ITensor/DerivableInterfaces.jl", devbranch = "main", push_preview = true ) diff --git a/docs/make_index.jl b/docs/make_index.jl index 3bca76a..66ce11c 100644 --- a/docs/make_index.jl +++ b/docs/make_index.jl @@ -2,20 +2,20 @@ using Literate: Literate using DerivableInterfaces: DerivableInterfaces function ccq_logo(content) - include_ccq_logo = """ + include_ccq_logo = """ ```@raw html Flatiron Center for Computational Quantum Physics logo. Flatiron Center for Computational Quantum Physics logo. ``` """ - content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) - return content + content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) + return content end Literate.markdown( - joinpath(pkgdir(DerivableInterfaces), "examples", "README.jl"), - joinpath(pkgdir(DerivableInterfaces), "docs", "src"); - flavor=Literate.DocumenterFlavor(), - name="index", - postprocess=ccq_logo, + joinpath(pkgdir(DerivableInterfaces), "examples", "README.jl"), + joinpath(pkgdir(DerivableInterfaces), "docs", "src"); + flavor = Literate.DocumenterFlavor(), + name = "index", + postprocess = ccq_logo, ) diff --git a/docs/make_readme.jl b/docs/make_readme.jl index c583f65..0933627 100644 --- a/docs/make_readme.jl +++ b/docs/make_readme.jl @@ -2,20 +2,20 @@ using Literate: Literate using DerivableInterfaces: DerivableInterfaces function ccq_logo(content) - include_ccq_logo = """ + include_ccq_logo = """ Flatiron Center for Computational Quantum Physics logo. """ - content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) - return content + content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) + return content end Literate.markdown( - joinpath(pkgdir(DerivableInterfaces), "examples", "README.jl"), - joinpath(pkgdir(DerivableInterfaces)); - flavor=Literate.CommonMarkFlavor(), - name="README", - postprocess=ccq_logo, + joinpath(pkgdir(DerivableInterfaces), "examples", "README.jl"), + joinpath(pkgdir(DerivableInterfaces)); + flavor = Literate.CommonMarkFlavor(), + name = "README", + postprocess = ccq_logo, ) diff --git a/examples/README.jl b/examples/README.jl index 8eaf78e..de3f2e7 100644 --- a/examples/README.jl +++ b/examples/README.jl @@ -1,5 +1,5 @@ # # DerivableInterfaces.jl -# +# # [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://itensor.github.io/DerivableInterfaces.jl/stable/) # [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://itensor.github.io/DerivableInterfaces.jl/dev/) # [![Build Status](https://github.com/ITensor/DerivableInterfaces.jl/actions/workflows/Tests.yml/badge.svg?branch=main)](https://github.com/ITensor/DerivableInterfaces.jl/actions/workflows/Tests.yml?query=branch%3Amain) @@ -60,7 +60,7 @@ julia> Pkg.add("DerivableInterfaces") # ## Examples using DerivableInterfaces: - DerivableInterfaces, @array_aliases, @derive, @interface, interface + DerivableInterfaces, @array_aliases, @derive, @interface, interface using Test: @test # Define an interface. @@ -68,48 +68,48 @@ struct SparseArrayInterface end # Define interface functions. @interface ::SparseArrayInterface function Base.getindex(a, I::Int...) - checkbounds(a, I...) - !isstored(a, I...) && return getunstoredindex(a, I...) - return getstoredindex(a, I...) + checkbounds(a, I...) + !isstored(a, I...) && return getunstoredindex(a, I...) + return getstoredindex(a, I...) end @interface ::SparseArrayInterface function Base.setindex!(a, value, I::Int...) - checkbounds(a, I...) - iszero(value) && return a - if !isstored(a, I...) - setunstoredindex!(a, value, I...) + checkbounds(a, I...) + iszero(value) && return a + if !isstored(a, I...) + setunstoredindex!(a, value, I...) + return a + end + setstoredindex!(a, value, I...) return a - end - setstoredindex!(a, value, I...) - return a end # Define a type that will derive the interface. -struct SparseArrayDOK{T,N} <: AbstractArray{T,N} - storage::Dict{CartesianIndex{N},T} - size::NTuple{N,Int} +struct SparseArrayDOK{T, N} <: AbstractArray{T, N} + storage::Dict{CartesianIndex{N}, T} + size::NTuple{N, Int} end storage(a::SparseArrayDOK) = a.storage Base.size(a::SparseArrayDOK) = a.size function SparseArrayDOK{T}(size::Int...) where {T} - N = length(size) - return SparseArrayDOK{T,N}(Dict{CartesianIndex{N},T}(), size) + N = length(size) + return SparseArrayDOK{T, N}(Dict{CartesianIndex{N}, T}(), size) end function isstored(a::SparseArrayDOK, I::Int...) - return CartesianIndex(I) in keys(storage(a)) + return CartesianIndex(I) in keys(storage(a)) end function getstoredindex(a::SparseArrayDOK, I::Int...) - return storage(a)[CartesianIndex(I)] + return storage(a)[CartesianIndex(I)] end function getunstoredindex(a::SparseArrayDOK, I::Int...) - return zero(eltype(a)) + return zero(eltype(a)) end function setstoredindex!(a::SparseArrayDOK, value, I::Int...) - storage(a)[CartesianIndex(I)] = value - return a + storage(a)[CartesianIndex(I)] = value + return a end function setunstoredindex!(a::SparseArrayDOK, value, I::Int...) - storage(a)[CartesianIndex(I)] = value - return a + storage(a)[CartesianIndex(I)] = value + return a end # Specify the interface the type adheres to. @@ -119,9 +119,9 @@ DerivableInterfaces.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface() @array_aliases SparseArrayDOK # DerivableInterfaces the interface for the type. -@derive (T=SparseArrayDOK,) begin - Base.getindex(::T, ::Int...) - Base.setindex!(::T, ::Any, ::Int...) +@derive (T = SparseArrayDOK,) begin + Base.getindex(::T, ::Int...) + Base.setindex!(::T, ::Any, ::Int...) end a = SparseArrayDOK{Float64}(2, 2) diff --git a/ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl b/ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl index 000fe6b..506d23e 100644 --- a/ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl +++ b/ext/DerivableInterfacesBlockArraysExt/DerivableInterfacesBlockArraysExt.jl @@ -4,7 +4,7 @@ using BlockArrays: BlockedOneTo, blockedrange, blocklengths using DerivableInterfaces.Concatenate: Concatenate function Concatenate.cat_axis(a1::BlockedOneTo, a2::BlockedOneTo) - return blockedrange([blocklengths(a1); blocklengths(a2)]) + return blockedrange([blocklengths(a1); blocklengths(a2)]) end end diff --git a/src/abstractarrayinterface.jl b/src/abstractarrayinterface.jl index 6eba847..d1088b5 100644 --- a/src/abstractarrayinterface.jl +++ b/src/abstractarrayinterface.jl @@ -2,115 +2,115 @@ abstract type AbstractArrayInterface{N} <: AbstractInterface end function interface(::Type{<:Broadcast.AbstractArrayStyle{N}}) where {N} - return DefaultArrayInterface{N}() + return DefaultArrayInterface{N}() end function interface(::Type{<:Broadcast.AbstractArrayStyle}) - return DefaultArrayInterface() + return DefaultArrayInterface() end function interface(BC::Type{<:Broadcast.Broadcasted{Nothing}}) - return DefaultArrayInterface{ndims(BC)}() + return DefaultArrayInterface{ndims(BC)}() end function interface(::Type{<:Broadcast.Broadcasted{<:Style}}) where {Style} - return interface(Style) + return interface(Style) end # TODO: Define as `similar(Array{T}, ax)`. function Base.similar(interface::AbstractArrayInterface, T::Type, ax::Tuple) - return error("Not implemented.") + return error("Not implemented.") end using ArrayLayouts: ArrayLayouts @interface ::AbstractArrayInterface function Base.getindex(a::AbstractArray, I...) - return ArrayLayouts.layout_getindex(a, I...) + return ArrayLayouts.layout_getindex(a, I...) end @interface interface::AbstractArrayInterface function Base.setindex!( - a::AbstractArray, value, I... -) - # TODO: Change to this once broadcasting in `@interface` is supported: - # @interface interface a[I...] .= value - @interface interface map!(identity, @view(a[I...]), value) - return a + a::AbstractArray, value, I... + ) + # TODO: Change to this once broadcasting in `@interface` is supported: + # @interface interface a[I...] .= value + @interface interface map!(identity, @view(a[I...]), value) + return a end # TODO: Maybe define as `ArrayLayouts.layout_getindex(a, I...)` or # `invoke(getindex, Tuple{AbstractArray,Vararg{Any}}, a, I...)`. # TODO: Use `MethodError`? @interface ::AbstractArrayInterface function Base.getindex( - a::AbstractArray{<:Any,N}, I::Vararg{Int,N} -) where {N} - return error("Not implemented.") + a::AbstractArray{<:Any, N}, I::Vararg{Int, N} + ) where {N} + return error("Not implemented.") end # TODO: Make this more general, use `Base.to_index`. @interface interface::AbstractArrayInterface function Base.getindex( - a::AbstractArray{<:Any,N}, I::CartesianIndex{N} -) where {N} - return @interface interface getindex(a, Tuple(I)...) + a::AbstractArray{<:Any, N}, I::CartesianIndex{N} + ) where {N} + return @interface interface getindex(a, Tuple(I)...) end # Linear indexing. @interface interface::AbstractArrayInterface function Base.getindex( - a::AbstractArray, I::Int -) - return @interface interface getindex(a, CartesianIndices(a)[I]) + a::AbstractArray, I::Int + ) + return @interface interface getindex(a, CartesianIndices(a)[I]) end # TODO: Use `MethodError`? @interface ::AbstractArrayInterface function Base.setindex!( - a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N} -) where {N} - return error("Not implemented.") + a::AbstractArray{<:Any, N}, value, I::Vararg{Int, N} + ) where {N} + return error("Not implemented.") end # Linear indexing. @interface interface::AbstractArrayInterface function Base.setindex!( - a::AbstractArray, value, I::Int -) - return @interface interface setindex!(a, value, CartesianIndices(a)[I]) + a::AbstractArray, value, I::Int + ) + return @interface interface setindex!(a, value, CartesianIndices(a)[I]) end # TODO: Make this more general, use `Base.to_index`. @interface interface::AbstractArrayInterface function Base.setindex!( - a::AbstractArray{<:Any,N}, value, I::CartesianIndex{N} -) where {N} - return @interface interface setindex!(a, value, Tuple(I)...) + a::AbstractArray{<:Any, N}, value, I::CartesianIndex{N} + ) where {N} + return @interface interface setindex!(a, value, Tuple(I)...) end @interface ::AbstractArrayInterface function Broadcast.BroadcastStyle(type::Type) - return Broadcast.DefaultArrayStyle{ndims(type)}() + return Broadcast.DefaultArrayStyle{ndims(type)}() end # TODO: Maybe define as `Array{T}(undef, size...)` or # `invoke(Base.similar, Tuple{AbstractArray,Type,Vararg{Int}}, a, T, size)`. # TODO: Use `MethodError`? @interface interface::AbstractArrayInterface function Base.similar( - a::AbstractArray, T::Type, size::Tuple{Vararg{Int}} -) - return similar(interface, T, size) + a::AbstractArray, T::Type, size::Tuple{Vararg{Int}} + ) + return similar(interface, T, size) end @interface ::AbstractArrayInterface function Base.copy(a::AbstractArray) - a_dest = similar(a) - return a_dest .= a + a_dest = similar(a) + return a_dest .= a end # TODO: Use `Base.to_shape(axes)` or # `Base.invoke(similar, Tuple{AbstractArray,Type,Tuple{Union{Integer,Base.OneTo},Vararg{Union{Integer,Base.OneTo}}}}, a, T, axes)`. # TODO: Make this more general, handle mixtures of integers and ranges (`Union{Integer,Base.OneTo}`). @interface interface::AbstractArrayInterface function Base.similar( - a::AbstractArray, T::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}} -) - return @interface interface similar(a, T, Base.to_shape(axes)) + a::AbstractArray, T::Type, axes::Tuple{Base.OneTo, Vararg{Base.OneTo}} + ) + return @interface interface similar(a, T, Base.to_shape(axes)) end @interface interface::AbstractArrayInterface function Base.similar( - bc::Broadcast.Broadcasted, T::Type, axes::Tuple -) - return similar(interface, T, axes) + bc::Broadcast.Broadcasted, T::Type, axes::Tuple + ) + return similar(interface, T, axes) end using MapBroadcast: Mapped @@ -118,10 +118,10 @@ using MapBroadcast: Mapped # TODO: Look into `SparseArrays.capturescalars`: # https://github.com/JuliaSparse/SparseArrays.jl/blob/1beb0e4a4618b0399907b0000c43d9f66d34accc/src/higherorderfns.jl#L1092-L1102 @interface interface::AbstractArrayInterface function Base.copyto!( - a_dest::AbstractArray, bc::Broadcast.Broadcasted -) - m = Mapped(bc) - return @interface interface map!(m.f, a_dest, m.args...) + a_dest::AbstractArray, bc::Broadcast.Broadcasted + ) + m = Mapped(bc) + return @interface interface map!(m.f, a_dest, m.args...) end # This captures broadcast expressions such as `a .= 2`. @@ -129,37 +129,37 @@ end # https://github.com/JuliaLang/julia/issues/31677 # https://github.com/JuliaLang/julia/pull/40632 @interface interface::AbstractArrayInterface function Base.copyto!( - a_dest::AbstractArray, bc::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}} -) - @interface interface fill!(a_dest, bc.f(bc.args...)[]) + a_dest::AbstractArray, bc::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}} + ) + @interface interface fill!(a_dest, bc.f(bc.args...)[]) end # This is defined in this way so we can rely on the Broadcast logic # for determining the destination of the operation (element type, shape, etc.). @interface ::AbstractArrayInterface function Base.map(f, as::AbstractArray...) - # TODO: Should this be `@interface interface ...`? That doesn't support - # broadcasting yet. - # Broadcasting is used here to determine the destination array but that - # could be done manually here. - return f.(as...) + # TODO: Should this be `@interface interface ...`? That doesn't support + # broadcasting yet. + # Broadcasting is used here to determine the destination array but that + # could be done manually here. + return f.(as...) end # TODO: Maybe define as # `invoke(Base.map!, Tuple{Any,AbstractArray,Vararg{AbstractArray}}, f, dest, as...)`. # TODO: Use `MethodError`? @interface ::AbstractArrayInterface function Base.map!( - f, a_dest::AbstractArray, a_srcs::AbstractArray... -) - return error("Not implemented.") + f, a_dest::AbstractArray, a_srcs::AbstractArray... + ) + return error("Not implemented.") end @interface interface::AbstractArrayInterface function Base.fill!(a::AbstractArray, value) - @interface interface map!(Returns(value), a, a) + @interface interface map!(Returns(value), a, a) end # TODO: should this be recursive? `map!(zero!, A, A)` might also work? @interface ::AbstractArrayInterface DerivableInterfaces.zero!(A::AbstractArray) = fill!( - A, zero(eltype(A)) + A, zero(eltype(A)) ) # Specialized version of `Base.zero` written in terms of `zero!`. @@ -167,59 +167,59 @@ end # to handle the logic of dropping all elements of the sparse array when possible. # We use a single function definition to minimize method ambiguities. @interface interface::AbstractArrayInterface function Base.zero(a::AbstractArray) - # More generally, the first codepath could be taking if `zero(eltype(a))` - # is defined and the elements are immutable. - if eltype(a) <: Number - return @interface interface zero!(similar(a)) - end - return @interface interface map(interface(zero), a) + # More generally, the first codepath could be taking if `zero(eltype(a))` + # is defined and the elements are immutable. + if eltype(a) <: Number + return @interface interface zero!(similar(a)) + end + return @interface interface map(interface(zero), a) end @interface ::AbstractArrayInterface function Base.mapreduce( - f, op, as::AbstractArray...; kwargs... -) - return error("Not implemented.") + f, op, as::AbstractArray...; kwargs... + ) + return error("Not implemented.") end # TODO: Generalize to multiple inputs. @interface interface::AbstractInterface function Base.reduce(f, a::AbstractArray; kwargs...) - return @interface interface mapreduce(identity, f, a; kwargs...) + return @interface interface mapreduce(identity, f, a; kwargs...) end @interface interface::AbstractArrayInterface function Base.all(a::AbstractArray) - return @interface interface reduce(&, a; init=true) + return @interface interface reduce(&, a; init = true) end @interface interface::AbstractArrayInterface function Base.all( - f::Function, a::AbstractArray -) - return @interface interface mapreduce(f, &, a; init=true) + f::Function, a::AbstractArray + ) + return @interface interface mapreduce(f, &, a; init = true) end @interface interface::AbstractArrayInterface function Base.iszero(a::AbstractArray) - return @interface interface all(iszero, a) + return @interface interface all(iszero, a) end @interface interface::AbstractArrayInterface function Base.isreal(a::AbstractArray) - return @interface interface all(isreal, a) + return @interface interface all(isreal, a) end @interface interface::AbstractArrayInterface function Base.permutedims!( - a_dest::AbstractArray, a_src::AbstractArray, perm -) - return @interface interface map!(identity, a_dest, PermutedDimsArray(a_src, perm)) + a_dest::AbstractArray, a_src::AbstractArray, perm + ) + return @interface interface map!(identity, a_dest, PermutedDimsArray(a_src, perm)) end @interface interface::AbstractArrayInterface function Base.copyto!( - a_dest::AbstractArray, a_src::AbstractArray -) - return @interface interface map!(identity, a_dest, a_src) + a_dest::AbstractArray, a_src::AbstractArray + ) + return @interface interface map!(identity, a_dest, a_src) end @interface interface::AbstractArrayInterface function Base.copy!( - a_dest::AbstractArray, a_src::AbstractArray -) - return @interface interface map!(identity, a_dest, a_src) + a_dest::AbstractArray, a_src::AbstractArray + ) + return @interface interface map!(identity, a_dest, a_src) end using LinearAlgebra: LinearAlgebra @@ -230,15 +230,15 @@ using LinearAlgebra: LinearAlgebra # # Matmul implementation. # end @interface ::AbstractArrayInterface function LinearAlgebra.mul!( - a_dest::AbstractVecOrMat, a1::AbstractVecOrMat, a2::AbstractVecOrMat, α::Number, β::Number -) - return ArrayLayouts.mul!(a_dest, a1, a2, α, β) + a_dest::AbstractVecOrMat, a1::AbstractVecOrMat, a2::AbstractVecOrMat, α::Number, β::Number + ) + return ArrayLayouts.mul!(a_dest, a1, a2, α, β) end @interface ::AbstractArrayInterface function ArrayLayouts.MemoryLayout(type::Type) - # TODO: Define as `UnknownLayout()`? - # TODO: Use `MethodError`? - return error("Not implemented.") + # TODO: Define as `UnknownLayout()`? + # TODO: Use `MethodError`? + return error("Not implemented.") end ## TODO: Define `const AbstractMatrixInterface = AbstractArrayInterface{2}`, diff --git a/src/abstractinterface.jl b/src/abstractinterface.jl index dd89129..a5b144b 100644 --- a/src/abstractinterface.jl +++ b/src/abstractinterface.jl @@ -13,22 +13,22 @@ interface(x::AbstractInterface) = x # Adapted from `Base.Broadcast.combine_styles`. # Get the combined interfaces of the input objects. function combine_interfaces( - inter1::AbstractInterface, inter2::AbstractInterface, inter_rest::AbstractInterface... -) - return combine_interfaces(combine_interface_rule(inter1, inter2), inter_rest...) + inter1::AbstractInterface, inter2::AbstractInterface, inter_rest::AbstractInterface... + ) + return combine_interfaces(combine_interface_rule(inter1, inter2), inter_rest...) end function combine_interfaces(inter1::AbstractInterface, inter2::AbstractInterface) - return combine_interface_rule(inter1, inter2) + return combine_interface_rule(inter1, inter2) end combine_interfaces(inter::AbstractInterface) = inter # Rules for combining interfaces. function combine_interface_rule( - inter1::Interface, inter2::Interface -) where {Interface<:AbstractInterface} - return inter1 + inter1::Interface, inter2::Interface + ) where {Interface <: AbstractInterface} + return inter1 end # TODO: Define as `UnknownInterface()`. function combine_interface_rule(inter1::AbstractInterface, inter2::AbstractInterface) - return error("No rule for combining interfaces.") + return error("No rule for combining interfaces.") end diff --git a/src/concatenate.jl b/src/concatenate.jl index 5f7b019..d3b640d 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -41,43 +41,43 @@ function _Concatenated end Lazy representation of the concatenation of various `Args` along `Dims`, in order to provide hooks to customize the implementation. """ -struct Concatenated{Interface,Dims,Args<:Tuple} - interface::Interface - dims::Val{Dims} - args::Args - global @inline function _Concatenated( - interface::Interface, dims::Val{Dims}, args::Args - ) where {Interface,Dims,Args<:Tuple} - return new{Interface,Dims,Args}(interface, dims, args) - end +struct Concatenated{Interface, Dims, Args <: Tuple} + interface::Interface + dims::Val{Dims} + args::Args + global @inline function _Concatenated( + interface::Interface, dims::Val{Dims}, args::Args + ) where {Interface, Dims, Args <: Tuple} + return new{Interface, Dims, Args}(interface, dims, args) + end end function Concatenated( - interface::Union{AbstractArrayInterface,Nothing}, dims::Val, args::Tuple -) - return _Concatenated(interface, dims, args) + interface::Union{AbstractArrayInterface, Nothing}, dims::Val, args::Tuple + ) + return _Concatenated(interface, dims, args) end function Concatenated(dims::Val, args::Tuple) - return Concatenated(cat_interface(dims, args...), dims, args) + return Concatenated(cat_interface(dims, args...), dims, args) end function Concatenated{Interface}( - dims::Val, args::Tuple -) where {Interface<:Union{AbstractArrayInterface,Nothing}} - return Concatenated(Interface(), dims, args) + dims::Val, args::Tuple + ) where {Interface <: Union{AbstractArrayInterface, Nothing}} + return Concatenated(Interface(), dims, args) end -dims(::Concatenated{<:Any,D}) where {D} = D +dims(::Concatenated{<:Any, D}) where {D} = D DerivableInterfaces.interface(concat::Concatenated) = getfield(concat, :interface) concatenated(dims, args...) = concatenated(Val(dims), args...) concatenated(dims::Val, args...) = Concatenated(dims, args) function Base.convert( - ::Type{Concatenated{NewInterface}}, concat::Concatenated{<:Any,Dims,Args} -) where {NewInterface,Dims,Args} - return Concatenated{NewInterface}( - concat.dims, concat.args - )::Concatenated{NewInterface,Dims,Args} + ::Type{Concatenated{NewInterface}}, concat::Concatenated{<:Any, Dims, Args} + ) where {NewInterface, Dims, Args} + return Concatenated{NewInterface}( + concat.dims, concat.args + )::Concatenated{NewInterface, Dims, Args} end # allocating the destination container @@ -85,38 +85,38 @@ end Base.similar(concat::Concatenated) = similar(concat, eltype(concat)) Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(concat)) function Base.similar(concat::Concatenated, ax::Tuple) - return similar(interface(concat), eltype(concat), ax) + return similar(interface(concat), eltype(concat), ax) end function Base.similar(concat::Concatenated, ::Type{T}, ax::Tuple) where {T} - return similar(interface(concat), T, ax) + return similar(interface(concat), T, ax) end function cat_axis( - a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange... -) - return cat_axis(cat_axis(a1, a2), a_rest...) + a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange... + ) + return cat_axis(cat_axis(a1, a2), a_rest...) end cat_axis(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2)) function cat_ndims(dims, as::AbstractArray...) - return max(maximum(dims), maximum(ndims, as)) + return max(maximum(dims), maximum(ndims, as)) end function cat_ndims(dims::Val, as::AbstractArray...) - return cat_ndims(unval(dims), as...) + return cat_ndims(unval(dims), as...) end function cat_axes(dims, a::AbstractArray, as::AbstractArray...) - return ntuple(cat_ndims(dims, a, as...)) do dim - return dim in dims ? cat_axis(map(Base.Fix2(axes, dim), (a, as...))...) : axes(a, dim) - end + return ntuple(cat_ndims(dims, a, as...)) do dim + return dim in dims ? cat_axis(map(Base.Fix2(axes, dim), (a, as...))...) : axes(a, dim) + end end function cat_axes(dims::Val, as::AbstractArray...) - return cat_axes(unval(dims), as...) + return cat_axes(unval(dims), as...) end function cat_interface(dims, as::AbstractArray...) - N = cat_ndims(dims, as...) - return typeof(interface(as...))(Val(N)) + N = cat_ndims(dims, as...) + return typeof(interface(as...))(Val(N)) end Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) @@ -151,8 +151,8 @@ Base.materialize(concat::Concatenated) = copy(concat) Concatenate the supplied `args` along dimensions `dims`, placing the result into `dest`. """ function cat!(dest, args...; dims) - Base.materialize!(dest, concatenated(dims, args...)) - return dest + Base.materialize!(dest, concatenated(dims, args...)) + return dest end Base.materialize!(dest, concat::Concatenated) = copyto!(dest, concat) @@ -172,45 +172,45 @@ cat_indices(A, d) = Base.OneTo(1) cat_indices(A::AbstractArray, d) = axes(A, d) function __cat!(A, shape, catdims, X...) - return __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...) + return __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...) end function __cat_offset!(A, shape, catdims, offsets, x, X...) - # splitting the "work" on x from X... may reduce latency (fewer costly specializations) - newoffsets = __cat_offset1!(A, shape, catdims, offsets, x) - return __cat_offset!(A, shape, catdims, newoffsets, X...) + # splitting the "work" on x from X... may reduce latency (fewer costly specializations) + newoffsets = __cat_offset1!(A, shape, catdims, offsets, x) + return __cat_offset!(A, shape, catdims, newoffsets, X...) end __cat_offset!(A, shape, catdims, offsets) = A function __cat_offset1!(A, shape, catdims, offsets, x) - inds = ntuple(length(offsets)) do i - (i <= length(catdims) && catdims[i]) ? offsets[i] .+ cat_indices(x, i) : 1:shape[i] - end - _copy_or_fill!(A, inds, x) - newoffsets = ntuple(length(offsets)) do i - (i <= length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i] - end - return newoffsets + inds = ntuple(length(offsets)) do i + (i <= length(catdims) && catdims[i]) ? offsets[i] .+ cat_indices(x, i) : 1:shape[i] + end + _copy_or_fill!(A, inds, x) + newoffsets = ntuple(length(offsets)) do i + (i <= length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i] + end + return newoffsets end dims2cat(dims::Val) = dims2cat(unval(dims)) function dims2cat(dims) - if any(≤(0), dims) - throw(ArgumentError("All cat dimensions must be positive integers, but got $dims")) - end - return ntuple(in(dims), maximum(dims)) + if any(≤(0), dims) + throw(ArgumentError("All cat dimensions must be positive integers, but got $dims")) + end + return ntuple(in(dims), maximum(dims)) end # default falls back to replacing interface with Nothing # this permits specializing on typeof(dest) without ambiguities # Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base. @inline function Base.copyto!(dest::AbstractArray, concat::Concatenated) - return copyto!(dest, convert(Concatenated{Nothing}, concat)) + return copyto!(dest, convert(Concatenated{Nothing}, concat)) end function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing}) - catdims = dims2cat(dims(concat)) - shape = size(concat) - count(!iszero, catdims)::Int > 1 && zero!(dest) - return __cat!(dest, shape, catdims, concat.args...) + catdims = dims2cat(dims(concat)) + shape = size(concat) + count(!iszero, catdims)::Int > 1 && zero!(dest) + return __cat!(dest, shape, catdims, concat.args...) end end diff --git a/src/defaultarrayinterface.jl b/src/defaultarrayinterface.jl index d0d7a2e..95f90a3 100644 --- a/src/defaultarrayinterface.jl +++ b/src/defaultarrayinterface.jl @@ -1,81 +1,81 @@ using TypeParameterAccessors: parenttype, set_eltype, unspecify_type_parameters -struct DefaultArrayInterface{N,A<:AbstractArray} <: AbstractArrayInterface{N} end +struct DefaultArrayInterface{N, A <: AbstractArray} <: AbstractArrayInterface{N} end -DefaultArrayInterface{N}() where {N} = DefaultArrayInterface{N,AbstractArray}() +DefaultArrayInterface{N}() where {N} = DefaultArrayInterface{N, AbstractArray}() DefaultArrayInterface() = DefaultArrayInterface{Any}() DefaultArrayInterface(::Val{N}) where {N} = DefaultArrayInterface{N}() -DefaultArrayInterface{M}(::Val{N}) where {M,N} = DefaultArrayInterface{N}() -DefaultArrayInterface{M,A}(::Val{N}) where {M,A,N} = DefaultArrayInterface{N,A}() +DefaultArrayInterface{M}(::Val{N}) where {M, N} = DefaultArrayInterface{N}() +DefaultArrayInterface{M, A}(::Val{N}) where {M, A, N} = DefaultArrayInterface{N, A}() # This version remembers the `ndims` of the wrapper type. function _interface(::Val{N}, arrayt::Type{<:AbstractArray}) where {N} - arrayt′ = parenttype(arrayt) - if arrayt′ === arrayt - return DefaultArrayInterface{N,unspecify_type_parameters(arrayt)}() - end - return typeof(interface(arrayt′))(Val(N)) + arrayt′ = parenttype(arrayt) + if arrayt′ === arrayt + return DefaultArrayInterface{N, unspecify_type_parameters(arrayt)}() + end + return typeof(interface(arrayt′))(Val(N)) end -function DerivableInterfaces.interface(arrayt::Type{<:AbstractArray{<:Any,N}}) where {N} - return _interface(Val(N), arrayt) +function DerivableInterfaces.interface(arrayt::Type{<:AbstractArray{<:Any, N}}) where {N} + return _interface(Val(N), arrayt) end function DerivableInterfaces.interface(arrayt::Type{<:AbstractArray}) - return _interface(Val(Any), arrayt) + return _interface(Val(Any), arrayt) end function Base.similar( - ::DefaultArrayInterface{<:Any,A}, T::Type, ax::Tuple -) where {A<:AbstractArray} - if isabstracttype(A) - # If the type is abstract, default to constructing the array on CPU. - return similar(Array{T}, ax) - else - return similar(set_eltype(A, T), ax) - end + ::DefaultArrayInterface{<:Any, A}, T::Type, ax::Tuple + ) where {A <: AbstractArray} + if isabstracttype(A) + # If the type is abstract, default to constructing the array on CPU. + return similar(Array{T}, ax) + else + return similar(set_eltype(A, T), ax) + end end function combine_interface_rule( - interface1::DefaultArrayInterface{N,A}, interface2::DefaultArrayInterface{N,A} -) where {N,A<:AbstractArray} - return DefaultArrayInterface{N,A}() + interface1::DefaultArrayInterface{N, A}, interface2::DefaultArrayInterface{N, A} + ) where {N, A <: AbstractArray} + return DefaultArrayInterface{N, A}() end function combine_interface_rule( - interface1::DefaultArrayInterface{<:Any,A}, interface2::DefaultArrayInterface{<:Any,A} -) where {A<:AbstractArray} - return DefaultArrayInterface{Any,A}() + interface1::DefaultArrayInterface{<:Any, A}, interface2::DefaultArrayInterface{<:Any, A} + ) where {A <: AbstractArray} + return DefaultArrayInterface{Any, A}() end function combine_interface_rule( - interface1::DefaultArrayInterface{N}, interface2::DefaultArrayInterface{N} -) where {N} - return DefaultArrayInterface{N}() + interface1::DefaultArrayInterface{N}, interface2::DefaultArrayInterface{N} + ) where {N} + return DefaultArrayInterface{N}() end function combine_interface_rule( - interface1::DefaultArrayInterface, interface2::DefaultArrayInterface -) - return DefaultArrayInterface() + interface1::DefaultArrayInterface, interface2::DefaultArrayInterface + ) + return DefaultArrayInterface() end @interface ::DefaultArrayInterface function Base.getindex( - a::AbstractArray{<:Any,N}, I::Vararg{Int,N} -) where {N} - return Base.getindex(a, I...) + a::AbstractArray{<:Any, N}, I::Vararg{Int, N} + ) where {N} + return Base.getindex(a, I...) end @interface ::DefaultArrayInterface function Base.setindex!( - a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N} -) where {N} - return Base.setindex!(a, value, I...) + a::AbstractArray{<:Any, N}, value, I::Vararg{Int, N} + ) where {N} + return Base.setindex!(a, value, I...) end @interface ::DefaultArrayInterface function Base.map!( - f, a_dest::AbstractArray, a_srcs::AbstractArray... -) - return Base.map!(f, a_dest, a_srcs...) + f, a_dest::AbstractArray, a_srcs::AbstractArray... + ) + return Base.map!(f, a_dest, a_srcs...) end @interface ::DefaultArrayInterface function Base.mapreduce( - f, op, as::AbstractArray...; kwargs... -) - return Base.mapreduce(f, op, as...; kwargs...) + f, op, as::AbstractArray...; kwargs... + ) + return Base.mapreduce(f, op, as...; kwargs...) end diff --git a/src/derive_macro.jl b/src/derive_macro.jl index 52b2657..d3e9e90 100644 --- a/src/derive_macro.jl +++ b/src/derive_macro.jl @@ -8,23 +8,23 @@ argname(i::Int) = Symbol(:arg, i) # TODO: Use this type of function to replace `DerivableInterfaces.f` with `GlobalRef(DerivableInterfaces, f)` # and also replace `T` with `SparseArrayDOK`. function rmlines(expr) - return @match expr begin - e::Expr => Expr(e.head, filter(!isnothing, map(rmlines, e.args))...) - _::LineNumberNode => nothing - a => a - end + return @match expr begin + e::Expr => Expr(e.head, filter(!isnothing, map(rmlines, e.args))...) + _::LineNumberNode => nothing + a => a + end end function globalref_derive(expr) - return @match expr begin - :(DerivableInterfaces.$f) => :($(GlobalRef(DerivableInterfaces, :($f)))) - e::Expr => Expr(e.head, map(globalref_derive, e.args)...) - a => a - end + return @match expr begin + :(DerivableInterfaces.$f) => :($(GlobalRef(DerivableInterfaces, :($f)))) + e::Expr => Expr(e.head, map(globalref_derive, e.args)...) + a => a + end end macro derive(expr...) - return esc(derive_expr(expr...)) + return esc(derive_expr(expr...)) end #== @@ -44,11 +44,11 @@ end end ``` ==# -function derive_expr(interface_or_types::Union{Symbol,Expr}, funcs::Expr) - return @match funcs begin - Expr(:call, _...) => derive_func(interface_or_types, funcs) - Expr(:block, _...) => derive_funcs(interface_or_types, funcs) - end +function derive_expr(interface_or_types::Union{Symbol, Expr}, funcs::Expr) + return @match funcs begin + Expr(:call, _...) => derive_func(interface_or_types, funcs) + Expr(:block, _...) => derive_funcs(interface_or_types, funcs) + end end #== @@ -61,11 +61,11 @@ end end ``` ==# -function derive_expr(interface::Union{Symbol,Expr}, types::Expr, funcs::Expr) - return @match funcs begin - Expr(:call, _...) => derive_func(interface, types, funcs) - Expr(:block, _...) => derive_funcs(interface, types, funcs) - end +function derive_expr(interface::Union{Symbol, Expr}, types::Expr, funcs::Expr) + return @match funcs begin + Expr(:call, _...) => derive_func(interface, types, funcs) + Expr(:block, _...) => derive_funcs(interface, types, funcs) + end end #== @@ -73,8 +73,8 @@ end @derive SparseArrayDOK AbstractArrayOps ``` ==# -function derive_expr(type::Union{Symbol,Expr}, trait::Symbol) - return derive_trait(type, trait) +function derive_expr(type::Union{Symbol, Expr}, trait::Symbol) + return derive_trait(type, trait) end #== @@ -83,19 +83,19 @@ end ``` ==# function derive_expr( - interface::Union{Symbol,Expr}, types::Union{Symbol,Expr}, trait::Symbol -) - return derive_trait(interface, types, trait) + interface::Union{Symbol, Expr}, types::Union{Symbol, Expr}, trait::Symbol + ) + return derive_trait(interface, types, trait) end function derive_funcs(args...) - interface_and_or_types = Base.front(args) - funcs = last(args) - Meta.isexpr(funcs, :block) || error("Expected a block.") - funcs = rmlines(funcs) - return Expr( - :block, map(func -> derive_func(interface_and_or_types..., func), funcs.args)... - ) + interface_and_or_types = Base.front(args) + funcs = last(args) + Meta.isexpr(funcs, :block) || error("Expected a block.") + funcs = rmlines(funcs) + return Expr( + :block, map(func -> derive_func(interface_and_or_types..., func), funcs.args)... + ) end #= @@ -107,31 +107,31 @@ In: replace `T` with `SparseArrayDOK`. =# function replace_typevars(types::Expr, func::Expr) - Meta.isexpr(types, :tuple) && all(arg -> Meta.isexpr(arg, :(=)), types.args) || - error("Wrong types format.") - name, args, kwargs, whereparams, rettype = split_function_head(func) - new_args = args - for type_expr in types.args - typevar, type = @match type_expr begin - :($x = $y) => (x, y) + Meta.isexpr(types, :tuple) && all(arg -> Meta.isexpr(arg, :(=)), types.args) || + error("Wrong types format.") + name, args, kwargs, whereparams, rettype = split_function_head(func) + new_args = args + for type_expr in types.args + typevar, type = @match type_expr begin + :($x = $y) => (x, y) + end + # TODO: Handle type parameters in other positions besides the first one. + new_args = map(args) do arg + return @match arg begin + :(::$Type{<:$T}) => T == typevar ? :(::$Type{<:$type}) : :(::$Type{<:$T}) + :(::$T...) => T == typevar ? :(::$type...) : :(::$T...) + :(::$T) => T == typevar ? :(::$type) : :(::$T) + end + end end - # TODO: Handle type parameters in other positions besides the first one. - new_args = map(args) do arg - return @match arg begin - :(::$Type{<:$T}) => T == typevar ? :(::$Type{<:$type}) : :(::$Type{<:$T}) - :(::$T...) => T == typevar ? :(::$type...) : :(::$T...) - :(::$T) => T == typevar ? :(::$type) : :(::$T) - end - end - end - _, new_func = split_function( - codegen_ast(JLFunction(; name, args=new_args, kwargs, whereparams, rettype)) - ) - return new_func + _, new_func = split_function( + codegen_ast(JLFunction(; name, args = new_args, kwargs, whereparams, rettype)) + ) + return new_func end function derive_func(interface::Symbol, func::Expr) - return derive_interface_func(:($(interface)()), func) + return derive_interface_func(:($(interface)()), func) end #= @@ -140,14 +140,14 @@ end @derive (T=SparseArrayDOK,) Base.getindex(::T, ::Int...) ``` =# -function derive_func(interface_or_types::Union{Symbol,Expr}, func::Expr) - if Meta.isexpr(interface_or_types, :tuple) && - all(arg -> Meta.isexpr(arg, :(=)), interface_or_types.args) - types = interface_or_types - return derive_func_from_types(types, func) - end - interface = interface_or_types - return derive_interface_func(interface, func) +function derive_func(interface_or_types::Union{Symbol, Expr}, func::Expr) + if Meta.isexpr(interface_or_types, :tuple) && + all(arg -> Meta.isexpr(arg, :(=)), interface_or_types.args) + types = interface_or_types + return derive_func_from_types(types, func) + end + interface = interface_or_types + return derive_interface_func(interface, func) end #= @@ -156,17 +156,17 @@ end ``` =# function derive_func_from_types(types::Expr, func::Expr) - new_func = replace_typevars(types, func) - _, args = split_function_head(func) - _, new_args = split_function_head(new_func) - active_argnames = map(findall(args .≠ new_args)) do i - if Meta.isexpr(args[i], :...) - return :($(argname(i))...) + new_func = replace_typevars(types, func) + _, args = split_function_head(func) + _, new_args = split_function_head(new_func) + active_argnames = map(findall(args .≠ new_args)) do i + if Meta.isexpr(args[i], :...) + return :($(argname(i))...) + end + return argname(i) end - return argname(i) - end - interface = globalref_derive(:(DerivableInterfaces.interface($(active_argnames...)))) - return derive_interface_func(interface, new_func) + interface = globalref_derive(:(DerivableInterfaces.interface($(active_argnames...)))) + return derive_interface_func(interface, new_func) end #= @@ -174,46 +174,46 @@ end @derive SparseArrayInterface() (T=SparseArrayDOK,) Base.getindex(::T, ::Int...) ``` =# -function derive_func(interface::Union{Symbol,Expr}, types::Expr, func::Expr) - new_func = replace_typevars(types, func) - return derive_interface_func(:($(interface)), new_func) +function derive_func(interface::Union{Symbol, Expr}, types::Expr, func::Expr) + new_func = replace_typevars(types, func) + return derive_interface_func(:($(interface)), new_func) end #= Core implementation of `@derive`. =# -function derive_interface_func(interface::Union{Symbol,Expr}, func::Expr) - name, args, kwargs, whereparams, rettype = split_function_head(func) - argnames = map(argname, 1:length(args)) - named_args = map(1:length(args)) do i - argname, arg = argnames[i], args[i] - return @match arg begin - :(::$T) => :($argname::$T) - :(::$T...) => :($argname::$T...) +function derive_interface_func(interface::Union{Symbol, Expr}, func::Expr) + name, args, kwargs, whereparams, rettype = split_function_head(func) + argnames = map(argname, 1:length(args)) + named_args = map(1:length(args)) do i + argname, arg = argnames[i], args[i] + return @match arg begin + :(::$T) => :($argname::$T) + :(::$T...) => :($argname::$T...) + end + end + # TODO: Insert `interface` as first argument. + body_args = map(1:length(args)) do i + argname, arg = argnames[i], args[i] + return @match arg begin + :(::$T) => :($argname) + :(::$T...) => :($argname...) + end end - end - # TODO: Insert `interface` as first argument. - body_args = map(1:length(args)) do i - argname, arg = argnames[i], args[i] - return @match arg begin - :(::$T) => :($argname) - :(::$T...) => :($argname...) + # TODO: Use the `@interface` macro rather than `DerivableInterfaces.call` + # directly, in case we want to change the implementation. + body_args = [interface; name; body_args...] + body_name = @match name begin + :($M.$f) => :(DerivableInterfaces.call) end - end - # TODO: Use the `@interface` macro rather than `DerivableInterfaces.call` - # directly, in case we want to change the implementation. - body_args = [interface; name; body_args...] - body_name = @match name begin - :($M.$f) => :(DerivableInterfaces.call) - end - # TODO: Remove defaults from `kwargs`. - _, body, _ = split_function( - codegen_ast(JLFunction(; name=body_name, args=body_args, kwargs)) - ) - jlfn = JLFunction(; name, args=named_args, kwargs, whereparams, rettype, body) - # Use `globalref_derive` to not require having `DerivableInterfaces` in the - # namespace when `@derive` is called. - return globalref_derive(codegen_ast(jlfn)) + # TODO: Remove defaults from `kwargs`. + _, body, _ = split_function( + codegen_ast(JLFunction(; name = body_name, args = body_args, kwargs)) + ) + jlfn = JLFunction(; name, args = named_args, kwargs, whereparams, rettype, body) + # Use `globalref_derive` to not require having `DerivableInterfaces` in the + # namespace when `@derive` is called. + return globalref_derive(codegen_ast(jlfn)) end #= @@ -222,10 +222,10 @@ end ``` =# function derive_trait( - interface::Union{Symbol,Expr}, type::Union{Symbol,Expr}, trait::Symbol -) - funcs = Expr(:block, derive(Val(trait), type).args...) - return derive_funcs(interface, funcs) + interface::Union{Symbol, Expr}, type::Union{Symbol, Expr}, trait::Symbol + ) + funcs = Expr(:block, derive(Val(trait), type).args...) + return derive_funcs(interface, funcs) end #= @@ -233,8 +233,8 @@ end @derive SparseArrayDOK AbstractArrayOps ``` =# -function derive_trait(type::Union{Symbol,Expr}, trait::Symbol) - types = :((T=($type),)) - funcs = Expr(:block, derive(Val(trait), :T).args...) - return derive_funcs(types, funcs) +function derive_trait(type::Union{Symbol, Expr}, trait::Symbol) + types = :((T = ($type),)) + funcs = Expr(:block, derive(Val(trait), :T).args...) + return derive_funcs(types, funcs) end diff --git a/src/fillarrays.jl b/src/fillarrays.jl index aef5ae1..9b1a7b9 100644 --- a/src/fillarrays.jl +++ b/src/fillarrays.jl @@ -3,7 +3,7 @@ # in Julia v1.10. using FillArrays: RectDiagonal function permuteddims(a::RectDiagonal, perm) - (ndims(a) == length(perm) && isperm(perm)) || - throw(ArgumentError("no valid permutation of dimensions")) - return RectDiagonal(parent(a), ntuple(d -> axes(a)[perm[d]], ndims(a))) + (ndims(a) == length(perm) && isperm(perm)) || + throw(ArgumentError("no valid permutation of dimensions")) + return RectDiagonal(parent(a), ntuple(d -> axes(a)[perm[d]], ndims(a))) end diff --git a/src/interface_function.jl b/src/interface_function.jl index 0ebde58..ce7eda9 100644 --- a/src/interface_function.jl +++ b/src/interface_function.jl @@ -10,8 +10,8 @@ call(interface, f, args...) = f(args...) call(interface, f, args...; kwargs...) = error("Not implemented") # Change the behavior of a function to use a certain interface. -struct InterfaceFunction{Interface,F} <: Function - interface::Interface - f::F +struct InterfaceFunction{Interface, F} <: Function + interface::Interface + f::F end (f::InterfaceFunction)(args...; kwargs...) = call(f.interface, f.f, args...; kwargs...) diff --git a/src/interface_macro.jl b/src/interface_macro.jl index 1627385..0861195 100644 --- a/src/interface_macro.jl +++ b/src/interface_macro.jl @@ -2,7 +2,7 @@ using ExproniconLite: JLFunction, codegen_ast, split_function, split_function_he using MLStyle: @match macro interface(expr...) - return esc(interface_expr(expr...)) + return esc(interface_expr(expr...)) end # TODO: Use `MLStyle.@match`/`Moshi.@match`. @@ -13,16 +13,16 @@ isrefexpr(expr) = Meta.isexpr(expr, :ref) # a[I...] = value issetrefexpr(expr) = Meta.isexpr(expr, :(=)) && isrefexpr(expr.args[1]) -function interface_expr(interface::Union{Symbol,Expr}, func::Expr) - # TODO: Use `MLStyle.@match`/`Moshi.@match`. - # f(args...) - iscallexpr(func) && return interface_call(interface, func) - # a[I...] - isrefexpr(func) && return interface_ref(interface, func) - # a[I...] = value - issetrefexpr(func) && return interface_setref(interface, func) - # Assume it is a function definition. - return interface_definition(interface, func) +function interface_expr(interface::Union{Symbol, Expr}, func::Expr) + # TODO: Use `MLStyle.@match`/`Moshi.@match`. + # f(args...) + iscallexpr(func) && return interface_call(interface, func) + # a[I...] + isrefexpr(func) && return interface_ref(interface, func) + # a[I...] = value + issetrefexpr(func) && return interface_setref(interface, func) + # Assume it is a function definition. + return interface_definition(interface, func) end #= @@ -39,17 +39,21 @@ to: DerivableInterfaces.call(SparseArrayInterface(), Base.getindex, a, I...) ``` =# -function interface_call(interface::Union{Symbol,Expr}, func::Expr) - return @match func begin - :($name($(args...))) => - :($(GlobalRef(DerivableInterfaces, :InterfaceFunction))($interface, $name)( - $(args...) - )) - :($name($(args...); $(kwargs...))) => - :($(GlobalRef(DerivableInterfaces, :InterfaceFunction))($interface, $name)( - $(args...); $(kwargs...) - )) - end +function interface_call(interface::Union{Symbol, Expr}, func::Expr) + return @match func begin + :($name($(args...))) => + :( + $(GlobalRef(DerivableInterfaces, :InterfaceFunction))($interface, $name)( + $(args...) + ) + ) + :($name($(args...); $(kwargs...))) => + :( + $(GlobalRef(DerivableInterfaces, :InterfaceFunction))($interface, $name)( + $(args...); $(kwargs...) + ) + ) + end end #= @@ -62,11 +66,11 @@ to: DerivableInterfaces.call(SparseArrayInterface(), Base.getindex, a, I...) ``` =# -function interface_ref(interface::Union{Symbol,Expr}, func::Expr) - func = @match func begin - :($a[$(I...)]) => :(Base.getindex($a, $(I...))) - end - return interface_call(interface, func) +function interface_ref(interface::Union{Symbol, Expr}, func::Expr) + func = @match func begin + :($a[$(I...)]) => :(Base.getindex($a, $(I...))) + end + return interface_call(interface, func) end #= @@ -79,12 +83,12 @@ to: DerivableInterfaces.call(SparseArrayInterface(), Base.setindex!, a, value, I...) ``` =# -function interface_setref(interface::Union{Symbol,Expr}, func::Expr) - return @match func begin - :($a[$(I...)] = $value) => Expr( - :block, interface_call(interface, :(Base.setindex!($a, $value, $(I...)))), :($value) - ) - end +function interface_setref(interface::Union{Symbol, Expr}, func::Expr) + return @match func begin + :($a[$(I...)] = $value) => Expr( + :block, interface_call(interface, :(Base.setindex!($a, $value, $(I...)))), :($value) + ) + end end #= @@ -103,17 +107,17 @@ function DerivableInterfaces.call(interface::SparseArrayInterface, Base.getindex end ``` =# -function interface_definition(interface::Union{Symbol,Expr}, func::Expr) - head, call, body = split_function(func) - name, args, kwargs, whereparams, rettype = split_function_head(call) - new_name = :(DerivableInterfaces.call) - # We use `Core.Typeof` here because `name` can either be a function or type, - # and `typeof(T::Type)` outputs things like `DataType`, `UnionAll`, etc. - # while `Core.Typeof(T::Type)` returns `Type{T}`. - new_args = [:($interface); :(::Core.Typeof($name)); args] - return globalref_derive( - codegen_ast( - JLFunction(; name=new_name, args=new_args, kwargs, rettype, whereparams, body) - ), - ) +function interface_definition(interface::Union{Symbol, Expr}, func::Expr) + head, call, body = split_function(func) + name, args, kwargs, whereparams, rettype = split_function_head(call) + new_name = :(DerivableInterfaces.call) + # We use `Core.Typeof` here because `name` can either be a function or type, + # and `typeof(T::Type)` outputs things like `DataType`, `UnionAll`, etc. + # while `Core.Typeof(T::Type)` returns `Type{T}`. + new_args = [:($interface); :(::Core.Typeof($name)); args] + return globalref_derive( + codegen_ast( + JLFunction(; name = new_name, args = new_args, kwargs, rettype, whereparams, body) + ), + ) end diff --git a/src/permuteddims.jl b/src/permuteddims.jl index d64615a..396796b 100644 --- a/src/permuteddims.jl +++ b/src/permuteddims.jl @@ -9,7 +9,7 @@ permuteddims(a::AbstractArray, perm) = PermutedDimsArray(a, perm) using LinearAlgebra: Diagonal function permuteddims(a::Diagonal, perm) - (ndims(a) == length(perm) && isperm(perm)) || - throw(ArgumentError("no valid permutation of dimensions")) - return a + (ndims(a) == length(perm) && isperm(perm)) || + throw(ArgumentError("no valid permutation of dimensions")) + return a end diff --git a/src/traits.jl b/src/traits.jl index 52a1a9b..59482da 100644 --- a/src/traits.jl +++ b/src/traits.jl @@ -21,38 +21,38 @@ end ``` =# function derive(::Val{:AbstractArrayOps}, type) - return quote - Base.getindex(::$type, ::Any...) - Base.getindex(::$type, ::Int...) - Base.setindex!(::$type, ::Any, ::Any...) - Base.setindex!(::$type, ::Any, ::Int...) - Base.similar(::$type, ::Type, ::Tuple{Vararg{Int}}) - Base.similar(::$type, ::Type, ::Tuple{Base.OneTo,Vararg{Base.OneTo}}) - Base.copy(::$type) - Base.copy!(::AbstractArray, ::$type) - Base.copyto!(::AbstractArray, ::$type) - Base.map(::Any, ::$type...) - Base.map!(::Any, ::AbstractArray, ::$type...) - Base.mapreduce(::Any, ::Any, ::$type...; kwargs...) - Base.reduce(::Any, ::$type...; kwargs...) - Base.all(::Function, ::$type) - Base.all(::$type) - Base.iszero(::$type) - Base.real(::$type) - Base.fill!(::$type, ::Any) - DerivableInterfaces.zero!(::$type) - Base.zero(::$type) - Base.permutedims!(::Any, ::$type, ::Any) - Broadcast.BroadcastStyle(::Type{<:$type}) - Base.copyto!(::$type, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}) - ArrayLayouts.MemoryLayout(::Type{<:$type}) - LinearAlgebra.mul!(::AbstractMatrix, ::$type, ::$type, ::Number, ::Number) - end + return quote + Base.getindex(::$type, ::Any...) + Base.getindex(::$type, ::Int...) + Base.setindex!(::$type, ::Any, ::Any...) + Base.setindex!(::$type, ::Any, ::Int...) + Base.similar(::$type, ::Type, ::Tuple{Vararg{Int}}) + Base.similar(::$type, ::Type, ::Tuple{Base.OneTo, Vararg{Base.OneTo}}) + Base.copy(::$type) + Base.copy!(::AbstractArray, ::$type) + Base.copyto!(::AbstractArray, ::$type) + Base.map(::Any, ::$type...) + Base.map!(::Any, ::AbstractArray, ::$type...) + Base.mapreduce(::Any, ::Any, ::$type...; kwargs...) + Base.reduce(::Any, ::$type...; kwargs...) + Base.all(::Function, ::$type) + Base.all(::$type) + Base.iszero(::$type) + Base.real(::$type) + Base.fill!(::$type, ::Any) + DerivableInterfaces.zero!(::$type) + Base.zero(::$type) + Base.permutedims!(::Any, ::$type, ::Any) + Broadcast.BroadcastStyle(::Type{<:$type}) + Base.copyto!(::$type, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}) + ArrayLayouts.MemoryLayout(::Type{<:$type}) + LinearAlgebra.mul!(::AbstractMatrix, ::$type, ::$type, ::Number, ::Number) + end end function derive(::Val{:AbstractArrayStyleOps}, type) - return quote - Base.similar(::Broadcast.Broadcasted{<:$type}, ::Type, ::Tuple) - Base.copyto!(::AbstractArray, ::Broadcast.Broadcasted{<:$type}) - end + return quote + Base.similar(::Broadcast.Broadcasted{<:$type}, ::Type, ::Tuple) + Base.copyto!(::AbstractArray, ::Broadcast.Broadcasted{<:$type}) + end end diff --git a/src/wrappedarrays.jl b/src/wrappedarrays.jl index 42b6f35..5e424bf 100644 --- a/src/wrappedarrays.jl +++ b/src/wrappedarrays.jl @@ -1,8 +1,8 @@ -function symbol_replace(symbol::Symbol, replacement::Pair{Symbol,Symbol}) - return Symbol(replace(String(symbol), String(replacement[1]) => String(replacement[2]))) +function symbol_replace(symbol::Symbol, replacement::Pair{Symbol, Symbol}) + return Symbol(replace(String(symbol), String(replacement[1]) => String(replacement[2]))) end function symbol_cat(symbol1::Symbol, symbol2::Symbol) - return Symbol(symbol1, symbol2) + return Symbol(symbol1, symbol2) end vectype(type::Symbol) = symbol_replace(type, :Array => :Vector) @@ -12,15 +12,15 @@ anytype(type::Symbol) = symbol_cat(:Any, type) wrappedtype(type::Symbol) = symbol_cat(:Wrapped, type) macro vecmat_aliases(type) - return esc(vecmat_aliases(type)) + return esc(vecmat_aliases(type)) end function vecmat_aliases(type::Symbol) - return quote - const $(vectype(type)){T} = $type{T,1} - const $(mattype(type)){T} = $type{T,2} - const $(vecormattype(type)){T} = Union{$(vectype(type)){T},$(mattype(type)){T}} - end + return quote + const $(vectype(type)){T} = $type{T, 1} + const $(mattype(type)){T} = $type{T, 2} + const $(vecormattype(type)){T} = Union{$(vectype(type)){T}, $(mattype(type)){T}} + end end using Adapt: Adapt, WrappedArray @@ -29,29 +29,29 @@ using Adapt: Adapt, WrappedArray # i.e. `Adjoint`, Transpose`, `Diagonal`, etc. Maybe call it # `wrapped_vec_aliases` and `wrapped_mat_aliases`. macro wrapped_aliases(type) - return esc(wrapped_aliases(type)) + return esc(wrapped_aliases(type)) end function wrapped_aliases(type::Symbol) - return quote - const $(wrappedtype(type)){T,N} = $(GlobalRef(Adapt, :WrappedArray)){ - T,N,$type,$type{T,N} - } - const $(anytype(type)){T,N} = Union{$type{T,N},$(wrappedtype(type)){T,N}} - end + return quote + const $(wrappedtype(type)){T, N} = $(GlobalRef(Adapt, :WrappedArray)){ + T, N, $type, $type{T, N}, + } + const $(anytype(type)){T, N} = Union{$type{T, N}, $(wrappedtype(type)){T, N}} + end end macro array_aliases(type) - return esc(array_aliases(type)) + return esc(array_aliases(type)) end function array_aliases(type::Symbol) - # TODO: I tried to implement this by using `quote` and calling - # out to the macros but I couldn't get it to work with `GlobalRef`. - return Expr( - :block, - vecmat_aliases(type).args..., - wrapped_aliases(type).args..., - vecmat_aliases(anytype(type)).args..., - ) + # TODO: I tried to implement this by using `quote` and calling + # out to the macros but I couldn't get it to work with `GlobalRef`. + return Expr( + :block, + vecmat_aliases(type).args..., + wrapped_aliases(type).args..., + vecmat_aliases(anytype(type)).args..., + ) end diff --git a/src/zero.jl b/src/zero.jl index d5c013c..cf563f2 100644 --- a/src/zero.jl +++ b/src/zero.jl @@ -5,6 +5,6 @@ In-place version of `Base.zero`. """ function zero! end -@derive (T=AbstractArray,) begin - DerivableInterfaces.zero!(::T) +@derive (T = AbstractArray,) begin + DerivableInterfaces.zero!(::T) end diff --git a/test/SparseArrayDOKs.jl b/test/SparseArrayDOKs.jl index 192cf64..2549a1f 100644 --- a/test/SparseArrayDOKs.jl +++ b/test/SparseArrayDOKs.jl @@ -4,10 +4,10 @@ isstored(a::AbstractArray, I::CartesianIndex) = isstored(a, Tuple(I)...) getstoredindex(a::AbstractArray, I::CartesianIndex) = getstoredindex(a, Tuple(I)...) getunstoredindex(a::AbstractArray, I::CartesianIndex) = getunstoredindex(a, Tuple(I)...) function setstoredindex!(a::AbstractArray, value, I::CartesianIndex) - return setstoredindex!(a, value, Tuple(I)...) + return setstoredindex!(a, value, Tuple(I)...) end function setunstoredindex!(a::AbstractArray, value, I::CartesianIndex) - return setunstoredindex!(a, value, Tuple(I)...) + return setunstoredindex!(a, value, Tuple(I)...) end # A view of the stored values of an array. @@ -16,171 +16,171 @@ end # is then interpreted as a sparse array. Also, that involves extra # logic for determining if the indices are stored or not, but we know # the indices are stored. -struct StoredValues{T,A<:AbstractArray{T},I} <: AbstractVector{T} - array::A - storedindices::I +struct StoredValues{T, A <: AbstractArray{T}, I} <: AbstractVector{T} + array::A + storedindices::I end StoredValues(a::AbstractArray) = StoredValues(a, collect(eachstoredindex(a))) Base.size(a::StoredValues) = size(a.storedindices) Base.getindex(a::StoredValues, I::Int) = getstoredindex(a.array, a.storedindices[I]) function Base.setindex!(a::StoredValues, value, I::Int) - return setstoredindex!(a.array, value, a.storedindices[I]) + return setstoredindex!(a.array, value, a.storedindices[I]) end storedvalues(a::AbstractArray) = StoredValues(a) using ArrayLayouts: ArrayLayouts, MatMulMatAdd, MemoryLayout using DerivableInterfaces: - DerivableInterfaces, - @array_aliases, - @derive, - @interface, - AbstractArrayInterface, - interface + DerivableInterfaces, + @array_aliases, + @derive, + @interface, + AbstractArrayInterface, + interface using LinearAlgebra: LinearAlgebra # Define an interface. struct SparseArrayInterface{N} <: AbstractArrayInterface{N} end SparseArrayInterface(::Val{N}) where {N} = SparseArrayInterface{N}() -SparseArrayInterface{M}(::Val{N}) where {M,N} = SparseArrayInterface{N}() +SparseArrayInterface{M}(::Val{N}) where {M, N} = SparseArrayInterface{N}() # Define interface functions. @interface ::SparseArrayInterface function Base.getindex( - a::AbstractArray{<:Any,N}, I::Vararg{Int,N} -) where {N} - checkbounds(a, I...) - !isstored(a, I...) && return getunstoredindex(a, I...) - return getstoredindex(a, I...) + a::AbstractArray{<:Any, N}, I::Vararg{Int, N} + ) where {N} + checkbounds(a, I...) + !isstored(a, I...) && return getunstoredindex(a, I...) + return getstoredindex(a, I...) end @interface ::SparseArrayInterface function Base.setindex!( - a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N} -) where {N} - checkbounds(a, I...) - if !isstored(a, I...) - iszero(value) && return a - setunstoredindex!(a, value, I...) + a::AbstractArray{<:Any, N}, value, I::Vararg{Int, N} + ) where {N} + checkbounds(a, I...) + if !isstored(a, I...) + iszero(value) && return a + setunstoredindex!(a, value, I...) + return a + end + setstoredindex!(a, value, I...) return a - end - setstoredindex!(a, value, I...) - return a end struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end -SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}() +SparseArrayStyle{M}(::Val{N}) where {M, N} = SparseArrayStyle{N}() function DerivableInterfaces.interface(::Type{<:SparseArrayStyle{N}}) where {N} - return SparseArrayInterface{N}() + return SparseArrayInterface{N}() end @derive SparseArrayStyle AbstractArrayStyleOps function Base.similar(::SparseArrayInterface, T::Type, ax::Tuple) - return similar(SparseArrayDOK{T}, ax) + return similar(SparseArrayDOK{T}, ax) end # Interface functions. @interface ::SparseArrayInterface function Broadcast.BroadcastStyle(type::Type) - return SparseArrayStyle{ndims(type)}() + return SparseArrayStyle{ndims(type)}() end struct SparseLayout <: MemoryLayout end @interface ::SparseArrayInterface function ArrayLayouts.MemoryLayout(type::Type) - return SparseLayout() + return SparseLayout() end @interface ::SparseArrayInterface function Base.map!( - f, a_dest::AbstractArray, as::AbstractArray... -) - # TODO: Define a function `preserves_unstored(a_dest, f, as...)` - # to determine if a function preserves the stored values - # of the destination sparse array. - # The current code may be inefficient since it actually - # accesses an unstored element, which in the case of a - # sparse array of arrays can allocate an array. - # Sparse arrays could be expected to define a cheap - # unstored element allocator, for example - # `get_prototypical_unstored(a::AbstractArray)`. - I = first(eachindex(as...)) - preserves_unstored = iszero(f(map(a -> getunstoredindex(a, I), as)...)) - if !preserves_unstored - # Doesn't preserve unstored values, loop over all elements. - for I in eachindex(as...) - a_dest[I] = map(f, map(a -> a[I], as)...) + f, a_dest::AbstractArray, as::AbstractArray... + ) + # TODO: Define a function `preserves_unstored(a_dest, f, as...)` + # to determine if a function preserves the stored values + # of the destination sparse array. + # The current code may be inefficient since it actually + # accesses an unstored element, which in the case of a + # sparse array of arrays can allocate an array. + # Sparse arrays could be expected to define a cheap + # unstored element allocator, for example + # `get_prototypical_unstored(a::AbstractArray)`. + I = first(eachindex(as...)) + preserves_unstored = iszero(f(map(a -> getunstoredindex(a, I), as)...)) + if !preserves_unstored + # Doesn't preserve unstored values, loop over all elements. + for I in eachindex(as...) + a_dest[I] = map(f, map(a -> a[I], as)...) + end + end + # TODO: Define `eachstoredindex(as...)`. + for I in union(eachstoredindex.(as)...) + a_dest[I] = map(f, map(a -> a[I], as)...) end - end - # TODO: Define `eachstoredindex(as...)`. - for I in union(eachstoredindex.(as)...) - a_dest[I] = map(f, map(a -> a[I], as)...) - end - return a_dest + return a_dest end @interface ::SparseArrayInterface function Base.mapreduce( - f, op, a::AbstractArray; kwargs... -) - # TODO: Need to select a better `init`. - return mapreduce(f, op, storedvalues(a); kwargs...) + f, op, a::AbstractArray; kwargs... + ) + # TODO: Need to select a better `init`. + return mapreduce(f, op, storedvalues(a); kwargs...) end # ArrayLayouts functionality. function ArrayLayouts.sub_materialize(::SparseLayout, a::AbstractArray, axes::Tuple) - a_dest = similar(a) - a_dest .= a - return a_dest + a_dest = similar(a) + a_dest .= a + return a_dest end function ArrayLayouts.materialize!( - m::MatMulMatAdd{<:SparseLayout,<:SparseLayout,<:SparseLayout} -) - a_dest, a1, a2, α, β = m.C, m.A, m.B, m.α, m.β - for I1 in eachstoredindex(a1) - for I2 in eachstoredindex(a2) - if I1[2] == I2[1] - I_dest = CartesianIndex(I1[1], I2[2]) - a_dest[I_dest] = a1[I1] * a2[I2] * α + a_dest[I_dest] * β - end + m::MatMulMatAdd{<:SparseLayout, <:SparseLayout, <:SparseLayout} + ) + a_dest, a1, a2, α, β = m.C, m.A, m.B, m.α, m.β + for I1 in eachstoredindex(a1) + for I2 in eachstoredindex(a2) + if I1[2] == I2[1] + I_dest = CartesianIndex(I1[1], I2[2]) + a_dest[I_dest] = a1[I1] * a2[I2] * α + a_dest[I_dest] * β + end + end end - end - return a_dest + return a_dest end # Sparse array minimal interface using LinearAlgebra: Adjoint function isstored(a::Adjoint, i::Int, j::Int) - return isstored(parent(a), j, i) + return isstored(parent(a), j, i) end function getstoredindex(a::Adjoint, i::Int, j::Int) - return getstoredindex(parent(a), j, i)' + return getstoredindex(parent(a), j, i)' end function getunstoredindex(a::Adjoint, i::Int, j::Int) - return getunstoredindex(parent(a), j, i)' + return getunstoredindex(parent(a), j, i)' end function eachstoredindex(a::Adjoint) - return map(CartesianIndex ∘ reverse ∘ Tuple, collect(eachstoredindex(parent(a)))) + return map(CartesianIndex ∘ reverse ∘ Tuple, collect(eachstoredindex(parent(a)))) end -perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p -iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip +perm(::PermutedDimsArray{<:Any, <:Any, p}) where {p} = p +iperm(::PermutedDimsArray{<:Any, <:Any, <:Any, ip}) where {ip} = ip # TODO: Use `Base.PermutedDimsArrays.genperm` or # https://github.com/jipolanco/StaticPermutations.jl? genperm(v, perm) = map(j -> v[j], perm) function isstored(a::PermutedDimsArray, I::Int...) - return isstored(parent(a), genperm(I, iperm(a))...) + return isstored(parent(a), genperm(I, iperm(a))...) end function getstoredindex(a::PermutedDimsArray, I::Int...) - return getstoredindex(parent(a), genperm(I, iperm(a))...) + return getstoredindex(parent(a), genperm(I, iperm(a))...) end function getunstoredindex(a::PermutedDimsArray, I::Int...) - return getunstoredindex(parent(a), genperm(I, iperm(a))...) + return getunstoredindex(parent(a), genperm(I, iperm(a))...) end function eachstoredindex(a::PermutedDimsArray) - return map(collect(eachstoredindex(parent(a)))) do I - return CartesianIndex(genperm(I, perm(a))) - end + return map(collect(eachstoredindex(parent(a)))) do I + return CartesianIndex(genperm(I, perm(a))) + end end tuple_oneto(n) = ntuple(identity, n) @@ -190,84 +190,86 @@ tuple_oneto(n) = ntuple(identity, n) ## end function eachstoredparentindex(a::SubArray) - return filter(eachstoredindex(parent(a))) do I - return all(d -> I[d] ∈ parentindices(a)[d], 1:ndims(parent(a))) - end + return filter(eachstoredindex(parent(a))) do I + return all(d -> I[d] ∈ parentindices(a)[d], 1:ndims(parent(a))) + end end function storedvalues(a::SubArray) - return @view parent(a)[collect(eachstoredparentindex(a))] + return @view parent(a)[collect(eachstoredparentindex(a))] end function isstored(a::SubArray, I::Int...) - return isstored(parent(a), Base.reindex(parentindices(a), I)...) + return isstored(parent(a), Base.reindex(parentindices(a), I)...) end function getstoredindex(a::SubArray, I::Int...) - return getstoredindex(parent(a), Base.reindex(parentindices(a), I)...) + return getstoredindex(parent(a), Base.reindex(parentindices(a), I)...) end function getunstoredindex(a::SubArray, I::Int...) - return getunstoredindex(parent(a), Base.reindex(parentindices(a), I)...) + return getunstoredindex(parent(a), Base.reindex(parentindices(a), I)...) end function setstoredindex!(a::SubArray, value, I::Int...) - return setstoredindex!(parent(a), value, Base.reindex(parentindices(a), I)...) + return setstoredindex!(parent(a), value, Base.reindex(parentindices(a), I)...) end function setunstoredindex!(a::SubArray, value, I::Int...) - return setunstoredindex!(parent(a), value, Base.reindex(parentindices(a), I)...) + return setunstoredindex!(parent(a), value, Base.reindex(parentindices(a), I)...) end function eachstoredindex(a::SubArray) - nonscalardims = filter(tuple_oneto(ndims(parent(a)))) do d - return !(parentindices(a)[d] isa Real) - end - return collect(( - CartesianIndex( - map(nonscalardims) do d - return findfirst(==(I[d]), parentindices(a)[d]) - end, - ) for I in eachstoredparentindex(a) - )) + nonscalardims = filter(tuple_oneto(ndims(parent(a)))) do d + return !(parentindices(a)[d] isa Real) + end + return collect( + ( + CartesianIndex( + map(nonscalardims) do d + return findfirst(==(I[d]), parentindices(a)[d]) + end, + ) for I in eachstoredparentindex(a) + ) + ) end # Define a type that will derive the interface. -struct SparseArrayDOK{T,N} <: AbstractArray{T,N} - storage::Dict{CartesianIndex{N},T} - size::NTuple{N,Int} +struct SparseArrayDOK{T, N} <: AbstractArray{T, N} + storage::Dict{CartesianIndex{N}, T} + size::NTuple{N, Int} end storage(a::SparseArrayDOK) = a.storage Base.size(a::SparseArrayDOK) = a.size function SparseArrayDOK{T}(size::Int...) where {T} - N = length(size) - return SparseArrayDOK{T,N}(Dict{CartesianIndex{N},T}(), size) + N = length(size) + return SparseArrayDOK{T, N}(Dict{CartesianIndex{N}, T}(), size) end # Used in `Base.similar`. function SparseArrayDOK{T}(::UndefInitializer, size::Tuple{Vararg{Int}}) where {T} - return SparseArrayDOK{T}(size...) + return SparseArrayDOK{T}(size...) end function isstored(a::SparseArrayDOK, I::Int...) - return CartesianIndex(I) in keys(storage(a)) + return CartesianIndex(I) in keys(storage(a)) end function getstoredindex(a::SparseArrayDOK, I::Int...) - return storage(a)[CartesianIndex(I)] + return storage(a)[CartesianIndex(I)] end function getunstoredindex(a::SparseArrayDOK, I::Int...) - return zero(eltype(a)) + return zero(eltype(a)) end function setstoredindex!(a::SparseArrayDOK, value, I::Int...) - storage(a)[CartesianIndex(I)] = value - return a + storage(a)[CartesianIndex(I)] = value + return a end function setunstoredindex!(a::SparseArrayDOK, value, I::Int...) - storage(a)[CartesianIndex(I)] = value - return a + storage(a)[CartesianIndex(I)] = value + return a end eachstoredindex(a::SparseArrayDOK) = keys(storage(a)) storedlength(a::SparseArrayDOK) = length(eachstoredindex(a)) function DerivableInterfaces.zero!(a::SparseArrayDOK) - empty!(storage(a)) - return a + empty!(storage(a)) + return a end # Specify the interface the type adheres to. function DerivableInterfaces.interface(arrayt::Type{<:SparseArrayDOK}) - return SparseArrayInterface{ndims(arrayt)}() + return SparseArrayInterface{ndims(arrayt)}() end # Define aliases like `SparseMatrixDOK`, `AnySparseArrayDOK`, etc. @@ -278,7 +280,7 @@ end # avoid overloading `Base.cat` because of method invalidations function Base._cat(dims, args::SparseArrayDOK...) - return DerivableInterfaces.Concatenate.concatenate(dims, args...) + return DerivableInterfaces.Concatenate.concatenate(dims, args...) end end diff --git a/test/runtests.jl b/test/runtests.jl index 98b2d2b..0008050 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,60 +6,62 @@ using Suppressor: Suppressor const pat = r"(?:--group=)(\w+)" arg_id = findfirst(contains(pat), ARGS) const GROUP = uppercase( - if isnothing(arg_id) - get(ENV, "GROUP", "ALL") - else - only(match(pat, ARGS[arg_id]).captures) - end, + if isnothing(arg_id) + get(ENV, "GROUP", "ALL") + else + only(match(pat, ARGS[arg_id]).captures) + end, ) "match files of the form `test_*.jl`, but exclude `*setup*.jl`" function istestfile(fn) - return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup") + return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup") end "match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl`" function isexamplefile(fn) - return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup") + return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup") end @time begin - # tests in groups based on folder structure - for testgroup in filter(isdir, readdir(@__DIR__)) - if GROUP == "ALL" || GROUP == uppercase(testgroup) - groupdir = joinpath(@__DIR__, testgroup) - for file in filter(istestfile, readdir(groupdir)) - filename = joinpath(groupdir, file) - @eval @safetestset $file begin - include($filename) + # tests in groups based on folder structure + for testgroup in filter(isdir, readdir(@__DIR__)) + if GROUP == "ALL" || GROUP == uppercase(testgroup) + groupdir = joinpath(@__DIR__, testgroup) + for file in filter(istestfile, readdir(groupdir)) + filename = joinpath(groupdir, file) + @eval @safetestset $file begin + include($filename) + end + end end - end end - end - # single files in top folder - for file in filter(istestfile, readdir(@__DIR__)) - (file == basename(@__FILE__)) && continue # exclude this file to avoid infinite recursion - @eval @safetestset $file begin - include($file) + # single files in top folder + for file in filter(istestfile, readdir(@__DIR__)) + (file == basename(@__FILE__)) && continue # exclude this file to avoid infinite recursion + @eval @safetestset $file begin + include($file) + end end - end - # test examples - examplepath = joinpath(@__DIR__, "..", "examples") - for (root, _, files) in walkdir(examplepath) - contains(chopprefix(root, @__DIR__), "setup") && continue - for file in filter(isexamplefile, files) - filename = joinpath(root, file) - @eval begin - @safetestset $file begin - $(Expr( - :macrocall, - GlobalRef(Suppressor, Symbol("@suppress")), - LineNumberNode(@__LINE__, @__FILE__), - :(include($filename)), - )) + # test examples + examplepath = joinpath(@__DIR__, "..", "examples") + for (root, _, files) in walkdir(examplepath) + contains(chopprefix(root, @__DIR__), "setup") && continue + for file in filter(isexamplefile, files) + filename = joinpath(root, file) + @eval begin + @safetestset $file begin + $( + Expr( + :macrocall, + GlobalRef(Suppressor, Symbol("@suppress")), + LineNumberNode(@__LINE__, @__FILE__), + :(include($filename)), + ) + ) + end + end end - end end - end end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index a01787c..fbec055 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -3,5 +3,5 @@ using Aqua: Aqua using Test: @testset @testset "Code quality (Aqua.jl)" begin - # Aqua.test_all(DerivableInterfaces) + # Aqua.test_all(DerivableInterfaces) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 747a4ac..ddb9db8 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -5,149 +5,149 @@ using Test: @test, @testset elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "DerivableInterfaces" for elt in elts - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - @test a isa SparseArrayDOK{elt,2} - @test size(a) == (2, 2) - @test a[1, 1] == 0 - @test a[1, 1, 1] == 0 - @test a[1, 2] == 12 - @test a[1, 2, 1] == 12 - @test storedlength(a) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - for b in (similar(a, Float32, (3, 3)), similar(a, Float32, Base.OneTo.((3, 3)))) - @test b isa SparseArrayDOK{Float32,2} - @test b == zeros(Float32, 3, 3) - @test size(b) == (3, 3) - end - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = similar(a) - bc = Broadcast.Broadcasted(x -> 2x, (a,)) - copyto!(b, bc) - @test b isa SparseArrayDOK{elt,2} - @test b == [0 24; 0 0] - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(3, 3, 3) - a[1, 2, 3] = 123 - b = permutedims(a, (2, 3, 1)) - @test b isa SparseArrayDOK{elt,3} - @test b[2, 3, 1] == 123 - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = copy(a') - @test b isa SparseArrayDOK{elt,2} - @test b == [0 0; 12 0] - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = map(x -> 2x, a) - @test b isa SparseArrayDOK{elt,2} - @test b == [0 24; 0 0] - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = a * a' - @test b isa SparseArrayDOK{elt,2} - @test b == [144 0; 0 0] - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = a .+ 2 .* a' - @test b isa SparseArrayDOK{elt,2} - @test b == [0 12; 24 0] - @test storedlength(b) == 2 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = a[1:2, 2] - @test b isa SparseArrayDOK{elt,1} - @test b == [12, 0] - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - @test iszero(a) - a[2, 1] = 21 - a[1, 2] = 12 - @test !iszero(a) - @test isreal(a) - @test sum(a) == 33 - @test mapreduce(x -> 2x, +, a) == 66 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = similar(a) - copyto!(b, a) - @test b isa SparseArrayDOK{elt,2} - @test b == a - @test b[1, 2] == 12 - @test storedlength(b) == 1 - - a = SparseArrayDOK{elt}(2, 2) - a .= 2 - @test storedlength(a) == length(a) - for I in eachindex(a) - @test a[I] == 2 - end - - a = SparseArrayDOK{elt}(2, 2) - fill!(a, 2) - @test storedlength(a) == length(a) - for I in eachindex(a) - @test a[I] == 2 - end - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - zero!(a) - @test iszero(a) - @test iszero(storedlength(a)) - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = zero(a) - @test b isa SparseArrayDOK{elt,2} - @test iszero(b) - @test iszero(storedlength(b)) - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = SparseArrayDOK{elt}(4, 4) - b[2:3, 2:3] .= a - @test isone(storedlength(b)) - @test b[2, 3] == 12 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = SparseArrayDOK{elt}(4, 4) - b[2:3, 2:3] = a - @test isone(storedlength(b)) - @test b[2, 3] == 12 - - a = SparseArrayDOK{elt}(2, 2) - a[1, 2] = 12 - b = SparseArrayDOK{elt}(4, 4) - c = @view b[2:3, 2:3] - c .= a - @test isone(storedlength(b)) - @test b[2, 3] == 12 - - a1 = SparseArrayDOK{elt}(2, 2) - a1[1, 2] = 12 - a2 = SparseArrayDOK{elt}(2, 2) - a2[2, 1] = 21 - b = cat(a1, a2; dims=(1, 2)) - @test b isa SparseArrayDOK{elt,2} - @test storedlength(b) == 2 - @test b[1, 2] == 12 - @test b[4, 3] == 21 + a = SparseArrayDOK{elt}(2, 2) + a[1, 2] = 12 + @test a isa SparseArrayDOK{elt, 2} + @test size(a) == (2, 2) + @test a[1, 1] == 0 + @test a[1, 1, 1] == 0 + @test a[1, 2] == 12 + @test a[1, 2, 1] == 12 + @test storedlength(a) == 1 + + a = SparseArrayDOK{elt}(2, 2) + a[1, 2] = 12 + for b in (similar(a, Float32, (3, 3)), similar(a, Float32, Base.OneTo.((3, 3)))) + @test b isa SparseArrayDOK{Float32, 2} + @test b == zeros(Float32, 3, 3) + @test size(b) == (3, 3) + end + + a = SparseArrayDOK{elt}(2, 2) + a[1, 2] = 12 + b = similar(a) + bc = Broadcast.Broadcasted(x -> 2x, (a,)) + copyto!(b, bc) + @test b isa SparseArrayDOK{elt, 2} + @test b == [0 24; 0 0] + @test storedlength(b) == 1 + + a = SparseArrayDOK{elt}(3, 3, 3) + a[1, 2, 3] = 123 + b = permutedims(a, (2, 3, 1)) + @test b isa SparseArrayDOK{elt, 3} + @test b[2, 3, 1] == 123 + @test storedlength(b) == 1 + + a = SparseArrayDOK{elt}(2, 2) + a[1, 2] = 12 + b = copy(a') + @test b isa SparseArrayDOK{elt, 2} + @test b == [0 0; 12 0] + @test storedlength(b) == 1 + + a = SparseArrayDOK{elt}(2, 2) + a[1, 2] = 12 + b = map(x -> 2x, a) + @test b isa SparseArrayDOK{elt, 2} + @test b == [0 24; 0 0] + @test storedlength(b) == 1 + + a = SparseArrayDOK{elt}(2, 2) + a[1, 2] = 12 + b = a * a' + @test b isa SparseArrayDOK{elt, 2} + @test b == [144 0; 0 0] + @test storedlength(b) == 1 + + a = SparseArrayDOK{elt}(2, 2) + a[1, 2] = 12 + b = a .+ 2 .* a' + @test b isa SparseArrayDOK{elt, 2} + @test b == [0 12; 24 0] + @test storedlength(b) == 2 + + a = SparseArrayDOK{elt}(2, 2) + a[1, 2] = 12 + b = a[1:2, 2] + @test b isa SparseArrayDOK{elt, 1} + @test b == [12, 0] + @test storedlength(b) == 1 + + a = SparseArrayDOK{elt}(2, 2) + @test iszero(a) + a[2, 1] = 21 + a[1, 2] = 12 + @test !iszero(a) + @test isreal(a) + @test sum(a) == 33 + @test mapreduce(x -> 2x, +, a) == 66 + + a = SparseArrayDOK{elt}(2, 2) + a[1, 2] = 12 + b = similar(a) + copyto!(b, a) + @test b isa SparseArrayDOK{elt, 2} + @test b == a + @test b[1, 2] == 12 + @test storedlength(b) == 1 + + a = SparseArrayDOK{elt}(2, 2) + a .= 2 + @test storedlength(a) == length(a) + for I in eachindex(a) + @test a[I] == 2 + end + + a = SparseArrayDOK{elt}(2, 2) + fill!(a, 2) + @test storedlength(a) == length(a) + for I in eachindex(a) + @test a[I] == 2 + end + + a = SparseArrayDOK{elt}(2, 2) + a[1, 2] = 12 + zero!(a) + @test iszero(a) + @test iszero(storedlength(a)) + + a = SparseArrayDOK{elt}(2, 2) + a[1, 2] = 12 + b = zero(a) + @test b isa SparseArrayDOK{elt, 2} + @test iszero(b) + @test iszero(storedlength(b)) + + a = SparseArrayDOK{elt}(2, 2) + a[1, 2] = 12 + b = SparseArrayDOK{elt}(4, 4) + b[2:3, 2:3] .= a + @test isone(storedlength(b)) + @test b[2, 3] == 12 + + a = SparseArrayDOK{elt}(2, 2) + a[1, 2] = 12 + b = SparseArrayDOK{elt}(4, 4) + b[2:3, 2:3] = a + @test isone(storedlength(b)) + @test b[2, 3] == 12 + + a = SparseArrayDOK{elt}(2, 2) + a[1, 2] = 12 + b = SparseArrayDOK{elt}(4, 4) + c = @view b[2:3, 2:3] + c .= a + @test isone(storedlength(b)) + @test b[2, 3] == 12 + + a1 = SparseArrayDOK{elt}(2, 2) + a1[1, 2] = 12 + a2 = SparseArrayDOK{elt}(2, 2) + a2[2, 1] = 21 + b = cat(a1, a2; dims = (1, 2)) + @test b isa SparseArrayDOK{elt, 2} + @test storedlength(b) == 2 + @test b[1, 2] == 12 + @test b[4, 3] == 21 end diff --git a/test/test_concatenate.jl b/test/test_concatenate.jl index d35674b..cfd6c43 100644 --- a/test/test_concatenate.jl +++ b/test/test_concatenate.jl @@ -2,30 +2,30 @@ using DerivableInterfaces.Concatenate: concatenated using Test: @test, @testset @testset "Concatenated" begin - a = randn(Float32, 2, 2) - b = randn(Float64, 2, 2) + a = randn(Float32, 2, 2) + b = randn(Float64, 2, 2) - concat = concatenated((1, 2), a, b) - @test axes(concat) == Base.OneTo.((4, 4)) - @test size(concat) == (4, 4) - @test eltype(concat) === Float64 - @test copy(concat) == cat(a, b; dims=(1, 2)) + concat = concatenated((1, 2), a, b) + @test axes(concat) == Base.OneTo.((4, 4)) + @test size(concat) == (4, 4) + @test eltype(concat) === Float64 + @test copy(concat) == cat(a, b; dims = (1, 2)) - concat = concatenated(1, a, b) - @test axes(concat) == Base.OneTo.((4, 2)) - @test size(concat) == (4, 2) - @test eltype(concat) === Float64 - @test copy(concat) == cat(a, b; dims=1) + concat = concatenated(1, a, b) + @test axes(concat) == Base.OneTo.((4, 2)) + @test size(concat) == (4, 2) + @test eltype(concat) === Float64 + @test copy(concat) == cat(a, b; dims = 1) - concat = concatenated(3, a, b) - @test axes(concat) == Base.OneTo.((2, 2, 2)) - @test size(concat) == (2, 2, 2) - @test eltype(concat) === Float64 - @test copy(concat) == cat(a, b; dims=3) + concat = concatenated(3, a, b) + @test axes(concat) == Base.OneTo.((2, 2, 2)) + @test size(concat) == (2, 2, 2) + @test eltype(concat) === Float64 + @test copy(concat) == cat(a, b; dims = 3) - concat = concatenated(4, a, b) - @test axes(concat) == Base.OneTo.((2, 2, 1, 2)) - @test size(concat) == (2, 2, 1, 2) - @test eltype(concat) === Float64 - @test copy(concat) == cat(a, b; dims=4) + concat = concatenated(4, a, b) + @test axes(concat) == Base.OneTo.((2, 2, 1, 2)) + @test size(concat) == (2, 2, 1, 2) + @test eltype(concat) === Float64 + @test copy(concat) == cat(a, b; dims = 4) end diff --git a/test/test_defaultarrayinterface.jl b/test/test_defaultarrayinterface.jl index 6820882..af926e8 100644 --- a/test/test_defaultarrayinterface.jl +++ b/test/test_defaultarrayinterface.jl @@ -8,151 +8,151 @@ _getindex(A, i...) = @interface DefaultArrayInterface() A[i...] _setindex!(A, v, i...) = @interface DefaultArrayInterface() A[i...] = v _map!(args...) = @interface DefaultArrayInterface() map!(args...) function _mapreduce(args...; kwargs...) - @interface DefaultArrayInterface() mapreduce(args...; kwargs...) + return @interface DefaultArrayInterface() mapreduce(args...; kwargs...) end @testset "indexing" begin - for (A, i) in ((zeros(2), 2), (zeros(2, 2), (2, 1)), (zeros(1, 2, 3), (1, 2, 3))) - a = @constinferred _getindex(A, i...) - @test a == A[i...] - v = 1.1 - A′ = @constinferred _setindex!(A, v, i...) - @test A′ == (A[i...] = v) - end + for (A, i) in ((zeros(2), 2), (zeros(2, 2), (2, 1)), (zeros(1, 2, 3), (1, 2, 3))) + a = @constinferred _getindex(A, i...) + @test a == A[i...] + v = 1.1 + A′ = @constinferred _setindex!(A, v, i...) + @test A′ == (A[i...] = v) + end end @testset "map!" begin - A = zeros(3) - a = @constinferred _map!(Returns(2), copy(A), A) - @test a == map!(Returns(2), copy(A), A) + A = zeros(3) + a = @constinferred _map!(Returns(2), copy(A), A) + @test a == map!(Returns(2), copy(A), A) end @testset "mapreduce" begin - A = zeros(3) - a = @constinferred _mapreduce(Returns(2), +, A) - @test a == mapreduce(Returns(2), +, A) + A = zeros(3) + a = @constinferred _mapreduce(Returns(2), +, A) + @test a == mapreduce(Returns(2), +, A) end @testset "DefaultArrayInterface" begin - @test @constinferred(interface(Array)) === DefaultArrayInterface{Any,Array}() - @test @constinferred(interface(Array{Float32})) === DefaultArrayInterface{Any,Array}() - @test @constinferred(interface(Matrix)) === DefaultArrayInterface{2,Array}() - @test @constinferred(interface(Matrix{Float32})) === DefaultArrayInterface{2,Array}() - @test @constinferred(DefaultArrayInterface()) === DefaultArrayInterface{Any}() - @test @constinferred(DefaultArrayInterface(Val(2))) === DefaultArrayInterface{2}() - @test @constinferred(DefaultArrayInterface{Any}(Val(2))) === DefaultArrayInterface{2}() - @test @constinferred(DefaultArrayInterface{3}(Val(2))) === DefaultArrayInterface{2}() + @test @constinferred(interface(Array)) === DefaultArrayInterface{Any, Array}() + @test @constinferred(interface(Array{Float32})) === DefaultArrayInterface{Any, Array}() + @test @constinferred(interface(Matrix)) === DefaultArrayInterface{2, Array}() + @test @constinferred(interface(Matrix{Float32})) === DefaultArrayInterface{2, Array}() + @test @constinferred(DefaultArrayInterface()) === DefaultArrayInterface{Any}() + @test @constinferred(DefaultArrayInterface(Val(2))) === DefaultArrayInterface{2}() + @test @constinferred(DefaultArrayInterface{Any}(Val(2))) === DefaultArrayInterface{2}() + @test @constinferred(DefaultArrayInterface{3}(Val(2))) === DefaultArrayInterface{2}() - # DefaultArrayInterface - @test @constinferred(interface(AbstractArray)) === DefaultArrayInterface{Any}() - @test @constinferred(interface(AbstractArray{<:Any,3})) === DefaultArrayInterface{3}() - @test @constinferred(interface(Array{Float32})) === DefaultArrayInterface{Any,Array}() - @test @constinferred(interface(Array{Float32,3})) === DefaultArrayInterface{3,Array}() - @test @constinferred(interface(SubArray{<:Any,<:Any,Array})) === - DefaultArrayInterface{Any,Array}() - @test @constinferred(interface(SubArray{<:Any,<:Any,AbstractArray})) === - DefaultArrayInterface{Any}() - @test @constinferred(interface(SubArray{<:Any,2,Array})) === - DefaultArrayInterface{2,Array}() - @test @constinferred(interface(randn(2, 2))) === DefaultArrayInterface{2,Array}() - @test @constinferred(interface(view(randn(2, 2), 1:2, 1))) === - DefaultArrayInterface{1,Array}() + # DefaultArrayInterface + @test @constinferred(interface(AbstractArray)) === DefaultArrayInterface{Any}() + @test @constinferred(interface(AbstractArray{<:Any, 3})) === DefaultArrayInterface{3}() + @test @constinferred(interface(Array{Float32})) === DefaultArrayInterface{Any, Array}() + @test @constinferred(interface(Array{Float32, 3})) === DefaultArrayInterface{3, Array}() + @test @constinferred(interface(SubArray{<:Any, <:Any, Array})) === + DefaultArrayInterface{Any, Array}() + @test @constinferred(interface(SubArray{<:Any, <:Any, AbstractArray})) === + DefaultArrayInterface{Any}() + @test @constinferred(interface(SubArray{<:Any, 2, Array})) === + DefaultArrayInterface{2, Array}() + @test @constinferred(interface(randn(2, 2))) === DefaultArrayInterface{2, Array}() + @test @constinferred(interface(view(randn(2, 2), 1:2, 1))) === + DefaultArrayInterface{1, Array}() - # Combining DefaultArrayInterface - @test @constinferred(interface(DefaultArrayInterface(), DefaultArrayInterface())) === - DefaultArrayInterface() - @test @constinferred( - interface(DefaultArrayInterface{2}(), DefaultArrayInterface{2}()) - ) === DefaultArrayInterface{2}() - @test @constinferred( - interface(DefaultArrayInterface{2}(), DefaultArrayInterface{3}()) - ) === DefaultArrayInterface() - @test @constinferred(interface(DefaultArrayInterface(), DefaultArrayInterface{3}())) === - DefaultArrayInterface() - @test @constinferred(interface(randn(2, 2), randn(2, 2))) === - DefaultArrayInterface{2,Array}() - @test @constinferred(interface(randn(2, 2), randn(2))) === - DefaultArrayInterface{Any,Array}() - @test @constinferred(interface(randn(2, 2), randn(2, 2)')) === - DefaultArrayInterface{2,Array}() + # Combining DefaultArrayInterface + @test @constinferred(interface(DefaultArrayInterface(), DefaultArrayInterface())) === + DefaultArrayInterface() + @test @constinferred( + interface(DefaultArrayInterface{2}(), DefaultArrayInterface{2}()) + ) === DefaultArrayInterface{2}() + @test @constinferred( + interface(DefaultArrayInterface{2}(), DefaultArrayInterface{3}()) + ) === DefaultArrayInterface() + @test @constinferred(interface(DefaultArrayInterface(), DefaultArrayInterface{3}())) === + DefaultArrayInterface() + @test @constinferred(interface(randn(2, 2), randn(2, 2))) === + DefaultArrayInterface{2, Array}() + @test @constinferred(interface(randn(2, 2), randn(2))) === + DefaultArrayInterface{Any, Array}() + @test @constinferred(interface(randn(2, 2), randn(2, 2)')) === + DefaultArrayInterface{2, Array}() end @testset "similar(::DefaultArrayInterface, ...)" begin - a = @constinferred similar(DefaultArrayInterface(), Float32, (2, 2)) - @test typeof(a) === Matrix{Float32} - @test size(a) == (2, 2) + a = @constinferred similar(DefaultArrayInterface(), Float32, (2, 2)) + @test typeof(a) === Matrix{Float32} + @test size(a) == (2, 2) - a = @constinferred similar(DefaultArrayInterface{Any,Array}(), Float32, (2, 2)) - @test typeof(a) === Matrix{Float32} - @test size(a) == (2, 2) + a = @constinferred similar(DefaultArrayInterface{Any, Array}(), Float32, (2, 2)) + @test typeof(a) === Matrix{Float32} + @test size(a) == (2, 2) - a = @constinferred similar(DefaultArrayInterface{1}(), Float32, (2, 2)) - @test typeof(a) === Matrix{Float32} - @test size(a) == (2, 2) + a = @constinferred similar(DefaultArrayInterface{1}(), Float32, (2, 2)) + @test typeof(a) === Matrix{Float32} + @test size(a) == (2, 2) end @testset "Broadcast.DefaultArrayStyle" begin - @test @constinferred(interface(Broadcast.DefaultArrayStyle)) == DefaultArrayInterface() - @test @constinferred(interface(Broadcast.DefaultArrayStyle{2})) == - DefaultArrayInterface{2}() - @test @constinferred( - interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) - ) == DefaultArrayInterface{1}() + @test @constinferred(interface(Broadcast.DefaultArrayStyle)) == DefaultArrayInterface() + @test @constinferred(interface(Broadcast.DefaultArrayStyle{2})) == + DefaultArrayInterface{2}() + @test @constinferred( + interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) + ) == DefaultArrayInterface{1}() end @testset "DefaultArrayInterface with custom array type" begin - # ArrayInterface - a = jl(randn(2, 2)) - @test @constinferred(interface(JLArray{Float32})) === DefaultArrayInterface{Any,JLArray}() - @test @constinferred(interface(SubArray{<:Any,2,JLArray{Float32}})) === - DefaultArrayInterface{2,JLArray}() - @test @constinferred(interface(a)) === DefaultArrayInterface{2,JLArray}() - @test @constinferred(interface(a')) === DefaultArrayInterface{2,JLArray}() - @test @constinferred(interface(view(a, 1:2, 1))) === DefaultArrayInterface{1,JLArray}() - a′ = @constinferred similar(a, Float32, (2, 3, 3)) - @test a′ isa JLArray{Float32,3} - @test size(a′) == (2, 3, 3) + # ArrayInterface + a = jl(randn(2, 2)) + @test @constinferred(interface(JLArray{Float32})) === DefaultArrayInterface{Any, JLArray}() + @test @constinferred(interface(SubArray{<:Any, 2, JLArray{Float32}})) === + DefaultArrayInterface{2, JLArray}() + @test @constinferred(interface(a)) === DefaultArrayInterface{2, JLArray}() + @test @constinferred(interface(a')) === DefaultArrayInterface{2, JLArray}() + @test @constinferred(interface(view(a, 1:2, 1))) === DefaultArrayInterface{1, JLArray}() + a′ = @constinferred similar(a, Float32, (2, 3, 3)) + @test a′ isa JLArray{Float32, 3} + @test size(a′) == (2, 3, 3) - # Combining ArrayInterface - @test @constinferred( - interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{2,JLArray}()) - ) === DefaultArrayInterface{2,JLArray}() - @test @constinferred( - interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{3,JLArray}()) - ) === DefaultArrayInterface{Any,JLArray}() - @test @constinferred( - interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{2}()) - ) === DefaultArrayInterface{2}() - @test @constinferred( - interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{2,Array}()) - ) === DefaultArrayInterface{2}() - @test @constinferred( - interface(DefaultArrayInterface{2}(), DefaultArrayInterface{2,JLArray}()) - ) === DefaultArrayInterface{2}() - @test @constinferred( - interface(DefaultArrayInterface{2,Array}(), DefaultArrayInterface{2,JLArray}()) - ) === DefaultArrayInterface{2}() - @test @constinferred( - interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{3}()) - ) === DefaultArrayInterface() - @test @constinferred( - interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{3,Array}()) - ) === DefaultArrayInterface() - @test @constinferred( - interface(DefaultArrayInterface{3}(), DefaultArrayInterface{2,JLArray}()) - ) === DefaultArrayInterface() - @test @constinferred( - interface(DefaultArrayInterface{3,Array}(), DefaultArrayInterface{2,JLArray}()) - ) === DefaultArrayInterface() - @test @constinferred(interface(jl(randn(2, 2)), jl(randn(2, 2)))) === - DefaultArrayInterface{2,JLArray}() - @test @constinferred(interface(jl(randn(2, 2)), jl(randn(2, 2))')) === - DefaultArrayInterface{2,JLArray}() - @test @constinferred(interface(jl(randn(2, 2)), jl(randn(2, 2, 2)))) === - DefaultArrayInterface{Any,JLArray}() - @test @constinferred(interface(view(jl(randn(2, 2))', 1:2, 1), jl(randn(2)))) === - DefaultArrayInterface{1,JLArray}() - @test @constinferred(interface(randn(2, 2), jl(randn(2, 2)))) === - DefaultArrayInterface{2}() - @test @constinferred(interface(randn(2, 2), jl(randn(2)))) === DefaultArrayInterface() + # Combining ArrayInterface + @test @constinferred( + interface(DefaultArrayInterface{2, JLArray}(), DefaultArrayInterface{2, JLArray}()) + ) === DefaultArrayInterface{2, JLArray}() + @test @constinferred( + interface(DefaultArrayInterface{2, JLArray}(), DefaultArrayInterface{3, JLArray}()) + ) === DefaultArrayInterface{Any, JLArray}() + @test @constinferred( + interface(DefaultArrayInterface{2, JLArray}(), DefaultArrayInterface{2}()) + ) === DefaultArrayInterface{2}() + @test @constinferred( + interface(DefaultArrayInterface{2, JLArray}(), DefaultArrayInterface{2, Array}()) + ) === DefaultArrayInterface{2}() + @test @constinferred( + interface(DefaultArrayInterface{2}(), DefaultArrayInterface{2, JLArray}()) + ) === DefaultArrayInterface{2}() + @test @constinferred( + interface(DefaultArrayInterface{2, Array}(), DefaultArrayInterface{2, JLArray}()) + ) === DefaultArrayInterface{2}() + @test @constinferred( + interface(DefaultArrayInterface{2, JLArray}(), DefaultArrayInterface{3}()) + ) === DefaultArrayInterface() + @test @constinferred( + interface(DefaultArrayInterface{2, JLArray}(), DefaultArrayInterface{3, Array}()) + ) === DefaultArrayInterface() + @test @constinferred( + interface(DefaultArrayInterface{3}(), DefaultArrayInterface{2, JLArray}()) + ) === DefaultArrayInterface() + @test @constinferred( + interface(DefaultArrayInterface{3, Array}(), DefaultArrayInterface{2, JLArray}()) + ) === DefaultArrayInterface() + @test @constinferred(interface(jl(randn(2, 2)), jl(randn(2, 2)))) === + DefaultArrayInterface{2, JLArray}() + @test @constinferred(interface(jl(randn(2, 2)), jl(randn(2, 2))')) === + DefaultArrayInterface{2, JLArray}() + @test @constinferred(interface(jl(randn(2, 2)), jl(randn(2, 2, 2)))) === + DefaultArrayInterface{Any, JLArray}() + @test @constinferred(interface(view(jl(randn(2, 2))', 1:2, 1), jl(randn(2)))) === + DefaultArrayInterface{1, JLArray}() + @test @constinferred(interface(randn(2, 2), jl(randn(2, 2)))) === + DefaultArrayInterface{2}() + @test @constinferred(interface(randn(2, 2), jl(randn(2)))) === DefaultArrayInterface() end diff --git a/test/test_permuteddims.jl b/test/test_permuteddims.jl index 9483acd..ad0e12d 100644 --- a/test/test_permuteddims.jl +++ b/test/test_permuteddims.jl @@ -4,14 +4,14 @@ using LinearAlgebra: Diagonal using Test: @test, @testset @testset "permuteddims" begin - a = randn(2, 3, 4) - @test permuteddims(a, (2, 1, 3)) ≡ PermutedDimsArray(a, (2, 1, 3)) + a = randn(2, 3, 4) + @test permuteddims(a, (2, 1, 3)) ≡ PermutedDimsArray(a, (2, 1, 3)) - a = Diagonal(randn(3)) - @test permuteddims(a, (1, 2)) ≡ a - @test permuteddims(a, (2, 1)) ≡ a + a = Diagonal(randn(3)) + @test permuteddims(a, (1, 2)) ≡ a + @test permuteddims(a, (2, 1)) ≡ a - a = RectDiagonal(randn(3), (3, 4)) - @test permuteddims(a, (1, 2)) ≡ a - @test permuteddims(a, (2, 1)) ≡ RectDiagonal(parent(a), (4, 3)) + a = RectDiagonal(randn(3), (3, 4)) + @test permuteddims(a, (1, 2)) ≡ a + @test permuteddims(a, (2, 1)) ≡ RectDiagonal(parent(a), (4, 3)) end