Skip to content

Commit 3763330

Browse files
authored
Add definitions for AbstractArrayInterface (#7)
1 parent 54bfc9d commit 3763330

File tree

11 files changed

+403
-17
lines changed

11 files changed

+403
-17
lines changed

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
name = "Derive"
22
uuid = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.0"
4+
version = "0.3.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
9+
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
810
ExproniconLite = "55351af7-c7e9-48d6-89ff-24e801d99491"
11+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
912
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
1013
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1114

1215
[compat]
1316
Adapt = "4.1.1"
1417
Aqua = "0.8.9"
18+
ArrayLayouts = "1.11.0"
19+
BroadcastMapConversion = "0.1.0"
1520
ExproniconLite = "0.10.13"
21+
LinearAlgebra = "1.10"
1622
MLStyle = "0.4.17"
1723
SafeTestsets = "0.1"
1824
Suppressor = "0.2"

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ struct SparseArrayInterface end
6060
Define interface functions.
6161

6262
````julia
63-
@interface SparseArrayInterface function Base.getindex(a, I::Int...)
63+
@interface ::SparseArrayInterface function Base.getindex(a, I::Int...)
6464
checkbounds(a, I...)
6565
!isstored(a, I...) && return getunstoredindex(a, I...)
6666
return getstoredindex(a, I...)
6767
end
68-
@interface SparseArrayInterface function Base.setindex!(a, value, I::Int...)
68+
@interface ::SparseArrayInterface function Base.setindex!(a, value, I::Int...)
6969
checkbounds(a, I...)
7070
iszero(value) && return a
7171
if !isstored(a, I...)

examples/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
[deps]
2+
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
23
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"

examples/README.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ using Test: @test
5959
struct SparseArrayInterface end
6060

6161
# Define interface functions.
62-
@interface SparseArrayInterface function Base.getindex(a, I::Int...)
62+
@interface ::SparseArrayInterface function Base.getindex(a, I::Int...)
6363
checkbounds(a, I...)
6464
!isstored(a, I...) && return getunstoredindex(a, I...)
6565
return getstoredindex(a, I...)
6666
end
67-
@interface SparseArrayInterface function Base.setindex!(a, value, I::Int...)
67+
@interface ::SparseArrayInterface function Base.setindex!(a, value, I::Int...)
6868
checkbounds(a, I...)
6969
iszero(value) && return a
7070
if !isstored(a, I...)

src/abstractarrayinterface.jl

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,123 @@
11
# TODO: Add `ndims` type parameter.
22
abstract type AbstractArrayInterface <: AbstractInterface end
3+
4+
function interface(::Type{<:Broadcast.AbstractArrayStyle})
5+
return error("Not defined.")
6+
end
7+
8+
function interface(::Type{<:Broadcast.Broadcasted{<:Style}}) where {Style}
9+
return interface(Style)
10+
end
11+
12+
# TODO: Define as `Array{T}`.
13+
arraytype(::AbstractArrayInterface, T::Type) = error("Not implemented.")
14+
15+
using ArrayLayouts: ArrayLayouts
16+
17+
@interface ::AbstractArrayInterface function Base.getindex(a::AbstractArray, I...)
18+
return ArrayLayouts.layout_getindex(a, I...)
19+
end
20+
21+
@interface ::AbstractArrayInterface function Base.getindex(a::AbstractArray, I::Int...)
22+
# TODO: Maybe define as `ArrayLayouts.layout_getindex(a, I...)` or
23+
# `invoke(getindex, Tuple{AbstractArray,Vararg{Any}}, a, I...)`.
24+
# TODO: Use `MethodError`?
25+
return error("Not implemented.")
26+
end
27+
28+
@interface ::AbstractArrayInterface function Broadcast.BroadcastStyle(type::Type)
29+
return Broadcast.DefaultArrayStyle{ndims(type)}()
30+
end
31+
32+
@interface interface::AbstractArrayInterface function Base.similar(
33+
a::AbstractArray, T::Type, size::Tuple{Vararg{Int}}
34+
)
35+
# TODO: Maybe define as `Array{T}(undef, size...)` or
36+
# `invoke(Base.similar, Tuple{AbstractArray,Type,Vararg{Int}}, a, T, size)`.
37+
# TODO: Use `MethodError`?
38+
return similar(arraytype(interface, T), size)
39+
end
40+
41+
@interface ::AbstractArrayInterface function Base.copy(a::AbstractArray)
42+
a_dest = similar(a)
43+
return a_dest .= a
44+
end
45+
46+
# TODO: Make this more general, handle mixtures of integers and ranges (`Union{Integer,Base.OneTo}`).
47+
@interface interface::AbstractArrayInterface function Base.similar(
48+
a::AbstractArray, T::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
49+
)
50+
# TODO: Use `Base.to_shape(axes)` or
51+
# `Base.invoke(similar, Tuple{AbstractArray,Type,Tuple{Union{Integer,Base.OneTo},Vararg{Union{Integer,Base.OneTo}}}}, a, T, axes)`.
52+
return @interface interface similar(a, T, Base.to_shape(axes))
53+
end
54+
55+
@interface interface::AbstractArrayInterface function Base.similar(
56+
bc::Broadcast.Broadcasted, T::Type, axes::Tuple
57+
)
58+
# `arraytype(::AbstractInterface)` determines the default array type associated with the interface.
59+
return similar(arraytype(interface, T), axes)
60+
end
61+
62+
using BroadcastMapConversion: map_function, map_args
63+
# TODO: Turn this into an `@interface AbstractArrayInterface` function?
64+
# TODO: Look into `SparseArrays.capturescalars`:
65+
# https://github.com/JuliaSparse/SparseArrays.jl/blob/1beb0e4a4618b0399907b0000c43d9f66d34accc/src/higherorderfns.jl#L1092-L1102
66+
@interface interface::AbstractArrayInterface function Base.copyto!(
67+
dest::AbstractArray, bc::Broadcast.Broadcasted
68+
)
69+
@interface interface map!(map_function(bc), dest, map_args(bc)...)
70+
return dest
71+
end
72+
73+
# This is defined in this way so we can rely on the Broadcast logic
74+
# for determining the destination of the operation (element type, shape, etc.).
75+
@interface ::AbstractArrayInterface function Base.map(f, as::AbstractArray...)
76+
# TODO: Should this be `@interface interface ...`? That doesn't support
77+
# broadcasting yet.
78+
# Broadcasting is used here to determine the destination array but that
79+
# could be done manually here.
80+
return f.(as...)
81+
end
82+
83+
@interface ::AbstractArrayInterface function Base.map!(
84+
f, dest::AbstractArray, as::AbstractArray...
85+
)
86+
# TODO: Maybe define as
87+
# `invoke(Base.map!, Tuple{Any,AbstractArray,Vararg{AbstractArray}}, f, dest, as...)`.
88+
# TODO: Use `MethodError`?
89+
return error("Not implemented.")
90+
end
91+
92+
@interface ::AbstractArrayInterface function Base.permutedims!(
93+
a_dest::AbstractArray, a_src::AbstractArray, perm
94+
)
95+
# TODO: Should this be `@interface interface ...`?
96+
a_dest .= PermutedDimsArray(a_src, perm)
97+
return a_dest
98+
end
99+
100+
using LinearAlgebra: LinearAlgebra
101+
# This then requires overloading:
102+
# function ArrayLayouts.materialize!(
103+
# m::MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout}
104+
# )
105+
# # Matmul implementation.
106+
# end
107+
@interface ::AbstractArrayInterface function LinearAlgebra.mul!(
108+
a_dest::AbstractVecOrMat, a1::AbstractVecOrMat, a2::AbstractVecOrMat, α::Number, β::Number
109+
)
110+
return ArrayLayouts.mul!(a_dest, a1, a2, α, β)
111+
end
112+
113+
@interface ::AbstractArrayInterface function ArrayLayouts.MemoryLayout(type::Type)
114+
# TODO: Define as `UnknownLayout()`?
115+
# TODO: Use `MethodError`?
116+
return error("Not implemented.")
117+
end
118+
119+
## TODO: Define `const AbstractMatrixInterface = AbstractArrayInterface{2}`,
120+
## requires adding `ndims` type parameter to `AbstractArrayInterface`.
121+
## @interface ::AbstractMatrixInterface function Base.*(a1, a2)
122+
## return ArrayLayouts.mul(a1, a2)
123+
## end

src/derive_macro.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,10 @@ function replace_typevars(types::Expr, func::Expr)
115115
typevar, type = @match type_expr begin
116116
:($x = $y) => (x, y)
117117
end
118+
# TODO: Handle type parameters in other positions besides the first one.
118119
new_args = map(args) do arg
119120
return @match arg begin
120-
:(::Type{<:$T}) => T == typevar ? :(::Type{<:$type}) : :(::Type{<:$T})
121+
:(::$Type{<:$T}) => T == typevar ? :(::$Type{<:$type}) : :(::$Type{<:$T})
121122
:(::$T...) => T == typevar ? :(::$type...) : :(::$T...)
122123
:(::$T) => T == typevar ? :(::$type) : :(::$T)
123124
end

src/interface_macro.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,14 @@ end
7878
#=
7979
Rewrite:
8080
```julia
81-
@interface SparseArrayInterface function Base.getindex(a, I::Int...)
81+
@interface interface::SparseArrayInterface function Base.getindex(a, I::Int...)
8282
!isstored(a, I...) && return getunstoredindex(a, I...)
8383
return getstoredindex(a, I...)
8484
end
8585
```
8686
to:
8787
```julia
88-
function Derive.call(::SparseArrayInterface, Base.getindex, a, I::Int...)
88+
function Derive.call(interface::SparseArrayInterface, Base.getindex, a, I::Int...)
8989
!isstored(a, I...) && return getunstoredindex(a, I...)
9090
return getstoredindex(a, I...)
9191
end
@@ -98,7 +98,7 @@ function interface_definition(interface::Union{Symbol,Expr}, func::Expr)
9898
# We use `Core.Typeof` here because `name` can either be a function or type,
9999
# and `typeof(T::Type)` outputs things like `DataType`, `UnionAll`, etc.
100100
# while `Core.Typeof(T::Type)` returns `Type{T}`.
101-
new_args = [:(::$interface); :(::Core.Typeof($name)); args]
101+
new_args = [:($interface); :(::Core.Typeof($name)); args]
102102
return globalref_derive(
103103
codegen_ast(
104104
JLFunction(; name=new_name, args=new_args, kwargs, rettype, whereparams, body)

src/traits.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
# using ArrayLayouts: ArrayLayouts
2-
# using LinearAlgebra: LinearAlgebra
1+
using ArrayLayouts: ArrayLayouts
2+
using LinearAlgebra: LinearAlgebra
33

4+
# TODO: Define an `AbstractMatrixOps` trait, which is where
5+
# matrix multiplication should be defined (both `mul!` and `*`).
46
#=
57
```julia
68
@derive SparseArrayDOK AbstractArrayOps
@@ -9,13 +11,24 @@
911
=#
1012
function derive(::Val{:AbstractArrayOps}, type)
1113
return quote
14+
Base.getindex(::$type, ::Any...)
1215
Base.getindex(::$type, ::Int...)
1316
Base.setindex!(::$type, ::Any, ::Int...)
1417
Base.similar(::$type, ::Type, ::Tuple{Vararg{Int}})
18+
Base.similar(::$type, ::Type, ::Tuple{Base.OneTo,Vararg{Base.OneTo}})
19+
Base.copy(::$type)
1520
Base.map(::Any, ::$type...)
16-
Base.map!(::Any, ::Any, ::$type...)
21+
Base.map!(::Any, ::AbstractArray, ::$type...)
22+
Base.permutedims!(::Any, ::$type, ::Any)
1723
Broadcast.BroadcastStyle(::Type{<:$type})
18-
# ArrayLayouts.MemoryLayout(::Type{<:$type})
19-
# LinearAlgebra.mul!(::Any, ::$type, ::$type, ::Number, ::Number)
24+
ArrayLayouts.MemoryLayout(::Type{<:$type})
25+
LinearAlgebra.mul!(::AbstractMatrix, ::$type, ::$type, ::Number, ::Number)
26+
end
27+
end
28+
29+
function derive(::Val{:AbstractArrayStyleOps}, type)
30+
return quote
31+
Base.similar(::Broadcast.Broadcasted{<:$type}, ::Type, ::Tuple)
32+
Base.copyto!(::AbstractArray, ::Broadcast.Broadcasted{<:$type})
2033
end
2134
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
34
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
5+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
46
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
57
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
68
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

0 commit comments

Comments
 (0)