Skip to content

Commit 4b9fed0

Browse files
authored
Add ndims type parameter to AbstractArrayInterface (#42)
1 parent fc8f02c commit 4b9fed0

File tree

9 files changed

+101
-35
lines changed

9 files changed

+101
-35
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DerivableInterfaces"
22
uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.4.5"
4+
version = "0.5.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
55

66
[compat]
7-
DerivableInterfaces = "0.4"
7+
DerivableInterfaces = "0.5"
88
Documenter = "1"
99
Literate = "2"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
44

55
[compat]
66
ArrayLayouts = "1"
7-
DerivableInterfaces = "0.4"
7+
DerivableInterfaces = "0.5"

src/abstractarrayinterface.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11
# TODO: Add `ndims` type parameter.
2-
abstract type AbstractArrayInterface <: AbstractInterface end
2+
abstract type AbstractArrayInterface{N} <: AbstractInterface end
33

4+
function interface(::Type{<:Broadcast.AbstractArrayStyle{N}}) where {N}
5+
return DefaultArrayInterface{N}()
6+
end
47
function interface(::Type{<:Broadcast.AbstractArrayStyle})
58
return DefaultArrayInterface()
69
end
710

8-
function interface(::Type{<:Broadcast.Broadcasted{Nothing}})
9-
return DefaultArrayInterface()
11+
function interface(BC::Type{<:Broadcast.Broadcasted{Nothing}})
12+
return DefaultArrayInterface{ndims(BC)}()
1013
end
1114

1215
function interface(::Type{<:Broadcast.Broadcasted{<:Style}}) where {Style}
1316
return interface(Style)
1417
end
1518

16-
# TODO: Define as `Array{T}`.
17-
arraytype(::AbstractArrayInterface, T::Type) = error("Not implemented.")
19+
# TODO: Define as `similar(Array{T}, ax)`.
20+
function Base.similar(interface::AbstractArrayInterface, T::Type, ax::Tuple)
21+
return error("Not implemented.")
22+
end
1823

1924
using ArrayLayouts: ArrayLayouts
2025

@@ -85,7 +90,7 @@ end
8590
@interface interface::AbstractArrayInterface function Base.similar(
8691
a::AbstractArray, T::Type, size::Tuple{Vararg{Int}}
8792
)
88-
return similar(arraytype(interface, T), size)
93+
return similar(interface, T, size)
8994
end
9095

9196
@interface ::AbstractArrayInterface function Base.copy(a::AbstractArray)
@@ -105,8 +110,7 @@ end
105110
@interface interface::AbstractArrayInterface function Base.similar(
106111
bc::Broadcast.Broadcasted, T::Type, axes::Tuple
107112
)
108-
# `arraytype(::AbstractInterface)` determines the default array type associated with the interface.
109-
return similar(arraytype(interface, T), axes)
113+
return similar(interface, T, axes)
110114
end
111115

112116
using MapBroadcast: Mapped

src/concatenate.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ export concatenate
2828
@compat public Concatenated, cat, cat!, concatenated
2929

3030
using Base: promote_eltypeof
31-
using ..DerivableInterfaces:
32-
DerivableInterfaces, AbstractInterface, interface, zero!, arraytype
31+
using ..DerivableInterfaces: DerivableInterfaces, AbstractArrayInterface, interface, zero!
3332

3433
unval(x) = x
3534
unval(::Val{x}) where {x} = x
@@ -53,13 +52,17 @@ struct Concatenated{Interface,Dims,Args<:Tuple}
5352
end
5453
end
5554

56-
function Concatenated(interface::Union{Nothing,AbstractInterface}, dims::Val, args::Tuple)
55+
function Concatenated(
56+
interface::Union{AbstractArrayInterface,Nothing}, dims::Val, args::Tuple
57+
)
5758
return _Concatenated(interface, dims, args)
5859
end
5960
function Concatenated(dims::Val, args::Tuple)
60-
return Concatenated(interface(args...), dims, args)
61+
return Concatenated(cat_interface(dims, args...), dims, args)
6162
end
62-
function Concatenated{Interface}(dims::Val, args::Tuple) where {Interface}
63+
function Concatenated{Interface}(
64+
dims::Val, args::Tuple
65+
) where {Interface<:Union{AbstractArrayInterface,Nothing}}
6366
return Concatenated(Interface(), dims, args)
6467
end
6568

@@ -81,8 +84,11 @@ end
8184
# ------------------------------------
8285
Base.similar(concat::Concatenated) = similar(concat, eltype(concat))
8386
Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(concat))
84-
function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T}
85-
return similar(arraytype(interface(concat), T), ax)
87+
function Base.similar(concat::Concatenated, ax::Tuple)
88+
return similar(interface(concat), eltype(concat), ax)
89+
end
90+
function Base.similar(concat::Concatenated, ::Type{T}, ax::Tuple) where {T}
91+
return similar(interface(concat), T, ax)
8692
end
8793

8894
function cat_axis(
@@ -108,10 +114,15 @@ function cat_axes(dims::Val, as::AbstractArray...)
108114
return cat_axes(unval(dims), as...)
109115
end
110116

117+
function cat_interface(dims, as::AbstractArray...)
118+
N = cat_ndims(dims, as...)
119+
return typeof(interface(as...))(Val(N))
120+
end
121+
111122
Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...)
112123
Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...)
113124
Base.size(concat::Concatenated) = length.(axes(concat))
114-
Base.ndims(concat::Concatenated) = length(axes(concat))
125+
Base.ndims(concat::Concatenated) = cat_ndims(dims(concat), concat.args...)
115126

116127
# Main logic
117128
# ----------

src/defaultarrayinterface.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,29 @@
1-
# TODO: Add `ndims` type parameter.
2-
struct DefaultArrayInterface <: AbstractArrayInterface end
1+
struct DefaultArrayInterface{N} <: AbstractArrayInterface{N} end
2+
3+
DefaultArrayInterface() = DefaultArrayInterface{Any}()
4+
DefaultArrayInterface(::Val{N}) where {N} = DefaultArrayInterface{N}()
5+
DefaultArrayInterface{M}(::Val{N}) where {M,N} = DefaultArrayInterface{N}()
36

47
using TypeParameterAccessors: parenttype
58
function interface(a::Type{<:AbstractArray})
69
parenttype(a) === a && return DefaultArrayInterface()
710
return interface(parenttype(a))
811
end
12+
function interface(a::Type{<:AbstractArray{<:Any,N}}) where {N}
13+
parenttype(a) === a && return DefaultArrayInterface{N}()
14+
return interface(parenttype(a))
15+
end
16+
17+
function combine_interface_rule(
18+
interface1::DefaultArrayInterface{N}, interface2::DefaultArrayInterface{N}
19+
) where {N}
20+
return DefaultArrayInterface{N}()
21+
end
22+
function combine_interface_rule(
23+
interface1::DefaultArrayInterface, interface2::DefaultArrayInterface
24+
)
25+
return DefaultArrayInterface{Any}()
26+
end
927

1028
@interface ::DefaultArrayInterface function Base.getindex(
1129
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
@@ -31,6 +49,6 @@ end
3149
return Base.mapreduce(f, op, as...; kwargs...)
3250
end
3351

34-
function arraytype(::DefaultArrayInterface, T::Type)
35-
return Array{T}
52+
function Base.similar(::DefaultArrayInterface, T::Type, ax::Tuple)
53+
return similar(Array{T}, ax)
3654
end

test/Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
66
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
77
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
88
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9+
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
910

1011
[compat]
1112
Aqua = "0.8"
1213
ArrayLayouts = "1"
13-
DerivableInterfaces = "0.4"
14+
DerivableInterfaces = "0.5"
15+
LinearAlgebra = "1"
1416
SafeTestsets = "0.1"
1517
Suppressor = "0.2"
16-
LinearAlgebra = "1"
1718
Test = "1"
19+
TestExtras = "0.3"

test/SparseArrayDOKs.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ using DerivableInterfaces:
4040
using LinearAlgebra: LinearAlgebra
4141

4242
# Define an interface.
43-
struct SparseArrayInterface <: AbstractArrayInterface end
43+
struct SparseArrayInterface{N} <: AbstractArrayInterface{N} end
44+
SparseArrayInterface(::Val{N}) where {N} = SparseArrayInterface{N}()
45+
SparseArrayInterface{M}(::Val{N}) where {M,N} = SparseArrayInterface{N}()
4446

4547
# Define interface functions.
4648
@interface ::SparseArrayInterface function Base.getindex(
@@ -66,11 +68,15 @@ end
6668
struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
6769
SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}()
6870

69-
DerivableInterfaces.interface(::Type{<:SparseArrayStyle}) = SparseArrayInterface()
71+
function DerivableInterfaces.interface(::Type{<:SparseArrayStyle{N}}) where {N}
72+
return SparseArrayInterface{N}()
73+
end
7074

7175
@derive SparseArrayStyle AbstractArrayStyleOps
7276

73-
DerivableInterfaces.arraytype(::SparseArrayInterface, T::Type) = SparseArrayDOK{T}
77+
function Base.similar(::SparseArrayInterface, T::Type, ax::Tuple)
78+
return similar(SparseArrayDOK{T}, ax)
79+
end
7480

7581
# Interface functions.
7682
@interface ::SparseArrayInterface function Broadcast.BroadcastStyle(type::Type)
@@ -260,7 +266,9 @@ function DerivableInterfaces.zero!(a::SparseArrayDOK)
260266
end
261267

262268
# Specify the interface the type adheres to.
263-
DerivableInterfaces.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface()
269+
function DerivableInterfaces.interface(arrayt::Type{<:SparseArrayDOK})
270+
SparseArrayInterface{ndims(arrayt)}()
271+
end
264272

265273
# Define aliases like `SparseMatrixDOK`, `AnySparseArrayDOK`, etc.
266274
@array_aliases SparseArrayDOK

test/test_defaultarrayinterface.jl

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
using Test: @inferred, @testset, @test
21
using DerivableInterfaces: @interface, DefaultArrayInterface, interface
2+
using Test: @testset, @test
3+
using TestExtras: @constinferred
34

45
# function wrappers to test type-stability
56
_getindex(A, i...) = @interface DefaultArrayInterface() A[i...]
@@ -11,28 +12,50 @@ end
1112

1213
@testset "indexing" begin
1314
for (A, i) in ((zeros(2), 2), (zeros(2, 2), (2, 1)), (zeros(1, 2, 3), (1, 2, 3)))
14-
a = @inferred _getindex(A, i...)
15+
a = @constinferred _getindex(A, i...)
1516
@test a == A[i...]
1617
v = 1.1
17-
A′ = @inferred _setindex!(A, v, i...)
18+
A′ = @constinferred _setindex!(A, v, i...)
1819
@test A′ == (A[i...] = v)
1920
end
2021
end
2122

2223
@testset "map!" begin
2324
A = zeros(3)
24-
a = @inferred _map!(Returns(2), copy(A), A)
25+
a = @constinferred _map!(Returns(2), copy(A), A)
2526
@test a == map!(Returns(2), copy(A), A)
2627
end
2728

2829
@testset "mapreduce" begin
2930
A = zeros(3)
30-
a = @inferred _mapreduce(Returns(2), +, A)
31+
a = @constinferred _mapreduce(Returns(2), +, A)
3132
@test a == mapreduce(Returns(2), +, A)
3233
end
3334

35+
@testset "DefaultArrayInterface" begin
36+
@test interface(Array) === DefaultArrayInterface{Any}()
37+
@test interface(Array{Float32}) === DefaultArrayInterface{Any}()
38+
@test interface(Matrix) === DefaultArrayInterface{2}()
39+
@test interface(Matrix{Float32}) === DefaultArrayInterface{2}()
40+
@test DefaultArrayInterface() === DefaultArrayInterface{Any}()
41+
@test DefaultArrayInterface(Val(2)) === DefaultArrayInterface{2}()
42+
@test DefaultArrayInterface{Any}(Val(2)) === DefaultArrayInterface{2}()
43+
@test DefaultArrayInterface{3}(Val(2)) === DefaultArrayInterface{2}()
44+
end
45+
46+
@testset "similar(::DefaultArrayInterface, ...)" begin
47+
a = @constinferred similar(DefaultArrayInterface(), Float32, (2, 2))
48+
@test typeof(a) === Matrix{Float32}
49+
@test size(a) == (2, 2)
50+
51+
a = @constinferred similar(DefaultArrayInterface{1}(), Float32, (2, 2))
52+
@test typeof(a) === Matrix{Float32}
53+
@test size(a) == (2, 2)
54+
end
55+
3456
@testset "Broadcast.DefaultArrayStyle" begin
3557
@test interface(Broadcast.DefaultArrayStyle) == DefaultArrayInterface()
58+
@test interface(Broadcast.DefaultArrayStyle{2}) == DefaultArrayInterface{2}()
3659
@test interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) ==
37-
DefaultArrayInterface()
60+
DefaultArrayInterface{1}()
3861
end

0 commit comments

Comments
 (0)