Skip to content

Commit 41f2ca2

Browse files
authored
Add array type parameter to DefaultArrayInterface (#44)
1 parent bb0a449 commit 41f2ca2

File tree

6 files changed

+155
-27
lines changed

6 files changed

+155
-27
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DerivableInterfaces"
22
uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.1"
4+
version = "0.5.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractinterface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ interface(x1, x_rest...) = combine_interfaces(interface(x1), interface.(x_rest).
66

77
abstract type AbstractInterface end
88

9+
interface(x::AbstractInterface) = x
10+
911
(interface::AbstractInterface)(f) = InterfaceFunction(interface, f)
1012

1113
# Adapted from `Base.Broadcast.combine_styles`.

src/defaultarrayinterface.jl

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,50 @@
1-
struct DefaultArrayInterface{N} <: AbstractArrayInterface{N} end
1+
using TypeParameterAccessors: parenttype, set_eltype, unspecify_type_parameters
22

3+
struct DefaultArrayInterface{N,A<:AbstractArray} <: AbstractArrayInterface{N} end
4+
5+
DefaultArrayInterface{N}() where {N} = DefaultArrayInterface{N,AbstractArray}()
36
DefaultArrayInterface() = DefaultArrayInterface{Any}()
47
DefaultArrayInterface(::Val{N}) where {N} = DefaultArrayInterface{N}()
58
DefaultArrayInterface{M}(::Val{N}) where {M,N} = DefaultArrayInterface{N}()
9+
DefaultArrayInterface{M,A}(::Val{N}) where {M,A,N} = DefaultArrayInterface{N,A}()
10+
11+
# This version remembers the `ndims` of the wrapper type.
12+
function _interface(::Val{N}, arrayt::Type{<:AbstractArray}) where {N}
13+
arrayt′ = parenttype(arrayt)
14+
if arrayt′ === arrayt
15+
return DefaultArrayInterface{N,unspecify_type_parameters(arrayt)}()
16+
end
17+
return typeof(interface(arrayt′))(Val(N))
18+
end
619

7-
using TypeParameterAccessors: parenttype
8-
function interface(a::Type{<:AbstractArray})
9-
parenttype(a) === a && return DefaultArrayInterface()
10-
return interface(parenttype(a))
20+
function DerivableInterfaces.interface(arrayt::Type{<:AbstractArray{<:Any,N}}) where {N}
21+
return _interface(Val(N), arrayt)
1122
end
12-
function interface(a::Type{<:AbstractArray{<:Any,N}}) where {N}
13-
parenttype(a) === a && return DefaultArrayInterface{N}()
14-
return interface(parenttype(a))
23+
function DerivableInterfaces.interface(arrayt::Type{<:AbstractArray})
24+
return _interface(Val(Any), arrayt)
25+
end
26+
27+
function Base.similar(
28+
::DefaultArrayInterface{<:Any,A}, T::Type, ax::Tuple
29+
) where {A<:AbstractArray}
30+
if isabstracttype(A)
31+
# If the type is abstract, default to constructing the array on CPU.
32+
return similar(Array{T}, ax)
33+
else
34+
return similar(set_eltype(A, T), ax)
35+
end
1536
end
1637

38+
function combine_interface_rule(
39+
interface1::DefaultArrayInterface{N,A}, interface2::DefaultArrayInterface{N,A}
40+
) where {N,A<:AbstractArray}
41+
return DefaultArrayInterface{N,A}()
42+
end
43+
function combine_interface_rule(
44+
interface1::DefaultArrayInterface{<:Any,A}, interface2::DefaultArrayInterface{<:Any,A}
45+
) where {A<:AbstractArray}
46+
return DefaultArrayInterface{Any,A}()
47+
end
1748
function combine_interface_rule(
1849
interface1::DefaultArrayInterface{N}, interface2::DefaultArrayInterface{N}
1950
) where {N}
@@ -22,7 +53,7 @@ end
2253
function combine_interface_rule(
2354
interface1::DefaultArrayInterface, interface2::DefaultArrayInterface
2455
)
25-
return DefaultArrayInterface{Any}()
56+
return DefaultArrayInterface()
2657
end
2758

2859
@interface ::DefaultArrayInterface function Base.getindex(
@@ -48,7 +79,3 @@ end
4879
)
4980
return Base.mapreduce(f, op, as...; kwargs...)
5081
end
51-
52-
function Base.similar(::DefaultArrayInterface, T::Type, ax::Tuple)
53-
return similar(Array{T}, ax)
54-
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
44
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
5+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
56
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
67
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
78
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
@@ -12,6 +13,7 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
1213
Aqua = "0.8"
1314
ArrayLayouts = "1"
1415
DerivableInterfaces = "0.5"
16+
JLArrays = "0.2"
1517
LinearAlgebra = "1"
1618
SafeTestsets = "0.1"
1719
Suppressor = "0.2"

test/SparseArrayDOKs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ end
267267

268268
# Specify the interface the type adheres to.
269269
function DerivableInterfaces.interface(arrayt::Type{<:SparseArrayDOK})
270-
SparseArrayInterface{ndims(arrayt)}()
270+
return SparseArrayInterface{ndims(arrayt)}()
271271
end
272272

273273
# Define aliases like `SparseMatrixDOK`, `AnySparseArrayDOK`, etc.

test/test_defaultarrayinterface.jl

Lines changed: 109 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using DerivableInterfaces: @interface, DefaultArrayInterface, interface
2+
using JLArrays: JLArray, jl
23
using Test: @testset, @test
34
using TestExtras: @constinferred
45

@@ -33,29 +34,125 @@ end
3334
end
3435

3536
@testset "DefaultArrayInterface" begin
36-
@test interface(Array) === DefaultArrayInterface{Any}()
37-
@test interface(Array{Float32}) === DefaultArrayInterface{Any}()
38-
@test interface(Matrix) === DefaultArrayInterface{2}()
39-
@test interface(Matrix{Float32}) === DefaultArrayInterface{2}()
40-
@test DefaultArrayInterface() === DefaultArrayInterface{Any}()
41-
@test DefaultArrayInterface(Val(2)) === DefaultArrayInterface{2}()
42-
@test DefaultArrayInterface{Any}(Val(2)) === DefaultArrayInterface{2}()
43-
@test DefaultArrayInterface{3}(Val(2)) === DefaultArrayInterface{2}()
37+
@test @constinferred(interface(Array)) === DefaultArrayInterface{Any,Array}()
38+
@test @constinferred(interface(Array{Float32})) === DefaultArrayInterface{Any,Array}()
39+
@test @constinferred(interface(Matrix)) === DefaultArrayInterface{2,Array}()
40+
@test @constinferred(interface(Matrix{Float32})) === DefaultArrayInterface{2,Array}()
41+
@test @constinferred(DefaultArrayInterface()) === DefaultArrayInterface{Any}()
42+
@test @constinferred(DefaultArrayInterface(Val(2))) === DefaultArrayInterface{2}()
43+
@test @constinferred(DefaultArrayInterface{Any}(Val(2))) === DefaultArrayInterface{2}()
44+
@test @constinferred(DefaultArrayInterface{3}(Val(2))) === DefaultArrayInterface{2}()
45+
46+
# DefaultArrayInterface
47+
@test @constinferred(interface(AbstractArray)) === DefaultArrayInterface{Any}()
48+
@test @constinferred(interface(AbstractArray{<:Any,3})) === DefaultArrayInterface{3}()
49+
@test @constinferred(interface(Array{Float32})) === DefaultArrayInterface{Any,Array}()
50+
@test @constinferred(interface(Array{Float32,3})) === DefaultArrayInterface{3,Array}()
51+
@test @constinferred(interface(SubArray{<:Any,<:Any,Array})) ===
52+
DefaultArrayInterface{Any,Array}()
53+
@test @constinferred(interface(SubArray{<:Any,<:Any,AbstractArray})) ===
54+
DefaultArrayInterface{Any}()
55+
@test @constinferred(interface(SubArray{<:Any,2,Array})) ===
56+
DefaultArrayInterface{2,Array}()
57+
@test @constinferred(interface(randn(2, 2))) === DefaultArrayInterface{2,Array}()
58+
@test @constinferred(interface(view(randn(2, 2), 1:2, 1))) ===
59+
DefaultArrayInterface{1,Array}()
60+
61+
# Combining DefaultArrayInterface
62+
@test @constinferred(interface(DefaultArrayInterface(), DefaultArrayInterface())) ===
63+
DefaultArrayInterface()
64+
@test @constinferred(
65+
interface(DefaultArrayInterface{2}(), DefaultArrayInterface{2}())
66+
) === DefaultArrayInterface{2}()
67+
@test @constinferred(
68+
interface(DefaultArrayInterface{2}(), DefaultArrayInterface{3}())
69+
) === DefaultArrayInterface()
70+
@test @constinferred(interface(DefaultArrayInterface(), DefaultArrayInterface{3}())) ===
71+
DefaultArrayInterface()
72+
@test @constinferred(interface(randn(2, 2), randn(2, 2))) ===
73+
DefaultArrayInterface{2,Array}()
74+
@test @constinferred(interface(randn(2, 2), randn(2))) ===
75+
DefaultArrayInterface{Any,Array}()
76+
@test @constinferred(interface(randn(2, 2), randn(2, 2)')) ===
77+
DefaultArrayInterface{2,Array}()
4478
end
4579

4680
@testset "similar(::DefaultArrayInterface, ...)" begin
4781
a = @constinferred similar(DefaultArrayInterface(), Float32, (2, 2))
4882
@test typeof(a) === Matrix{Float32}
4983
@test size(a) == (2, 2)
5084

85+
a = @constinferred similar(DefaultArrayInterface{Any,Array}(), Float32, (2, 2))
86+
@test typeof(a) === Matrix{Float32}
87+
@test size(a) == (2, 2)
88+
5189
a = @constinferred similar(DefaultArrayInterface{1}(), Float32, (2, 2))
5290
@test typeof(a) === Matrix{Float32}
5391
@test size(a) == (2, 2)
5492
end
5593

5694
@testset "Broadcast.DefaultArrayStyle" begin
57-
@test interface(Broadcast.DefaultArrayStyle) == DefaultArrayInterface()
58-
@test interface(Broadcast.DefaultArrayStyle{2}) == DefaultArrayInterface{2}()
59-
@test interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) ==
60-
DefaultArrayInterface{1}()
95+
@test @constinferred(interface(Broadcast.DefaultArrayStyle)) == DefaultArrayInterface()
96+
@test @constinferred(interface(Broadcast.DefaultArrayStyle{2})) ==
97+
DefaultArrayInterface{2}()
98+
@test @constinferred(
99+
interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2))))
100+
) == DefaultArrayInterface{1}()
101+
end
102+
103+
@testset "DefaultArrayInterface with custom array type" begin
104+
# ArrayInterface
105+
a = jl(randn(2, 2))
106+
@test @constinferred(interface(JLArray{Float32})) === DefaultArrayInterface{Any,JLArray}()
107+
@test @constinferred(interface(SubArray{<:Any,2,JLArray{Float32}})) ===
108+
DefaultArrayInterface{2,JLArray}()
109+
@test @constinferred(interface(a)) === DefaultArrayInterface{2,JLArray}()
110+
@test @constinferred(interface(a')) === DefaultArrayInterface{2,JLArray}()
111+
@test @constinferred(interface(view(a, 1:2, 1))) === DefaultArrayInterface{1,JLArray}()
112+
a′ = @constinferred similar(a, Float32, (2, 3, 3))
113+
@test a′ isa JLArray{Float32,3}
114+
@test size(a′) == (2, 3, 3)
115+
116+
# Combining ArrayInterface
117+
@test @constinferred(
118+
interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{2,JLArray}())
119+
) === DefaultArrayInterface{2,JLArray}()
120+
@test @constinferred(
121+
interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{3,JLArray}())
122+
) === DefaultArrayInterface{Any,JLArray}()
123+
@test @constinferred(
124+
interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{2}())
125+
) === DefaultArrayInterface{2}()
126+
@test @constinferred(
127+
interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{2,Array}())
128+
) === DefaultArrayInterface{2}()
129+
@test @constinferred(
130+
interface(DefaultArrayInterface{2}(), DefaultArrayInterface{2,JLArray}())
131+
) === DefaultArrayInterface{2}()
132+
@test @constinferred(
133+
interface(DefaultArrayInterface{2,Array}(), DefaultArrayInterface{2,JLArray}())
134+
) === DefaultArrayInterface{2}()
135+
@test @constinferred(
136+
interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{3}())
137+
) === DefaultArrayInterface()
138+
@test @constinferred(
139+
interface(DefaultArrayInterface{2,JLArray}(), DefaultArrayInterface{3,Array}())
140+
) === DefaultArrayInterface()
141+
@test @constinferred(
142+
interface(DefaultArrayInterface{3}(), DefaultArrayInterface{2,JLArray}())
143+
) === DefaultArrayInterface()
144+
@test @constinferred(
145+
interface(DefaultArrayInterface{3,Array}(), DefaultArrayInterface{2,JLArray}())
146+
) === DefaultArrayInterface()
147+
@test @constinferred(interface(jl(randn(2, 2)), jl(randn(2, 2)))) ===
148+
DefaultArrayInterface{2,JLArray}()
149+
@test @constinferred(interface(jl(randn(2, 2)), jl(randn(2, 2))')) ===
150+
DefaultArrayInterface{2,JLArray}()
151+
@test @constinferred(interface(jl(randn(2, 2)), jl(randn(2, 2, 2)))) ===
152+
DefaultArrayInterface{Any,JLArray}()
153+
@test @constinferred(interface(view(jl(randn(2, 2))', 1:2, 1), jl(randn(2)))) ===
154+
DefaultArrayInterface{1,JLArray}()
155+
@test @constinferred(interface(randn(2, 2), jl(randn(2, 2)))) ===
156+
DefaultArrayInterface{2}()
157+
@test @constinferred(interface(randn(2, 2), jl(randn(2)))) === DefaultArrayInterface()
61158
end

0 commit comments

Comments
 (0)