diff --git a/Project.toml b/Project.toml index 9206411..6d271cb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DerivableInterfaces" uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" authors = ["ITensor developers and contributors"] -version = "0.5.1" +version = "0.5.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/abstractinterface.jl b/src/abstractinterface.jl index 8223c6d..dd89129 100644 --- a/src/abstractinterface.jl +++ b/src/abstractinterface.jl @@ -6,6 +6,8 @@ interface(x1, x_rest...) = combine_interfaces(interface(x1), interface.(x_rest). abstract type AbstractInterface end +interface(x::AbstractInterface) = x + (interface::AbstractInterface)(f) = InterfaceFunction(interface, f) # Adapted from `Base.Broadcast.combine_styles`. diff --git a/src/defaultarrayinterface.jl b/src/defaultarrayinterface.jl index 3861a3e..d0d7a2e 100644 --- a/src/defaultarrayinterface.jl +++ b/src/defaultarrayinterface.jl @@ -1,19 +1,50 @@ -struct DefaultArrayInterface{N} <: AbstractArrayInterface{N} end +using TypeParameterAccessors: parenttype, set_eltype, unspecify_type_parameters +struct DefaultArrayInterface{N,A<:AbstractArray} <: AbstractArrayInterface{N} end + +DefaultArrayInterface{N}() where {N} = DefaultArrayInterface{N,AbstractArray}() DefaultArrayInterface() = DefaultArrayInterface{Any}() DefaultArrayInterface(::Val{N}) where {N} = DefaultArrayInterface{N}() DefaultArrayInterface{M}(::Val{N}) where {M,N} = DefaultArrayInterface{N}() +DefaultArrayInterface{M,A}(::Val{N}) where {M,A,N} = DefaultArrayInterface{N,A}() + +# This version remembers the `ndims` of the wrapper type. +function _interface(::Val{N}, arrayt::Type{<:AbstractArray}) where {N} + arrayt′ = parenttype(arrayt) + if arrayt′ === arrayt + return DefaultArrayInterface{N,unspecify_type_parameters(arrayt)}() + end + return typeof(interface(arrayt′))(Val(N)) +end -using TypeParameterAccessors: parenttype -function interface(a::Type{<:AbstractArray}) - parenttype(a) === a && return DefaultArrayInterface() - return interface(parenttype(a)) +function DerivableInterfaces.interface(arrayt::Type{<:AbstractArray{<:Any,N}}) where {N} + return _interface(Val(N), arrayt) end -function interface(a::Type{<:AbstractArray{<:Any,N}}) where {N} - parenttype(a) === a && return DefaultArrayInterface{N}() - return interface(parenttype(a)) +function DerivableInterfaces.interface(arrayt::Type{<:AbstractArray}) + return _interface(Val(Any), arrayt) +end + +function Base.similar( + ::DefaultArrayInterface{<:Any,A}, T::Type, ax::Tuple +) where {A<:AbstractArray} + if isabstracttype(A) + # If the type is abstract, default to constructing the array on CPU. + return similar(Array{T}, ax) + else + return similar(set_eltype(A, T), ax) + end end +function combine_interface_rule( + interface1::DefaultArrayInterface{N,A}, interface2::DefaultArrayInterface{N,A} +) where {N,A<:AbstractArray} + return DefaultArrayInterface{N,A}() +end +function combine_interface_rule( + interface1::DefaultArrayInterface{<:Any,A}, interface2::DefaultArrayInterface{<:Any,A} +) where {A<:AbstractArray} + return DefaultArrayInterface{Any,A}() +end function combine_interface_rule( interface1::DefaultArrayInterface{N}, interface2::DefaultArrayInterface{N} ) where {N} @@ -22,7 +53,7 @@ end function combine_interface_rule( interface1::DefaultArrayInterface, interface2::DefaultArrayInterface ) - return DefaultArrayInterface{Any}() + return DefaultArrayInterface() end @interface ::DefaultArrayInterface function Base.getindex( @@ -48,7 +79,3 @@ end ) return Base.mapreduce(f, op, as...; kwargs...) end - -function Base.similar(::DefaultArrayInterface, T::Type, ax::Tuple) - return similar(Array{T}, ax) -end diff --git a/test/Project.toml b/test/Project.toml index 2c88278..03a08a0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" @@ -12,6 +13,7 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Aqua = "0.8" ArrayLayouts = "1" DerivableInterfaces = "0.5" +JLArrays = "0.2" LinearAlgebra = "1" SafeTestsets = "0.1" Suppressor = "0.2" diff --git a/test/SparseArrayDOKs.jl b/test/SparseArrayDOKs.jl index 9f80c3f..192cf64 100644 --- a/test/SparseArrayDOKs.jl +++ b/test/SparseArrayDOKs.jl @@ -267,7 +267,7 @@ end # Specify the interface the type adheres to. function DerivableInterfaces.interface(arrayt::Type{<:SparseArrayDOK}) - SparseArrayInterface{ndims(arrayt)}() + return SparseArrayInterface{ndims(arrayt)}() end # Define aliases like `SparseMatrixDOK`, `AnySparseArrayDOK`, etc. diff --git a/test/test_defaultarrayinterface.jl b/test/test_defaultarrayinterface.jl index 6d2cb23..6820882 100644 --- a/test/test_defaultarrayinterface.jl +++ b/test/test_defaultarrayinterface.jl @@ -1,4 +1,5 @@ using DerivableInterfaces: @interface, DefaultArrayInterface, interface +using JLArrays: JLArray, jl using Test: @testset, @test using TestExtras: @constinferred @@ -33,14 +34,47 @@ end end @testset "DefaultArrayInterface" begin - @test interface(Array) === DefaultArrayInterface{Any}() - @test interface(Array{Float32}) === DefaultArrayInterface{Any}() - @test interface(Matrix) === DefaultArrayInterface{2}() - @test interface(Matrix{Float32}) === DefaultArrayInterface{2}() - @test DefaultArrayInterface() === DefaultArrayInterface{Any}() - @test DefaultArrayInterface(Val(2)) === DefaultArrayInterface{2}() - @test DefaultArrayInterface{Any}(Val(2)) === DefaultArrayInterface{2}() - @test DefaultArrayInterface{3}(Val(2)) === DefaultArrayInterface{2}() + @test @constinferred(interface(Array)) === DefaultArrayInterface{Any,Array}() + @test @constinferred(interface(Array{Float32})) === DefaultArrayInterface{Any,Array}() + @test @constinferred(interface(Matrix)) === DefaultArrayInterface{2,Array}() + @test @constinferred(interface(Matrix{Float32})) === DefaultArrayInterface{2,Array}() + @test @constinferred(DefaultArrayInterface()) === DefaultArrayInterface{Any}() + @test @constinferred(DefaultArrayInterface(Val(2))) === DefaultArrayInterface{2}() + @test @constinferred(DefaultArrayInterface{Any}(Val(2))) === DefaultArrayInterface{2}() + @test @constinferred(DefaultArrayInterface{3}(Val(2))) === DefaultArrayInterface{2}() + + # DefaultArrayInterface + @test @constinferred(interface(AbstractArray)) === DefaultArrayInterface{Any}() + @test @constinferred(interface(AbstractArray{<:Any,3})) === DefaultArrayInterface{3}() + @test @constinferred(interface(Array{Float32})) === DefaultArrayInterface{Any,Array}() + @test @constinferred(interface(Array{Float32,3})) === DefaultArrayInterface{3,Array}() + @test @constinferred(interface(SubArray{<:Any,<:Any,Array})) === + DefaultArrayInterface{Any,Array}() + @test @constinferred(interface(SubArray{<:Any,<:Any,AbstractArray})) === + DefaultArrayInterface{Any}() + @test @constinferred(interface(SubArray{<:Any,2,Array})) === + DefaultArrayInterface{2,Array}() + @test @constinferred(interface(randn(2, 2))) === DefaultArrayInterface{2,Array}() + @test @constinferred(interface(view(randn(2, 2), 1:2, 1))) === + DefaultArrayInterface{1,Array}() + + # Combining DefaultArrayInterface + @test @constinferred(interface(DefaultArrayInterface(), DefaultArrayInterface())) === + DefaultArrayInterface() + @test @constinferred( + interface(DefaultArrayInterface{2}(), DefaultArrayInterface{2}()) + ) === DefaultArrayInterface{2}() + @test @constinferred( + interface(DefaultArrayInterface{2}(), DefaultArrayInterface{3}()) + ) === DefaultArrayInterface() + @test @constinferred(interface(DefaultArrayInterface(), DefaultArrayInterface{3}())) === + DefaultArrayInterface() + @test @constinferred(interface(randn(2, 2), randn(2, 2))) === + DefaultArrayInterface{2,Array}() + @test @constinferred(interface(randn(2, 2), randn(2))) === + DefaultArrayInterface{Any,Array}() + @test @constinferred(interface(randn(2, 2), randn(2, 2)')) === + DefaultArrayInterface{2,Array}() end @testset "similar(::DefaultArrayInterface, ...)" begin @@ -48,14 +82,77 @@ end @test typeof(a) === Matrix{Float32} @test size(a) == (2, 2) + a = @constinferred similar(DefaultArrayInterface{Any,Array}(), Float32, (2, 2)) + @test typeof(a) === Matrix{Float32} + @test size(a) == (2, 2) + a = @constinferred similar(DefaultArrayInterface{1}(), Float32, (2, 2)) @test typeof(a) === Matrix{Float32} @test size(a) == (2, 2) end @testset "Broadcast.DefaultArrayStyle" begin - @test interface(Broadcast.DefaultArrayStyle) == DefaultArrayInterface() - @test interface(Broadcast.DefaultArrayStyle{2}) == DefaultArrayInterface{2}() - @test interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) == - DefaultArrayInterface{1}() + @test @constinferred(interface(Broadcast.DefaultArrayStyle)) == DefaultArrayInterface() + @test @constinferred(interface(Broadcast.DefaultArrayStyle{2})) == + DefaultArrayInterface{2}() + @test @constinferred( + interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) + ) == DefaultArrayInterface{1}() +end + +@testset "DefaultArrayInterface with custom array type" begin + # ArrayInterface + a = jl(randn(2, 2)) + @test @constinferred(interface(JLArray{Float32})) === DefaultArrayInterface{Any,JLArray}() + @test @constinferred(interface(SubArray{<:Any,2,JLArray{Float32}})) === + DefaultArrayInterface{2,JLArray}() + @test @constinferred(interface(a)) === DefaultArrayInterface{2,JLArray}() + @test @constinferred(interface(a')) === DefaultArrayInterface{2,JLArray}() + @test @constinferred(interface(view(a, 1:2, 1))) === DefaultArrayInterface{1,JLArray}() + a′ = @constinferred similar(a, Float32, (2, 3, 3)) + @test a′ isa JLArray{Float32,3} + @test size(a′) == (2, 3, 3) + + # Combining ArrayInterface + @test @constinferred( + interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{2,JLArray}()) + ) === DefaultArrayInterface{2,JLArray}() + @test @constinferred( + interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{3,JLArray}()) + ) === DefaultArrayInterface{Any,JLArray}() + @test @constinferred( + interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{2}()) + ) === DefaultArrayInterface{2}() + @test @constinferred( + interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{2,Array}()) + ) === DefaultArrayInterface{2}() + @test @constinferred( + interface(DefaultArrayInterface{2}(), DefaultArrayInterface{2,JLArray}()) + ) === DefaultArrayInterface{2}() + @test @constinferred( + interface(DefaultArrayInterface{2,Array}(), DefaultArrayInterface{2,JLArray}()) + ) === DefaultArrayInterface{2}() + @test @constinferred( + interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{3}()) + ) === DefaultArrayInterface() + @test @constinferred( + interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{3,Array}()) + ) === DefaultArrayInterface() + @test @constinferred( + interface(DefaultArrayInterface{3}(), DefaultArrayInterface{2,JLArray}()) + ) === DefaultArrayInterface() + @test @constinferred( + interface(DefaultArrayInterface{3,Array}(), DefaultArrayInterface{2,JLArray}()) + ) === DefaultArrayInterface() + @test @constinferred(interface(jl(randn(2, 2)), jl(randn(2, 2)))) === + DefaultArrayInterface{2,JLArray}() + @test @constinferred(interface(jl(randn(2, 2)), jl(randn(2, 2))')) === + DefaultArrayInterface{2,JLArray}() + @test @constinferred(interface(jl(randn(2, 2)), jl(randn(2, 2, 2)))) === + DefaultArrayInterface{Any,JLArray}() + @test @constinferred(interface(view(jl(randn(2, 2))', 1:2, 1), jl(randn(2)))) === + DefaultArrayInterface{1,JLArray}() + @test @constinferred(interface(randn(2, 2), jl(randn(2, 2)))) === + DefaultArrayInterface{2}() + @test @constinferred(interface(randn(2, 2), jl(randn(2)))) === DefaultArrayInterface() end