Skip to content

Commit 14cb954

Browse files
authored
Rework cat implementation + add zero! (#18)
This PR enables a proper implementation/specialization point for implementing concatenations. Simultaneously, we define an overloadable `zero!` function.
1 parent f691e77 commit 14cb954

14 files changed

+194
-124
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
name: "Integration Test Request"
2+
3+
on:
4+
issue_comment:
5+
types: [created]
6+
7+
jobs:
8+
integrationrequest:
9+
if: |
10+
github.event.issue.pull_request &&
11+
contains(fromJSON('["OWNER", "COLLABORATOR", "MEMBER"]'), github.event.comment.author_association)
12+
uses: ITensor/ITensorActions/.github/workflows/IntegrationTestRequest.yml@main
13+
with:
14+
localregistry: https://github.com/ITensor/ITensorRegistry.git

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "DerivableInterfaces"
22
uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.17"
4+
version = "0.4.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
9+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
910
ExproniconLite = "55351af7-c7e9-48d6-89ff-24e801d99491"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
@@ -15,6 +16,7 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1516
[compat]
1617
Adapt = "4.1.1"
1718
ArrayLayouts = "1.11.0"
19+
Compat = "3.47,4.10"
1820
ExproniconLite = "0.10.13"
1921
LinearAlgebra = "1.10"
2022
MLStyle = "0.4.17"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
55

66
[compat]
7-
DerivableInterfaces = "0.3"
7+
DerivableInterfaces = "0.4"
88
Documenter = "1"
99
Literate = "2"

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ makedocs(;
1616
edit_link="main",
1717
assets=String[],
1818
),
19-
pages=["Home" => "index.md"],
19+
pages=["Home" => "index.md", "Reference" => "reference.md"],
2020
)
2121

2222
deploydocs(;

docs/src/reference.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Reference
2+
3+
```@autodocs
4+
Modules = [DerivableInterfaces, DerivableInterfaces.Concatenate]
5+
```

src/DerivableInterfaces.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
module DerivableInterfaces
22

3+
export Concatenate
4+
35
include("interface_function.jl")
46
include("abstractinterface.jl")
57
include("derive_macro.jl")
68
include("interface_macro.jl")
79
include("wrappedarrays.jl")
10+
11+
include("zero.jl")
812
include("abstractarrayinterface.jl")
13+
include("concatenate.jl")
914
include("defaultarrayinterface.jl")
1015
include("traits.jl")
1116

src/abstractarrayinterface.jl

Lines changed: 5 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -153,21 +153,12 @@ end
153153
@interface interface map!(Returns(value), a, a)
154154
end
155155

156-
using ArrayLayouts: zero!
156+
# TODO: should this be recursive? `map!(zero!, A, A)` might also work?
157+
@interface ::AbstractArrayInterface DerivableInterfaces.zero!(A::AbstractArray) =
158+
fill!(A, zero(eltype(A)))
157159

158-
# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts`
159-
# and is useful for sparse array logic, since it can be used to empty
160-
# the sparse array storage.
161-
# We use a single function definition to minimize method ambiguities.
162-
@interface interface::AbstractArrayInterface function ArrayLayouts.zero!(a::AbstractArray)
163-
# More generally, the first codepath could be taking if `zero(eltype(a))`
164-
# is defined and the elements are immutable.
165-
f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero!
166-
return @interface interface map!(f, a, a)
167-
end
168-
169-
# Specialized version of `Base.zero` written in terms of `ArrayLayouts.zero!`.
170-
# This is friendlier for sparse arrays since `ArrayLayouts.zero!` makes it easier
160+
# Specialized version of `Base.zero` written in terms of `zero!`.
161+
# This is friendlier for sparse arrays since `zero!` makes it easier
171162
# to handle the logic of dropping all elements of the sparse array when possible.
172163
# We use a single function definition to minimize method ambiguities.
173164
@interface interface::AbstractArrayInterface function Base.zero(a::AbstractArray)
@@ -250,102 +241,3 @@ end
250241
## @interface ::AbstractMatrixInterface function Base.*(a1, a2)
251242
## return ArrayLayouts.mul(a1, a2)
252243
## end
253-
254-
# Concatenation
255-
256-
axis_cat(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2))
257-
function axis_cat(
258-
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange...
259-
)
260-
return axis_cat(axis_cat(a1, a2), a_rest...)
261-
end
262-
263-
unval(x) = x
264-
unval(::Val{x}) where {x} = x
265-
266-
function cat_axes(as::AbstractArray...; dims)
267-
return ntuple(length(first(axes.(as)))) do dim
268-
return if dim in unval(dims)
269-
axis_cat(map(axes -> axes[dim], axes.(as))...)
270-
else
271-
axes(first(as))[dim]
272-
end
273-
end
274-
end
275-
276-
function cat! end
277-
278-
# Represents concatenating `args` over `dims`.
279-
struct Cat{Args<:Tuple{Vararg{AbstractArray}},dims}
280-
args::Args
281-
end
282-
to_cat_dims(dim::Integer) = Int(dim)
283-
to_cat_dims(dim::Int) = (dim,)
284-
to_cat_dims(dims::Val) = to_cat_dims(unval(dims))
285-
to_cat_dims(dims::Tuple) = dims
286-
Cat(args::AbstractArray...; dims) = Cat{typeof(args),to_cat_dims(dims)}(args)
287-
cat_dims(::Cat{<:Any,dims}) where {dims} = dims
288-
289-
function Base.axes(a::Cat)
290-
return cat_axes(a.args...; dims=cat_dims(a))
291-
end
292-
Base.eltype(a::Cat) = promote_type(eltype.(a.args)...)
293-
function Base.similar(a::Cat)
294-
ax = axes(a)
295-
elt = eltype(a)
296-
# TODO: This drops GPU information, maybe use MemoryLayout?
297-
return similar(arraytype(interface(a.args...), elt), ax)
298-
end
299-
300-
# https://github.com/JuliaLang/julia/blob/v1.11.1/base/abstractarray.jl#L1748-L1857
301-
# https://docs.julialang.org/en/v1/base/arrays/#Concatenation-and-permutation
302-
# This is very similar to the `Base.cat` implementation but handles zero values better.
303-
function cat_offset!(
304-
a_dest::AbstractArray, offsets, a1::AbstractArray, a_rest::AbstractArray...; dims
305-
)
306-
inds = ntuple(ndims(a_dest)) do dim
307-
dim in unval(dims) ? offsets[dim] .+ axes(a1, dim) : axes(a_dest, dim)
308-
end
309-
a_dest[inds...] = a1
310-
new_offsets = ntuple(ndims(a_dest)) do dim
311-
dim in unval(dims) ? offsets[dim] + size(a1, dim) : offsets[dim]
312-
end
313-
cat_offset!(a_dest, new_offsets, a_rest...; dims)
314-
return a_dest
315-
end
316-
function cat_offset!(a_dest::AbstractArray, offsets; dims)
317-
return a_dest
318-
end
319-
320-
@interface ::AbstractArrayInterface function cat!(
321-
a_dest::AbstractArray, as::AbstractArray...; dims
322-
)
323-
offsets = ntuple(zero, ndims(a_dest))
324-
# TODO: Fill `a_dest` with zeros if needed using `zero!`.
325-
cat_offset!(a_dest, offsets, as...; dims)
326-
return a_dest
327-
end
328-
329-
function cat_along(dims, as::AbstractArray...)
330-
return @interface interface(as...) cat_along(dims, as...)
331-
end
332-
333-
@interface interface::AbstractArrayInterface function cat_along(dims, as::AbstractArray...)
334-
a_dest = similar(Cat(as...; dims))
335-
@interface interface cat!(a_dest, as...; dims)
336-
return a_dest
337-
end
338-
339-
@interface interface::AbstractArrayInterface function Base.cat(as::AbstractArray...; dims)
340-
return @interface interface cat_along(dims, as...)
341-
end
342-
343-
# TODO: Use `@derive` instead:
344-
# ```julia
345-
# @derive (T=AbstractArray,) begin
346-
# cat!(a_dest::AbstractArray, as::T...; dims)
347-
# end
348-
# ```
349-
function cat!(a_dest::AbstractArray, as::AbstractArray...; dims)
350-
return @interface interface(as...) cat!(a_dest, as...; dims)
351-
end

src/concatenate.jl

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""
2+
module Concatenate
3+
4+
Alternative implementation for `Base.cat` through [`cat(!)`](@ref cat).
5+
6+
This is mostly a copy of the Base implementation, with the main difference being
7+
that the destination is chosen based on all inputs instead of just the first.
8+
9+
Additionally, we have an intermediate representation in terms of a Concatenated object,
10+
reminiscent of how Broadcast works.
11+
12+
The various entry points for specializing behavior are:
13+
14+
* Destination selection can be achieved through
15+
16+
Base.similar(concat::Concatenated{Interface}, ::Type{T}, axes) where {Interface}
17+
18+
* Custom implementations:
19+
20+
Base.copy(concat::Concatenated{Interface}) # custom implementation of cat
21+
Base.copyto!(dest, concat::Concatenated{Interface}) # custom implementation of cat! based on interface
22+
Base.copyto!(dest, concat::Concatenated{Nothing}) # custom implementation of cat! based on typeof(dest)
23+
"""
24+
module Concatenate
25+
26+
using Compat: @compat
27+
export concatenate
28+
@compat public Concatenated, cat, cat!, concatenated
29+
30+
using Base: promote_eltypeof
31+
using ..DerivableInterfaces:
32+
DerivableInterfaces, AbstractInterface, interface, zero!, arraytype
33+
34+
"""
35+
Concatenated{Interface,Dims,Args<:Tuple}
36+
37+
Lazy representation of the concatenation of various `Args` along `Dims`, in order to provide
38+
hooks to customize the implementation.
39+
"""
40+
struct Concatenated{Interface,Dims,Args<:Tuple}
41+
interface::Interface
42+
dims::Val{Dims}
43+
args::Args
44+
45+
function Concatenated(
46+
interface::Union{Nothing,AbstractInterface}, dims::Val{Dims}, args::Tuple
47+
) where {Dims}
48+
return new{typeof(interface),Dims,typeof(args)}(interface, dims, args)
49+
end
50+
function Concatenated(dims, args::Tuple)
51+
return Concatenated(interface(args...), dims, args)
52+
end
53+
function Concatenated{Interface}(dims, args) where {Interface}
54+
return Concatenated(Interface(), dims, args)
55+
end
56+
function Concatenated{Interface,Dims}(args) where {Interface,Dims}
57+
return new{Interface,Dims,typeof(args)}(Interface(), Val(Dims), args)
58+
end
59+
end
60+
61+
dims(::Concatenated{A,D}) where {A,D} = D
62+
DerivableInterfaces.interface(concat::Concatenated) = concat.interface
63+
64+
concatenated(dims, args...) = concatenated(Val(dims), args...)
65+
concatenated(dims::Val, args...) = Concatenated(dims, args)
66+
67+
function Base.convert(
68+
::Type{Concatenated{NewInterface}}, concat::Concatenated{<:Any,Dims,Args}
69+
) where {NewInterface,Dims,Args}
70+
return Concatenated{NewInterface}(
71+
concat.dims, concat.args
72+
)::Concatenated{NewInterface,Dims,Args}
73+
end
74+
75+
# allocating the destination container
76+
# ------------------------------------
77+
Base.similar(concat::Concatenated) = similar(concat, eltype(concat))
78+
Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(concat))
79+
function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T}
80+
return similar(arraytype(interface(concat), T), ax)
81+
end
82+
83+
Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...)
84+
85+
# For now, simply couple back to base implementation
86+
function Base.axes(concat::Concatenated)
87+
catdims = Base.dims2cat(dims(concat))
88+
return Base.cat_size_shape(catdims, concat.args...)
89+
end
90+
91+
# Main logic
92+
# ----------
93+
"""
94+
concatenate(dims, args...)
95+
96+
Concatenate the supplied `args` along dimensions `dims`.
97+
98+
See also [`cat`] and [`cat!`](@ref).
99+
"""
100+
concatenate(dims, args...) = Base.materialize(concatenated(dims, args...))
101+
102+
"""
103+
Concatenate.cat(args...; dims)
104+
105+
Concatenate the supplied `args` along dimensions `dims`.
106+
107+
See also [`concatenate`] and [`cat!`](@ref).
108+
"""
109+
cat(args...; dims) = concatenate(dims, args...)
110+
Base.materialize(concat::Concatenated) = copy(concat)
111+
112+
"""
113+
Concatenate.cat!(dest, args...; dims)
114+
115+
Concatenate the supplied `args` along dimensions `dims`, placing the result into `dest`.
116+
"""
117+
function cat!(dest, args...; dims)
118+
Base.materialize!(dest, concatenated(dims, args...))
119+
return dest
120+
end
121+
Base.materialize!(dest, concat::Concatenated) = copyto!(dest, concat)
122+
123+
Base.copy(concat::Concatenated) = copyto!(similar(concat), concat)
124+
125+
# default falls back to replacing interface with Nothing
126+
# this permits specializing on typeof(dest) without ambiguities
127+
# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base.
128+
@inline Base.copyto!(dest::AbstractArray, concat::Concatenated) =
129+
copyto!(dest, convert(Concatenated{Nothing}, concat))
130+
131+
# couple back to Base implementation if no specialization exists:
132+
# https://github.com/JuliaLang/julia/blob/29da86bb983066dd076439c2c7bc5e28dbd611bb/base/abstractarray.jl#L1852
133+
function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing})
134+
catdims = Base.dims2cat(dims(concat))
135+
shape = Base.cat_size_shape(catdims, concat.args...)
136+
count(!iszero, catdims)::Int > 1 && zero!(dest)
137+
return Base.__cat(dest, shape, catdims, concat.args...)
138+
end
139+
140+
end

src/traits.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,11 @@ function derive(::Val{:AbstractArrayOps}, type)
4040
Base.iszero(::$type)
4141
Base.real(::$type)
4242
Base.fill!(::$type, ::Any)
43-
ArrayLayouts.zero!(::$type)
43+
DerivableInterfaces.zero!(::$type)
4444
Base.zero(::$type)
4545
Base.permutedims!(::Any, ::$type, ::Any)
4646
Broadcast.BroadcastStyle(::Type{<:$type})
4747
Base.copyto!(::$type, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}})
48-
Base.cat(::$type...; kwargs...)
4948
ArrayLayouts.MemoryLayout(::Type{<:$type})
5049
LinearAlgebra.mul!(::AbstractMatrix, ::$type, ::$type, ::Number, ::Number)
5150
end

src/zero.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""
2+
zero!(x::AbstractArray)
3+
4+
In-place version of `Base.zero`.
5+
"""
6+
function zero! end

0 commit comments

Comments
 (0)