diff --git a/Project.toml b/Project.toml index aa5072a..ac8a237 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DerivableInterfaces" uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" authors = ["ITensor developers and contributors"] -version = "0.4.5" +version = "0.5.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/docs/Project.toml b/docs/Project.toml index 8b0bf74..6e4dd3d 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" [compat] -DerivableInterfaces = "0.4" +DerivableInterfaces = "0.5" Documenter = "1" Literate = "2" diff --git a/examples/Project.toml b/examples/Project.toml index cf1b8a6..ebb1aef 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -4,4 +4,4 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" [compat] ArrayLayouts = "1" -DerivableInterfaces = "0.4" +DerivableInterfaces = "0.5" diff --git a/src/abstractarrayinterface.jl b/src/abstractarrayinterface.jl index 865787d..6eba847 100644 --- a/src/abstractarrayinterface.jl +++ b/src/abstractarrayinterface.jl @@ -1,20 +1,25 @@ # TODO: Add `ndims` type parameter. -abstract type AbstractArrayInterface <: AbstractInterface end +abstract type AbstractArrayInterface{N} <: AbstractInterface end +function interface(::Type{<:Broadcast.AbstractArrayStyle{N}}) where {N} + return DefaultArrayInterface{N}() +end function interface(::Type{<:Broadcast.AbstractArrayStyle}) return DefaultArrayInterface() end -function interface(::Type{<:Broadcast.Broadcasted{Nothing}}) - return DefaultArrayInterface() +function interface(BC::Type{<:Broadcast.Broadcasted{Nothing}}) + return DefaultArrayInterface{ndims(BC)}() end function interface(::Type{<:Broadcast.Broadcasted{<:Style}}) where {Style} return interface(Style) end -# TODO: Define as `Array{T}`. -arraytype(::AbstractArrayInterface, T::Type) = error("Not implemented.") +# TODO: Define as `similar(Array{T}, ax)`. +function Base.similar(interface::AbstractArrayInterface, T::Type, ax::Tuple) + return error("Not implemented.") +end using ArrayLayouts: ArrayLayouts @@ -85,7 +90,7 @@ end @interface interface::AbstractArrayInterface function Base.similar( a::AbstractArray, T::Type, size::Tuple{Vararg{Int}} ) - return similar(arraytype(interface, T), size) + return similar(interface, T, size) end @interface ::AbstractArrayInterface function Base.copy(a::AbstractArray) @@ -105,8 +110,7 @@ end @interface interface::AbstractArrayInterface function Base.similar( bc::Broadcast.Broadcasted, T::Type, axes::Tuple ) - # `arraytype(::AbstractInterface)` determines the default array type associated with the interface. - return similar(arraytype(interface, T), axes) + return similar(interface, T, axes) end using MapBroadcast: Mapped diff --git a/src/concatenate.jl b/src/concatenate.jl index fcefee8..5f7b019 100644 --- a/src/concatenate.jl +++ b/src/concatenate.jl @@ -28,8 +28,7 @@ export concatenate @compat public Concatenated, cat, cat!, concatenated using Base: promote_eltypeof -using ..DerivableInterfaces: - DerivableInterfaces, AbstractInterface, interface, zero!, arraytype +using ..DerivableInterfaces: DerivableInterfaces, AbstractArrayInterface, interface, zero! unval(x) = x unval(::Val{x}) where {x} = x @@ -53,13 +52,17 @@ struct Concatenated{Interface,Dims,Args<:Tuple} end end -function Concatenated(interface::Union{Nothing,AbstractInterface}, dims::Val, args::Tuple) +function Concatenated( + interface::Union{AbstractArrayInterface,Nothing}, dims::Val, args::Tuple +) return _Concatenated(interface, dims, args) end function Concatenated(dims::Val, args::Tuple) - return Concatenated(interface(args...), dims, args) + return Concatenated(cat_interface(dims, args...), dims, args) end -function Concatenated{Interface}(dims::Val, args::Tuple) where {Interface} +function Concatenated{Interface}( + dims::Val, args::Tuple +) where {Interface<:Union{AbstractArrayInterface,Nothing}} return Concatenated(Interface(), dims, args) end @@ -81,8 +84,11 @@ end # ------------------------------------ Base.similar(concat::Concatenated) = similar(concat, eltype(concat)) Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(concat)) -function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T} - return similar(arraytype(interface(concat), T), ax) +function Base.similar(concat::Concatenated, ax::Tuple) + return similar(interface(concat), eltype(concat), ax) +end +function Base.similar(concat::Concatenated, ::Type{T}, ax::Tuple) where {T} + return similar(interface(concat), T, ax) end function cat_axis( @@ -108,10 +114,15 @@ function cat_axes(dims::Val, as::AbstractArray...) return cat_axes(unval(dims), as...) end +function cat_interface(dims, as::AbstractArray...) + N = cat_ndims(dims, as...) + return typeof(interface(as...))(Val(N)) +end + Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...) Base.size(concat::Concatenated) = length.(axes(concat)) -Base.ndims(concat::Concatenated) = length(axes(concat)) +Base.ndims(concat::Concatenated) = cat_ndims(dims(concat), concat.args...) # Main logic # ---------- diff --git a/src/defaultarrayinterface.jl b/src/defaultarrayinterface.jl index 538b8e7..3861a3e 100644 --- a/src/defaultarrayinterface.jl +++ b/src/defaultarrayinterface.jl @@ -1,11 +1,29 @@ -# TODO: Add `ndims` type parameter. -struct DefaultArrayInterface <: AbstractArrayInterface end +struct DefaultArrayInterface{N} <: AbstractArrayInterface{N} end + +DefaultArrayInterface() = DefaultArrayInterface{Any}() +DefaultArrayInterface(::Val{N}) where {N} = DefaultArrayInterface{N}() +DefaultArrayInterface{M}(::Val{N}) where {M,N} = DefaultArrayInterface{N}() using TypeParameterAccessors: parenttype function interface(a::Type{<:AbstractArray}) parenttype(a) === a && return DefaultArrayInterface() return interface(parenttype(a)) end +function interface(a::Type{<:AbstractArray{<:Any,N}}) where {N} + parenttype(a) === a && return DefaultArrayInterface{N}() + return interface(parenttype(a)) +end + +function combine_interface_rule( + interface1::DefaultArrayInterface{N}, interface2::DefaultArrayInterface{N} +) where {N} + return DefaultArrayInterface{N}() +end +function combine_interface_rule( + interface1::DefaultArrayInterface, interface2::DefaultArrayInterface +) + return DefaultArrayInterface{Any}() +end @interface ::DefaultArrayInterface function Base.getindex( a::AbstractArray{<:Any,N}, I::Vararg{Int,N} @@ -31,6 +49,6 @@ end return Base.mapreduce(f, op, as...; kwargs...) end -function arraytype(::DefaultArrayInterface, T::Type) - return Array{T} +function Base.similar(::DefaultArrayInterface, T::Type, ax::Tuple) + return similar(Array{T}, ax) end diff --git a/test/Project.toml b/test/Project.toml index 87525d0..2c88278 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,12 +6,14 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" [compat] Aqua = "0.8" ArrayLayouts = "1" -DerivableInterfaces = "0.4" +DerivableInterfaces = "0.5" +LinearAlgebra = "1" SafeTestsets = "0.1" Suppressor = "0.2" -LinearAlgebra = "1" Test = "1" +TestExtras = "0.3" diff --git a/test/SparseArrayDOKs.jl b/test/SparseArrayDOKs.jl index 0f5c68c..9f80c3f 100644 --- a/test/SparseArrayDOKs.jl +++ b/test/SparseArrayDOKs.jl @@ -40,7 +40,9 @@ using DerivableInterfaces: using LinearAlgebra: LinearAlgebra # Define an interface. -struct SparseArrayInterface <: AbstractArrayInterface end +struct SparseArrayInterface{N} <: AbstractArrayInterface{N} end +SparseArrayInterface(::Val{N}) where {N} = SparseArrayInterface{N}() +SparseArrayInterface{M}(::Val{N}) where {M,N} = SparseArrayInterface{N}() # Define interface functions. @interface ::SparseArrayInterface function Base.getindex( @@ -66,11 +68,15 @@ end struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}() -DerivableInterfaces.interface(::Type{<:SparseArrayStyle}) = SparseArrayInterface() +function DerivableInterfaces.interface(::Type{<:SparseArrayStyle{N}}) where {N} + return SparseArrayInterface{N}() +end @derive SparseArrayStyle AbstractArrayStyleOps -DerivableInterfaces.arraytype(::SparseArrayInterface, T::Type) = SparseArrayDOK{T} +function Base.similar(::SparseArrayInterface, T::Type, ax::Tuple) + return similar(SparseArrayDOK{T}, ax) +end # Interface functions. @interface ::SparseArrayInterface function Broadcast.BroadcastStyle(type::Type) @@ -260,7 +266,9 @@ function DerivableInterfaces.zero!(a::SparseArrayDOK) end # Specify the interface the type adheres to. -DerivableInterfaces.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface() +function DerivableInterfaces.interface(arrayt::Type{<:SparseArrayDOK}) + SparseArrayInterface{ndims(arrayt)}() +end # Define aliases like `SparseMatrixDOK`, `AnySparseArrayDOK`, etc. @array_aliases SparseArrayDOK diff --git a/test/test_defaultarrayinterface.jl b/test/test_defaultarrayinterface.jl index d12bade..6d2cb23 100644 --- a/test/test_defaultarrayinterface.jl +++ b/test/test_defaultarrayinterface.jl @@ -1,5 +1,6 @@ -using Test: @inferred, @testset, @test using DerivableInterfaces: @interface, DefaultArrayInterface, interface +using Test: @testset, @test +using TestExtras: @constinferred # function wrappers to test type-stability _getindex(A, i...) = @interface DefaultArrayInterface() A[i...] @@ -11,28 +12,50 @@ end @testset "indexing" begin for (A, i) in ((zeros(2), 2), (zeros(2, 2), (2, 1)), (zeros(1, 2, 3), (1, 2, 3))) - a = @inferred _getindex(A, i...) + a = @constinferred _getindex(A, i...) @test a == A[i...] v = 1.1 - A′ = @inferred _setindex!(A, v, i...) + A′ = @constinferred _setindex!(A, v, i...) @test A′ == (A[i...] = v) end end @testset "map!" begin A = zeros(3) - a = @inferred _map!(Returns(2), copy(A), A) + a = @constinferred _map!(Returns(2), copy(A), A) @test a == map!(Returns(2), copy(A), A) end @testset "mapreduce" begin A = zeros(3) - a = @inferred _mapreduce(Returns(2), +, A) + a = @constinferred _mapreduce(Returns(2), +, A) @test a == mapreduce(Returns(2), +, A) end +@testset "DefaultArrayInterface" begin + @test interface(Array) === DefaultArrayInterface{Any}() + @test interface(Array{Float32}) === DefaultArrayInterface{Any}() + @test interface(Matrix) === DefaultArrayInterface{2}() + @test interface(Matrix{Float32}) === DefaultArrayInterface{2}() + @test DefaultArrayInterface() === DefaultArrayInterface{Any}() + @test DefaultArrayInterface(Val(2)) === DefaultArrayInterface{2}() + @test DefaultArrayInterface{Any}(Val(2)) === DefaultArrayInterface{2}() + @test DefaultArrayInterface{3}(Val(2)) === DefaultArrayInterface{2}() +end + +@testset "similar(::DefaultArrayInterface, ...)" begin + a = @constinferred similar(DefaultArrayInterface(), Float32, (2, 2)) + @test typeof(a) === Matrix{Float32} + @test size(a) == (2, 2) + + a = @constinferred similar(DefaultArrayInterface{1}(), Float32, (2, 2)) + @test typeof(a) === Matrix{Float32} + @test size(a) == (2, 2) +end + @testset "Broadcast.DefaultArrayStyle" begin @test interface(Broadcast.DefaultArrayStyle) == DefaultArrayInterface() + @test interface(Broadcast.DefaultArrayStyle{2}) == DefaultArrayInterface{2}() @test interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) == - DefaultArrayInterface() + DefaultArrayInterface{1}() end