From c5c4904094347b827bc4a50d541d40d64f58bb69 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 12 Jun 2025 11:00:12 -0400 Subject: [PATCH 1/3] Add ArrayInterface --- Project.toml | 2 +- src/abstractinterface.jl | 2 + src/defaultarrayinterface.jl | 85 +++++++++++++++++++++++++++--- test/Project.toml | 2 + test/test_defaultarrayinterface.jl | 69 +++++++++++++++++++++++- 5 files changed, 151 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 9206411..6d271cb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DerivableInterfaces" uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" authors = ["ITensor developers and contributors"] -version = "0.5.1" +version = "0.5.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/abstractinterface.jl b/src/abstractinterface.jl index 8223c6d..dd89129 100644 --- a/src/abstractinterface.jl +++ b/src/abstractinterface.jl @@ -6,6 +6,8 @@ interface(x1, x_rest...) = combine_interfaces(interface(x1), interface.(x_rest). abstract type AbstractInterface end +interface(x::AbstractInterface) = x + (interface::AbstractInterface)(f) = InterfaceFunction(interface, f) # Adapted from `Base.Broadcast.combine_styles`. diff --git a/src/defaultarrayinterface.jl b/src/defaultarrayinterface.jl index 3861a3e..059e0c0 100644 --- a/src/defaultarrayinterface.jl +++ b/src/defaultarrayinterface.jl @@ -4,14 +4,11 @@ DefaultArrayInterface() = DefaultArrayInterface{Any}() DefaultArrayInterface(::Val{N}) where {N} = DefaultArrayInterface{N}() DefaultArrayInterface{M}(::Val{N}) where {M,N} = DefaultArrayInterface{N}() -using TypeParameterAccessors: parenttype -function interface(a::Type{<:AbstractArray}) - parenttype(a) === a && return DefaultArrayInterface() - return interface(parenttype(a)) +function DerivableInterfaces.interface(arrayt::Type{<:Array{<:Any,N}}) where {N} + return DefaultArrayInterface{N}() end -function interface(a::Type{<:AbstractArray{<:Any,N}}) where {N} - parenttype(a) === a && return DefaultArrayInterface{N}() - return interface(parenttype(a)) +function DerivableInterfaces.interface(arrayt::Type{<:Array}) + return DefaultArrayInterface() end function combine_interface_rule( @@ -52,3 +49,77 @@ end function Base.similar(::DefaultArrayInterface, T::Type, ax::Tuple) return similar(Array{T}, ax) end + +struct ArrayInterface{N,A<:AbstractArray} <: AbstractArrayInterface{N} end +ArrayInterface{M,A}(::Val{N}) where {M,A,N} = ArrayInterface{N,A}() + +function Base.similar( + interface::ArrayInterface{A}, elt::Type, ax::Tuple +) where {A<:AbstractArray} + return similar(set_eltype(A, elt), ax) +end + +using TypeParameterAccessors: parenttype, unspecify_type_parameters +function _interface(::Val{N}, arrayt::Type{<:AbstractArray}) where {N} + arrayt′ = parenttype(arrayt) + if arrayt′ === arrayt + if arrayt <: Array || isabstracttype(arrayt) + return DefaultArrayInterface{N}() + else + return ArrayInterface{N,unspecify_type_parameters(arrayt)}() + end + end + return _interface(Val(N), arrayt′) +end + +function DerivableInterfaces.interface(arrayt::Type{<:AbstractArray{<:Any,N}}) where {N} + return _interface(Val(N), arrayt) +end +function DerivableInterfaces.interface(arrayt::Type{<:AbstractArray}) + return _interface(Val(Any), arrayt) +end + +using TypeParameterAccessors: set_eltype +function Base.similar(::ArrayInterface{<:Any,A}, T::Type, ax::Tuple) where {A} + return similar(set_eltype(A, T), ax) +end + +function combine_interface_rule( + interface1::ArrayInterface{N,A}, interface2::ArrayInterface{N,A} +) where {N,A<:AbstractArray} + return ArrayInterface{N,A}() +end + +function combine_interface_rule( + interface1::ArrayInterface{<:Any,A}, interface2::ArrayInterface{<:Any,A} +) where {A<:AbstractArray} + return ArrayInterface{Any,A}() +end +function combine_interface_rule( + interface1::ArrayInterface{N}, interface2::ArrayInterface{N} +) where {N} + return DefaultArrayInterface{N}() +end +function combine_interface_rule(interface1::ArrayInterface, interface2::ArrayInterface) + return DefaultArrayInterface() +end +function DerivableInterfaces.combine_interface_rule( + inter1::ArrayInterface, inter2::DefaultArrayInterface +) + return DefaultArrayInterface() +end +function DerivableInterfaces.combine_interface_rule( + inter1::DefaultArrayInterface, inter2::ArrayInterface +) + return DefaultArrayInterface() +end +function DerivableInterfaces.combine_interface_rule( + inter1::ArrayInterface{N}, inter2::DefaultArrayInterface{N} +) where {N} + return DefaultArrayInterface{N}() +end +function DerivableInterfaces.combine_interface_rule( + inter1::DefaultArrayInterface{N}, inter2::ArrayInterface{N} +) where {N} + return DefaultArrayInterface{N}() +end diff --git a/test/Project.toml b/test/Project.toml index 2c88278..03a08a0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" @@ -12,6 +13,7 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Aqua = "0.8" ArrayLayouts = "1" DerivableInterfaces = "0.5" +JLArrays = "0.2" LinearAlgebra = "1" SafeTestsets = "0.1" Suppressor = "0.2" diff --git a/test/test_defaultarrayinterface.jl b/test/test_defaultarrayinterface.jl index 6d2cb23..81e97f5 100644 --- a/test/test_defaultarrayinterface.jl +++ b/test/test_defaultarrayinterface.jl @@ -1,4 +1,5 @@ -using DerivableInterfaces: @interface, DefaultArrayInterface, interface +using DerivableInterfaces: @interface, ArrayInterface, DefaultArrayInterface, interface +using JLArrays: JLArray, jl using Test: @testset, @test using TestExtras: @constinferred @@ -41,6 +42,30 @@ end @test DefaultArrayInterface(Val(2)) === DefaultArrayInterface{2}() @test DefaultArrayInterface{Any}(Val(2)) === DefaultArrayInterface{2}() @test DefaultArrayInterface{3}(Val(2)) === DefaultArrayInterface{2}() + + # DefaultArrayInterface + @test interface(AbstractArray) === DefaultArrayInterface{Any}() + @test interface(AbstractArray{<:Any,3}) === DefaultArrayInterface{3}() + @test interface(Array{Float32}) === DefaultArrayInterface{Any}() + @test interface(Array{Float32,3}) === DefaultArrayInterface{3}() + @test interface(SubArray{<:Any,<:Any,Array}) === DefaultArrayInterface{Any}() + @test interface(SubArray{<:Any,<:Any,AbstractArray}) === DefaultArrayInterface{Any}() + @test interface(SubArray{<:Any,2,Array}) === DefaultArrayInterface{2}() + @test interface(randn(2, 2)) === DefaultArrayInterface{2}() + @test interface(view(randn(2, 2), 1:2, 1)) === DefaultArrayInterface{1}() + + # Combining DefaultArrayInterface + @test interface(DefaultArrayInterface(), DefaultArrayInterface()) === + DefaultArrayInterface() + @test interface(DefaultArrayInterface{2}(), DefaultArrayInterface{2}()) === + DefaultArrayInterface{2}() + @test interface(DefaultArrayInterface{2}(), DefaultArrayInterface{3}()) === + DefaultArrayInterface() + @test interface(DefaultArrayInterface(), DefaultArrayInterface{3}()) === + DefaultArrayInterface() + @test interface(randn(2, 2), randn(2, 2)) === DefaultArrayInterface{2}() + @test interface(randn(2, 2), randn(2)) === DefaultArrayInterface() + @test interface(randn(2, 2), randn(2, 2)') === DefaultArrayInterface{2}() end @testset "similar(::DefaultArrayInterface, ...)" begin @@ -59,3 +84,45 @@ end @test interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) == DefaultArrayInterface{1}() end + +@testset "ArrayInterface" begin + # ArrayInterface + a = jl(randn(2, 2)) + @test interface(JLArray{Float32}) === ArrayInterface{Any,JLArray}() + @test interface(SubArray{<:Any,2,JLArray{Float32}}) === ArrayInterface{2,JLArray}() + @test interface(a) === ArrayInterface{2,JLArray}() + @test interface(a') === ArrayInterface{2,JLArray}() + @test interface(view(a, 1:2, 1)) === ArrayInterface{1,JLArray}() + a′ = similar(a, Float32, (2, 3, 3)) + @test a′ isa JLArray{Float32,3} + @test size(a′) == (2, 3, 3) + + # Combining ArrayInterface + @test interface(ArrayInterface{2,JLArray}(), ArrayInterface{2,JLArray}()) === + ArrayInterface{2,JLArray}() + @test interface(ArrayInterface{2,JLArray}(), ArrayInterface{3,JLArray}()) === + ArrayInterface{Any,JLArray}() + @test interface(ArrayInterface{2,JLArray}(), DefaultArrayInterface{2}()) === + DefaultArrayInterface{2}() + @test interface(ArrayInterface{2,JLArray}(), ArrayInterface{2,Array}()) === + DefaultArrayInterface{2}() + @test interface(DefaultArrayInterface{2}(), ArrayInterface{2,JLArray}()) === + DefaultArrayInterface{2}() + @test interface(ArrayInterface{2,Array}(), ArrayInterface{2,JLArray}()) === + DefaultArrayInterface{2}() + @test interface(ArrayInterface{2,JLArray}(), DefaultArrayInterface{3}()) === + DefaultArrayInterface() + @test interface(ArrayInterface{2,JLArray}(), ArrayInterface{3,Array}()) === + DefaultArrayInterface() + @test interface(DefaultArrayInterface{3}(), ArrayInterface{2,JLArray}()) === + DefaultArrayInterface() + @test interface(ArrayInterface{3,Array}(), ArrayInterface{2,JLArray}()) === + DefaultArrayInterface() + @test interface(jl(randn(2, 2)), jl(randn(2, 2))) === ArrayInterface{2,JLArray}() + @test interface(jl(randn(2, 2)), jl(randn(2, 2))') === ArrayInterface{2,JLArray}() + @test interface(jl(randn(2, 2)), jl(randn(2, 2, 2))) === ArrayInterface{Any,JLArray}() + @test interface(view(jl(randn(2, 2))', 1:2, 1), jl(randn(2))) === + ArrayInterface{1,JLArray}() + @test interface(randn(2, 2), jl(randn(2, 2))) === DefaultArrayInterface{2}() + @test interface(randn(2, 2), jl(randn(2))) === DefaultArrayInterface() +end From cf56abd73e51144cc26edf9772c87db015c2f1f4 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 12 Jun 2025 11:37:13 -0400 Subject: [PATCH 2/3] Fix tests --- src/defaultarrayinterface.jl | 2 +- test/SparseArrayDOKs.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/defaultarrayinterface.jl b/src/defaultarrayinterface.jl index 059e0c0..3eb64d4 100644 --- a/src/defaultarrayinterface.jl +++ b/src/defaultarrayinterface.jl @@ -69,7 +69,7 @@ function _interface(::Val{N}, arrayt::Type{<:AbstractArray}) where {N} return ArrayInterface{N,unspecify_type_parameters(arrayt)}() end end - return _interface(Val(N), arrayt′) + return typeof(interface(arrayt′))(Val(N)) end function DerivableInterfaces.interface(arrayt::Type{<:AbstractArray{<:Any,N}}) where {N} diff --git a/test/SparseArrayDOKs.jl b/test/SparseArrayDOKs.jl index 9f80c3f..192cf64 100644 --- a/test/SparseArrayDOKs.jl +++ b/test/SparseArrayDOKs.jl @@ -267,7 +267,7 @@ end # Specify the interface the type adheres to. function DerivableInterfaces.interface(arrayt::Type{<:SparseArrayDOK}) - SparseArrayInterface{ndims(arrayt)}() + return SparseArrayInterface{ndims(arrayt)}() end # Define aliases like `SparseMatrixDOK`, `AnySparseArrayDOK`, etc. From c53cf32849cb38e4fb3738c52ff0db720316c173 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 12 Jun 2025 12:59:06 -0400 Subject: [PATCH 3/3] Merge ArrayInterface into DefaultArrayInterface --- src/defaultarrayinterface.jl | 124 ++++++++-------------- test/test_defaultarrayinterface.jl | 158 +++++++++++++++++------------ 2 files changed, 134 insertions(+), 148 deletions(-) diff --git a/src/defaultarrayinterface.jl b/src/defaultarrayinterface.jl index 3eb64d4..d0d7a2e 100644 --- a/src/defaultarrayinterface.jl +++ b/src/defaultarrayinterface.jl @@ -1,16 +1,50 @@ -struct DefaultArrayInterface{N} <: AbstractArrayInterface{N} end +using TypeParameterAccessors: parenttype, set_eltype, unspecify_type_parameters +struct DefaultArrayInterface{N,A<:AbstractArray} <: AbstractArrayInterface{N} end + +DefaultArrayInterface{N}() where {N} = DefaultArrayInterface{N,AbstractArray}() DefaultArrayInterface() = DefaultArrayInterface{Any}() DefaultArrayInterface(::Val{N}) where {N} = DefaultArrayInterface{N}() DefaultArrayInterface{M}(::Val{N}) where {M,N} = DefaultArrayInterface{N}() +DefaultArrayInterface{M,A}(::Val{N}) where {M,A,N} = DefaultArrayInterface{N,A}() -function DerivableInterfaces.interface(arrayt::Type{<:Array{<:Any,N}}) where {N} - return DefaultArrayInterface{N}() +# This version remembers the `ndims` of the wrapper type. +function _interface(::Val{N}, arrayt::Type{<:AbstractArray}) where {N} + arrayt′ = parenttype(arrayt) + if arrayt′ === arrayt + return DefaultArrayInterface{N,unspecify_type_parameters(arrayt)}() + end + return typeof(interface(arrayt′))(Val(N)) end -function DerivableInterfaces.interface(arrayt::Type{<:Array}) - return DefaultArrayInterface() + +function DerivableInterfaces.interface(arrayt::Type{<:AbstractArray{<:Any,N}}) where {N} + return _interface(Val(N), arrayt) +end +function DerivableInterfaces.interface(arrayt::Type{<:AbstractArray}) + return _interface(Val(Any), arrayt) +end + +function Base.similar( + ::DefaultArrayInterface{<:Any,A}, T::Type, ax::Tuple +) where {A<:AbstractArray} + if isabstracttype(A) + # If the type is abstract, default to constructing the array on CPU. + return similar(Array{T}, ax) + else + return similar(set_eltype(A, T), ax) + end end +function combine_interface_rule( + interface1::DefaultArrayInterface{N,A}, interface2::DefaultArrayInterface{N,A} +) where {N,A<:AbstractArray} + return DefaultArrayInterface{N,A}() +end +function combine_interface_rule( + interface1::DefaultArrayInterface{<:Any,A}, interface2::DefaultArrayInterface{<:Any,A} +) where {A<:AbstractArray} + return DefaultArrayInterface{Any,A}() +end function combine_interface_rule( interface1::DefaultArrayInterface{N}, interface2::DefaultArrayInterface{N} ) where {N} @@ -19,7 +53,7 @@ end function combine_interface_rule( interface1::DefaultArrayInterface, interface2::DefaultArrayInterface ) - return DefaultArrayInterface{Any}() + return DefaultArrayInterface() end @interface ::DefaultArrayInterface function Base.getindex( @@ -45,81 +79,3 @@ end ) return Base.mapreduce(f, op, as...; kwargs...) end - -function Base.similar(::DefaultArrayInterface, T::Type, ax::Tuple) - return similar(Array{T}, ax) -end - -struct ArrayInterface{N,A<:AbstractArray} <: AbstractArrayInterface{N} end -ArrayInterface{M,A}(::Val{N}) where {M,A,N} = ArrayInterface{N,A}() - -function Base.similar( - interface::ArrayInterface{A}, elt::Type, ax::Tuple -) where {A<:AbstractArray} - return similar(set_eltype(A, elt), ax) -end - -using TypeParameterAccessors: parenttype, unspecify_type_parameters -function _interface(::Val{N}, arrayt::Type{<:AbstractArray}) where {N} - arrayt′ = parenttype(arrayt) - if arrayt′ === arrayt - if arrayt <: Array || isabstracttype(arrayt) - return DefaultArrayInterface{N}() - else - return ArrayInterface{N,unspecify_type_parameters(arrayt)}() - end - end - return typeof(interface(arrayt′))(Val(N)) -end - -function DerivableInterfaces.interface(arrayt::Type{<:AbstractArray{<:Any,N}}) where {N} - return _interface(Val(N), arrayt) -end -function DerivableInterfaces.interface(arrayt::Type{<:AbstractArray}) - return _interface(Val(Any), arrayt) -end - -using TypeParameterAccessors: set_eltype -function Base.similar(::ArrayInterface{<:Any,A}, T::Type, ax::Tuple) where {A} - return similar(set_eltype(A, T), ax) -end - -function combine_interface_rule( - interface1::ArrayInterface{N,A}, interface2::ArrayInterface{N,A} -) where {N,A<:AbstractArray} - return ArrayInterface{N,A}() -end - -function combine_interface_rule( - interface1::ArrayInterface{<:Any,A}, interface2::ArrayInterface{<:Any,A} -) where {A<:AbstractArray} - return ArrayInterface{Any,A}() -end -function combine_interface_rule( - interface1::ArrayInterface{N}, interface2::ArrayInterface{N} -) where {N} - return DefaultArrayInterface{N}() -end -function combine_interface_rule(interface1::ArrayInterface, interface2::ArrayInterface) - return DefaultArrayInterface() -end -function DerivableInterfaces.combine_interface_rule( - inter1::ArrayInterface, inter2::DefaultArrayInterface -) - return DefaultArrayInterface() -end -function DerivableInterfaces.combine_interface_rule( - inter1::DefaultArrayInterface, inter2::ArrayInterface -) - return DefaultArrayInterface() -end -function DerivableInterfaces.combine_interface_rule( - inter1::ArrayInterface{N}, inter2::DefaultArrayInterface{N} -) where {N} - return DefaultArrayInterface{N}() -end -function DerivableInterfaces.combine_interface_rule( - inter1::DefaultArrayInterface{N}, inter2::ArrayInterface{N} -) where {N} - return DefaultArrayInterface{N}() -end diff --git a/test/test_defaultarrayinterface.jl b/test/test_defaultarrayinterface.jl index 81e97f5..6820882 100644 --- a/test/test_defaultarrayinterface.jl +++ b/test/test_defaultarrayinterface.jl @@ -1,4 +1,4 @@ -using DerivableInterfaces: @interface, ArrayInterface, DefaultArrayInterface, interface +using DerivableInterfaces: @interface, DefaultArrayInterface, interface using JLArrays: JLArray, jl using Test: @testset, @test using TestExtras: @constinferred @@ -34,38 +34,47 @@ end end @testset "DefaultArrayInterface" begin - @test interface(Array) === DefaultArrayInterface{Any}() - @test interface(Array{Float32}) === DefaultArrayInterface{Any}() - @test interface(Matrix) === DefaultArrayInterface{2}() - @test interface(Matrix{Float32}) === DefaultArrayInterface{2}() - @test DefaultArrayInterface() === DefaultArrayInterface{Any}() - @test DefaultArrayInterface(Val(2)) === DefaultArrayInterface{2}() - @test DefaultArrayInterface{Any}(Val(2)) === DefaultArrayInterface{2}() - @test DefaultArrayInterface{3}(Val(2)) === DefaultArrayInterface{2}() + @test @constinferred(interface(Array)) === DefaultArrayInterface{Any,Array}() + @test @constinferred(interface(Array{Float32})) === DefaultArrayInterface{Any,Array}() + @test @constinferred(interface(Matrix)) === DefaultArrayInterface{2,Array}() + @test @constinferred(interface(Matrix{Float32})) === DefaultArrayInterface{2,Array}() + @test @constinferred(DefaultArrayInterface()) === DefaultArrayInterface{Any}() + @test @constinferred(DefaultArrayInterface(Val(2))) === DefaultArrayInterface{2}() + @test @constinferred(DefaultArrayInterface{Any}(Val(2))) === DefaultArrayInterface{2}() + @test @constinferred(DefaultArrayInterface{3}(Val(2))) === DefaultArrayInterface{2}() # DefaultArrayInterface - @test interface(AbstractArray) === DefaultArrayInterface{Any}() - @test interface(AbstractArray{<:Any,3}) === DefaultArrayInterface{3}() - @test interface(Array{Float32}) === DefaultArrayInterface{Any}() - @test interface(Array{Float32,3}) === DefaultArrayInterface{3}() - @test interface(SubArray{<:Any,<:Any,Array}) === DefaultArrayInterface{Any}() - @test interface(SubArray{<:Any,<:Any,AbstractArray}) === DefaultArrayInterface{Any}() - @test interface(SubArray{<:Any,2,Array}) === DefaultArrayInterface{2}() - @test interface(randn(2, 2)) === DefaultArrayInterface{2}() - @test interface(view(randn(2, 2), 1:2, 1)) === DefaultArrayInterface{1}() + @test @constinferred(interface(AbstractArray)) === DefaultArrayInterface{Any}() + @test @constinferred(interface(AbstractArray{<:Any,3})) === DefaultArrayInterface{3}() + @test @constinferred(interface(Array{Float32})) === DefaultArrayInterface{Any,Array}() + @test @constinferred(interface(Array{Float32,3})) === DefaultArrayInterface{3,Array}() + @test @constinferred(interface(SubArray{<:Any,<:Any,Array})) === + DefaultArrayInterface{Any,Array}() + @test @constinferred(interface(SubArray{<:Any,<:Any,AbstractArray})) === + DefaultArrayInterface{Any}() + @test @constinferred(interface(SubArray{<:Any,2,Array})) === + DefaultArrayInterface{2,Array}() + @test @constinferred(interface(randn(2, 2))) === DefaultArrayInterface{2,Array}() + @test @constinferred(interface(view(randn(2, 2), 1:2, 1))) === + DefaultArrayInterface{1,Array}() # Combining DefaultArrayInterface - @test interface(DefaultArrayInterface(), DefaultArrayInterface()) === + @test @constinferred(interface(DefaultArrayInterface(), DefaultArrayInterface())) === DefaultArrayInterface() - @test interface(DefaultArrayInterface{2}(), DefaultArrayInterface{2}()) === - DefaultArrayInterface{2}() - @test interface(DefaultArrayInterface{2}(), DefaultArrayInterface{3}()) === - DefaultArrayInterface() - @test interface(DefaultArrayInterface(), DefaultArrayInterface{3}()) === + @test @constinferred( + interface(DefaultArrayInterface{2}(), DefaultArrayInterface{2}()) + ) === DefaultArrayInterface{2}() + @test @constinferred( + interface(DefaultArrayInterface{2}(), DefaultArrayInterface{3}()) + ) === DefaultArrayInterface() + @test @constinferred(interface(DefaultArrayInterface(), DefaultArrayInterface{3}())) === DefaultArrayInterface() - @test interface(randn(2, 2), randn(2, 2)) === DefaultArrayInterface{2}() - @test interface(randn(2, 2), randn(2)) === DefaultArrayInterface() - @test interface(randn(2, 2), randn(2, 2)') === DefaultArrayInterface{2}() + @test @constinferred(interface(randn(2, 2), randn(2, 2))) === + DefaultArrayInterface{2,Array}() + @test @constinferred(interface(randn(2, 2), randn(2))) === + DefaultArrayInterface{Any,Array}() + @test @constinferred(interface(randn(2, 2), randn(2, 2)')) === + DefaultArrayInterface{2,Array}() end @testset "similar(::DefaultArrayInterface, ...)" begin @@ -73,56 +82,77 @@ end @test typeof(a) === Matrix{Float32} @test size(a) == (2, 2) + a = @constinferred similar(DefaultArrayInterface{Any,Array}(), Float32, (2, 2)) + @test typeof(a) === Matrix{Float32} + @test size(a) == (2, 2) + a = @constinferred similar(DefaultArrayInterface{1}(), Float32, (2, 2)) @test typeof(a) === Matrix{Float32} @test size(a) == (2, 2) end @testset "Broadcast.DefaultArrayStyle" begin - @test interface(Broadcast.DefaultArrayStyle) == DefaultArrayInterface() - @test interface(Broadcast.DefaultArrayStyle{2}) == DefaultArrayInterface{2}() - @test interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) == - DefaultArrayInterface{1}() + @test @constinferred(interface(Broadcast.DefaultArrayStyle)) == DefaultArrayInterface() + @test @constinferred(interface(Broadcast.DefaultArrayStyle{2})) == + DefaultArrayInterface{2}() + @test @constinferred( + interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) + ) == DefaultArrayInterface{1}() end -@testset "ArrayInterface" begin +@testset "DefaultArrayInterface with custom array type" begin # ArrayInterface a = jl(randn(2, 2)) - @test interface(JLArray{Float32}) === ArrayInterface{Any,JLArray}() - @test interface(SubArray{<:Any,2,JLArray{Float32}}) === ArrayInterface{2,JLArray}() - @test interface(a) === ArrayInterface{2,JLArray}() - @test interface(a') === ArrayInterface{2,JLArray}() - @test interface(view(a, 1:2, 1)) === ArrayInterface{1,JLArray}() - a′ = similar(a, Float32, (2, 3, 3)) + @test @constinferred(interface(JLArray{Float32})) === DefaultArrayInterface{Any,JLArray}() + @test @constinferred(interface(SubArray{<:Any,2,JLArray{Float32}})) === + DefaultArrayInterface{2,JLArray}() + @test @constinferred(interface(a)) === DefaultArrayInterface{2,JLArray}() + @test @constinferred(interface(a')) === DefaultArrayInterface{2,JLArray}() + @test @constinferred(interface(view(a, 1:2, 1))) === DefaultArrayInterface{1,JLArray}() + a′ = @constinferred similar(a, Float32, (2, 3, 3)) @test a′ isa JLArray{Float32,3} @test size(a′) == (2, 3, 3) # Combining ArrayInterface - @test interface(ArrayInterface{2,JLArray}(), ArrayInterface{2,JLArray}()) === - ArrayInterface{2,JLArray}() - @test interface(ArrayInterface{2,JLArray}(), ArrayInterface{3,JLArray}()) === - ArrayInterface{Any,JLArray}() - @test interface(ArrayInterface{2,JLArray}(), DefaultArrayInterface{2}()) === - DefaultArrayInterface{2}() - @test interface(ArrayInterface{2,JLArray}(), ArrayInterface{2,Array}()) === + @test @constinferred( + interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{2,JLArray}()) + ) === DefaultArrayInterface{2,JLArray}() + @test @constinferred( + interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{3,JLArray}()) + ) === DefaultArrayInterface{Any,JLArray}() + @test @constinferred( + interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{2}()) + ) === DefaultArrayInterface{2}() + @test @constinferred( + interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{2,Array}()) + ) === DefaultArrayInterface{2}() + @test @constinferred( + interface(DefaultArrayInterface{2}(), DefaultArrayInterface{2,JLArray}()) + ) === DefaultArrayInterface{2}() + @test @constinferred( + interface(DefaultArrayInterface{2,Array}(), DefaultArrayInterface{2,JLArray}()) + ) === DefaultArrayInterface{2}() + @test @constinferred( + interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{3}()) + ) === DefaultArrayInterface() + @test @constinferred( + interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{3,Array}()) + ) === DefaultArrayInterface() + @test @constinferred( + interface(DefaultArrayInterface{3}(), DefaultArrayInterface{2,JLArray}()) + ) === DefaultArrayInterface() + @test @constinferred( + interface(DefaultArrayInterface{3,Array}(), DefaultArrayInterface{2,JLArray}()) + ) === DefaultArrayInterface() + @test @constinferred(interface(jl(randn(2, 2)), jl(randn(2, 2)))) === + DefaultArrayInterface{2,JLArray}() + @test @constinferred(interface(jl(randn(2, 2)), jl(randn(2, 2))')) === + DefaultArrayInterface{2,JLArray}() + @test @constinferred(interface(jl(randn(2, 2)), jl(randn(2, 2, 2)))) === + DefaultArrayInterface{Any,JLArray}() + @test @constinferred(interface(view(jl(randn(2, 2))', 1:2, 1), jl(randn(2)))) === + DefaultArrayInterface{1,JLArray}() + @test @constinferred(interface(randn(2, 2), jl(randn(2, 2)))) === DefaultArrayInterface{2}() - @test interface(DefaultArrayInterface{2}(), ArrayInterface{2,JLArray}()) === - DefaultArrayInterface{2}() - @test interface(ArrayInterface{2,Array}(), ArrayInterface{2,JLArray}()) === - DefaultArrayInterface{2}() - @test interface(ArrayInterface{2,JLArray}(), DefaultArrayInterface{3}()) === - DefaultArrayInterface() - @test interface(ArrayInterface{2,JLArray}(), ArrayInterface{3,Array}()) === - DefaultArrayInterface() - @test interface(DefaultArrayInterface{3}(), ArrayInterface{2,JLArray}()) === - DefaultArrayInterface() - @test interface(ArrayInterface{3,Array}(), ArrayInterface{2,JLArray}()) === - DefaultArrayInterface() - @test interface(jl(randn(2, 2)), jl(randn(2, 2))) === ArrayInterface{2,JLArray}() - @test interface(jl(randn(2, 2)), jl(randn(2, 2))') === ArrayInterface{2,JLArray}() - @test interface(jl(randn(2, 2)), jl(randn(2, 2, 2))) === ArrayInterface{Any,JLArray}() - @test interface(view(jl(randn(2, 2))', 1:2, 1), jl(randn(2))) === - ArrayInterface{1,JLArray}() - @test interface(randn(2, 2), jl(randn(2, 2))) === DefaultArrayInterface{2}() - @test interface(randn(2, 2), jl(randn(2))) === DefaultArrayInterface() + @test @constinferred(interface(randn(2, 2), jl(randn(2)))) === DefaultArrayInterface() end