Skip to content

Commit 5c52fbf

Browse files
authored
Fix defining single line interface functions, add support for cat and slicing (#18)
* Fix defining single line interface functions, add support for `cat` and slicing * Bump to v0.3.5
1 parent f30e57d commit 5c52fbf

File tree

6 files changed

+182
-4
lines changed

6 files changed

+182
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Derive"
22
uuid = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.4"
4+
version = "0.3.5"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractarrayinterface.jl

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ using ArrayLayouts: ArrayLayouts
1919
return ArrayLayouts.layout_getindex(a, I...)
2020
end
2121

22+
@interface interface::AbstractArrayInterface function Base.setindex!(
23+
a::AbstractArray, value, I...
24+
)
25+
# TODO: Change to this once broadcasting in `@interface` is supported:
26+
# @interface interface a[I...] .= value
27+
@interface interface map!(identity, @view(a[I...]), value)
28+
return a
29+
end
30+
2231
# TODO: Maybe define as `ArrayLayouts.layout_getindex(a, I...)` or
2332
# `invoke(getindex, Tuple{AbstractArray,Vararg{Any}}, a, I...)`.
2433
# TODO: Use `MethodError`?
@@ -28,6 +37,27 @@ end
2837
return error("Not implemented.")
2938
end
3039

40+
# TODO: Make this more general, use `Base.to_index`.
41+
@interface interface::AbstractArrayInterface function Base.getindex(
42+
a::AbstractArray{<:Any,N}, I::CartesianIndex{N}
43+
) where {N}
44+
return @interface interface getindex(a, Tuple(I)...)
45+
end
46+
47+
# TODO: Use `MethodError`?
48+
@interface ::AbstractArrayInterface function Base.setindex!(
49+
a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N}
50+
) where {N}
51+
return error("Not implemented.")
52+
end
53+
54+
# TODO: Make this more general, use `Base.to_index`.
55+
@interface interface::AbstractArrayInterface function Base.setindex!(
56+
a::AbstractArray{<:Any,N}, value, I::CartesianIndex{N}
57+
) where {N}
58+
return @interface interface setindex!(a, value, Tuple(I)...)
59+
end
60+
3161
@interface ::AbstractArrayInterface function Broadcast.BroadcastStyle(type::Type)
3262
return Broadcast.DefaultArrayStyle{ndims(type)}()
3363
end
@@ -203,3 +233,94 @@ end
203233
## @interface ::AbstractMatrixInterface function Base.*(a1, a2)
204234
## return ArrayLayouts.mul(a1, a2)
205235
## end
236+
237+
# Concatenation
238+
239+
axis_cat(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2))
240+
function axis_cat(
241+
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange...
242+
)
243+
return axis_cat(axis_cat(a1, a2), a_rest...)
244+
end
245+
246+
unval(x) = x
247+
unval(::Val{x}) where {x} = x
248+
249+
function cat_axes(as::AbstractArray...; dims)
250+
return ntuple(length(first(axes.(as)))) do dim
251+
return if dim in unval(dims)
252+
axis_cat(map(axes -> axes[dim], axes.(as))...)
253+
else
254+
axes(first(as))[dim]
255+
end
256+
end
257+
end
258+
259+
function cat! end
260+
261+
# Represents concatenating `args` over `dims`.
262+
struct Cat{Args<:Tuple{Vararg{AbstractArray}},dims}
263+
args::Args
264+
end
265+
to_cat_dims(dim::Integer) = Int(dim)
266+
to_cat_dims(dim::Int) = (dim,)
267+
to_cat_dims(dims::Val) = to_cat_dims(unval(dims))
268+
to_cat_dims(dims::Tuple) = dims
269+
Cat(args::AbstractArray...; dims) = Cat{typeof(args),to_cat_dims(dims)}(args)
270+
cat_dims(::Cat{<:Any,dims}) where {dims} = dims
271+
272+
function Base.axes(a::Cat)
273+
return cat_axes(a.args...; dims=cat_dims(a))
274+
end
275+
Base.eltype(a::Cat) = promote_type(eltype.(a.args)...)
276+
function Base.similar(a::Cat)
277+
ax = axes(a)
278+
elt = eltype(a)
279+
# TODO: This drops GPU information, maybe use MemoryLayout?
280+
return similar(arraytype(interface(a.args...), elt), ax)
281+
end
282+
283+
# https://github.com/JuliaLang/julia/blob/v1.11.1/base/abstractarray.jl#L1748-L1857
284+
# https://docs.julialang.org/en/v1/base/arrays/#Concatenation-and-permutation
285+
# This is very similar to the `Base.cat` implementation but handles zero values better.
286+
function cat_offset!(
287+
a_dest::AbstractArray, offsets, a1::AbstractArray, a_rest::AbstractArray...; dims
288+
)
289+
inds = ntuple(ndims(a_dest)) do dim
290+
dim in unval(dims) ? offsets[dim] .+ axes(a1, dim) : axes(a_dest, dim)
291+
end
292+
a_dest[inds...] = a1
293+
new_offsets = ntuple(ndims(a_dest)) do dim
294+
dim in unval(dims) ? offsets[dim] + size(a1, dim) : offsets[dim]
295+
end
296+
cat_offset!(a_dest, new_offsets, a_rest...; dims)
297+
return a_dest
298+
end
299+
function cat_offset!(a_dest::AbstractArray, offsets; dims)
300+
return a_dest
301+
end
302+
303+
@interface ::AbstractArrayInterface function cat!(
304+
a_dest::AbstractArray, as::AbstractArray...; dims
305+
)
306+
offsets = ntuple(zero, ndims(a_dest))
307+
# TODO: Fill `a_dest` with zeros if needed using `zero!`.
308+
cat_offset!(a_dest, offsets, as...; dims)
309+
return a_dest
310+
end
311+
312+
@interface interface::AbstractArrayInterface function Base.cat(as::AbstractArray...; dims)
313+
a_dest = similar(Cat(as...; dims))
314+
@interface interface cat!(a_dest, as...; dims)
315+
return a_dest
316+
end
317+
318+
# TODO: Use `@derive` instead:
319+
# ```julia
320+
# @derive (T=AbstractArray,) begin
321+
# cat!(a_dest::AbstractArray, as::T...; dims)
322+
# end
323+
# ```
324+
function cat!(a_dest::AbstractArray, as::AbstractArray...; dims)
325+
return @interface interface(as...) cat!(a_dest, as...; dims)
326+
end

src/abstractinterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
interface(x) = interface(typeof(x))
33
# TODO: Define as `DefaultInterface()`.
44
interface(::Type) = error("Interface unknown.")
5+
interface(x1, x_rest...) = combine_interfaces(x1, x_rest...)
56

67
# Adapted from `Base.Broadcast.combine_styles`.
78
# Get the combined interfaces of the input objects.

src/interface_macro.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,22 @@ macro interface(expr...)
55
return esc(interface_expr(expr...))
66
end
77

8+
# TODO: Use `MLStyle.@match`/`Moshi.@match`.
9+
# f(args...)
10+
iscallexpr(expr) = Meta.isexpr(expr, :call)
11+
# a[I...]
12+
isrefexpr(expr) = Meta.isexpr(expr, :ref)
13+
# a[I...] = value
14+
issetrefexpr(expr) = Meta.isexpr(expr, :(=)) && isrefexpr(expr.args[1])
15+
816
function interface_expr(interface::Union{Symbol,Expr}, func::Expr)
17+
# TODO: Use `MLStyle.@match`/`Moshi.@match`.
918
# f(args...)
10-
Meta.isexpr(func, :call) && return interface_call(interface, func)
19+
iscallexpr(func) && return interface_call(interface, func)
1120
# a[I...]
12-
Meta.isexpr(func, :ref) && return interface_ref(interface, func)
21+
isrefexpr(func) && return interface_ref(interface, func)
1322
# a[I...] = value
14-
Meta.isexpr(func, :(=)) && return interface_setref(interface, func)
23+
issetrefexpr(func) && return interface_setref(interface, func)
1524
# Assume it is a function definition.
1625
return interface_definition(interface, func)
1726
end

src/traits.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
using ArrayLayouts: ArrayLayouts
22
using LinearAlgebra: LinearAlgebra
33

4+
# TODO: Create a macro:
5+
#=
6+
```
7+
@derive_def AbstractArrayOps T begin
8+
Base.getindex(::T, ::Any...)
9+
Base.getindex(::T, ::Int...)
10+
Base.setindex!(::T, ::Any, ::Int...)
11+
Base.similar(::T, ::Type, ::Tuple{Vararg{Int}})
12+
end
13+
```
14+
=#
415
# TODO: Define an `AbstractMatrixOps` trait, which is where
516
# matrix multiplication should be defined (both `mul!` and `*`).
617
#=
@@ -13,6 +24,7 @@ function derive(::Val{:AbstractArrayOps}, type)
1324
return quote
1425
Base.getindex(::$type, ::Any...)
1526
Base.getindex(::$type, ::Int...)
27+
Base.setindex!(::$type, ::Any, ::Any...)
1628
Base.setindex!(::$type, ::Any, ::Int...)
1729
Base.similar(::$type, ::Type, ::Tuple{Vararg{Int}})
1830
Base.similar(::$type, ::Type, ::Tuple{Base.OneTo,Vararg{Base.OneTo}})
@@ -33,6 +45,7 @@ function derive(::Val{:AbstractArrayOps}, type)
3345
Base.permutedims!(::Any, ::$type, ::Any)
3446
Broadcast.BroadcastStyle(::Type{<:$type})
3547
Base.copyto!(::$type, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}})
48+
Base.cat(::$type...; kwargs...)
3649
ArrayLayouts.MemoryLayout(::Type{<:$type})
3750
LinearAlgebra.mul!(::AbstractMatrix, ::$type, ::$type, ::Number, ::Number)
3851
end

test/basics/test_basics.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
8787
a[1, 2] = 12
8888
b = similar(a)
8989
copyto!(b, a)
90+
@test b isa SparseArrayDOK{elt,2}
9091
@test b == a
9192
@test b[1, 2] == 12
9293
@test storedlength(b) == 1
@@ -114,6 +115,39 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
114115
a = SparseArrayDOK{elt}(2, 2)
115116
a[1, 2] = 12
116117
b = zero(a)
118+
@test b isa SparseArrayDOK{elt,2}
117119
@test iszero(b)
118120
@test iszero(storedlength(b))
121+
122+
a = SparseArrayDOK{elt}(2, 2)
123+
a[1, 2] = 12
124+
b = SparseArrayDOK{elt}(4, 4)
125+
b[2:3, 2:3] .= a
126+
@test isone(storedlength(b))
127+
@test b[2, 3] == 12
128+
129+
a = SparseArrayDOK{elt}(2, 2)
130+
a[1, 2] = 12
131+
b = SparseArrayDOK{elt}(4, 4)
132+
b[2:3, 2:3] = a
133+
@test isone(storedlength(b))
134+
@test b[2, 3] == 12
135+
136+
a = SparseArrayDOK{elt}(2, 2)
137+
a[1, 2] = 12
138+
b = SparseArrayDOK{elt}(4, 4)
139+
c = @view b[2:3, 2:3]
140+
c .= a
141+
@test isone(storedlength(b))
142+
@test b[2, 3] == 12
143+
144+
a1 = SparseArrayDOK{elt}(2, 2)
145+
a1[1, 2] = 12
146+
a2 = SparseArrayDOK{elt}(2, 2)
147+
a2[2, 1] = 21
148+
b = cat(a1, a2; dims=(1, 2))
149+
@test b isa SparseArrayDOK{elt,2}
150+
@test storedlength(b) == 2
151+
@test b[1, 2] == 12
152+
@test b[4, 3] == 21
119153
end

0 commit comments

Comments
 (0)