Skip to content

Commit 1f15687

Browse files
authored
More customization points in Concatenate (#32)
1 parent 14cb954 commit 1f15687

File tree

6 files changed

+151
-30
lines changed

6 files changed

+151
-30
lines changed

Project.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DerivableInterfaces"
22
uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.4.0"
4+
version = "0.4.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -13,9 +13,16 @@ MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
1313
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
1414
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1515

16+
[weakdeps]
17+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
18+
19+
[extensions]
20+
DerivableInterfacesBlockArraysExt = "BlockArrays"
21+
1622
[compat]
1723
Adapt = "4.1.1"
18-
ArrayLayouts = "1.11.0"
24+
ArrayLayouts = "1.11"
25+
BlockArrays = "1.4"
1926
Compat = "3.47,4.10"
2027
ExproniconLite = "0.10.13"
2128
LinearAlgebra = "1.10"
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
module DerivableInterfacesBlockArraysExt
2+
3+
using BlockArrays: BlockedOneTo, blockedrange, blocklengths
4+
using DerivableInterfaces.Concatenate: Concatenate
5+
6+
function Concatenate.cat_axis(a1::BlockedOneTo, a2::BlockedOneTo)
7+
return blockedrange([blocklengths(a1); blocklengths(a2)])
8+
end
9+
10+
end

src/concatenate.jl

Lines changed: 93 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ using Base: promote_eltypeof
3131
using ..DerivableInterfaces:
3232
DerivableInterfaces, AbstractInterface, interface, zero!, arraytype
3333

34+
unval(x) = x
35+
unval(::Val{x}) where {x} = x
36+
37+
function _Concatenated end
38+
3439
"""
3540
Concatenated{Interface,Dims,Args<:Tuple}
3641
@@ -41,25 +46,25 @@ struct Concatenated{Interface,Dims,Args<:Tuple}
4146
interface::Interface
4247
dims::Val{Dims}
4348
args::Args
44-
45-
function Concatenated(
46-
interface::Union{Nothing,AbstractInterface}, dims::Val{Dims}, args::Tuple
47-
) where {Dims}
48-
return new{typeof(interface),Dims,typeof(args)}(interface, dims, args)
49-
end
50-
function Concatenated(dims, args::Tuple)
51-
return Concatenated(interface(args...), dims, args)
52-
end
53-
function Concatenated{Interface}(dims, args) where {Interface}
54-
return Concatenated(Interface(), dims, args)
55-
end
56-
function Concatenated{Interface,Dims}(args) where {Interface,Dims}
57-
return new{Interface,Dims,typeof(args)}(Interface(), Val(Dims), args)
49+
global @inline function _Concatenated(
50+
interface::Interface, dims::Val{Dims}, args::Args
51+
) where {Interface,Dims,Args<:Tuple}
52+
return new{Interface,Dims,Args}(interface, dims, args)
5853
end
5954
end
6055

61-
dims(::Concatenated{A,D}) where {A,D} = D
62-
DerivableInterfaces.interface(concat::Concatenated) = concat.interface
56+
function Concatenated(interface::Union{Nothing,AbstractInterface}, dims::Val, args::Tuple)
57+
return _Concatenated(interface, dims, args)
58+
end
59+
function Concatenated(dims::Val, args::Tuple)
60+
return Concatenated(interface(args...), dims, args)
61+
end
62+
function Concatenated{Interface}(dims::Val, args::Tuple) where {Interface}
63+
return Concatenated(Interface(), dims, args)
64+
end
65+
66+
dims(::Concatenated{<:Any,D}) where {D} = D
67+
DerivableInterfaces.interface(concat::Concatenated) = getfield(concat, :interface)
6368

6469
concatenated(dims, args...) = concatenated(Val(dims), args...)
6570
concatenated(dims::Val, args...) = Concatenated(dims, args)
@@ -80,13 +85,33 @@ function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T}
8085
return similar(arraytype(interface(concat), T), ax)
8186
end
8287

83-
Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...)
88+
function cat_axis(
89+
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange...
90+
)
91+
return cat_axis(cat_axis(a1, a2), a_rest...)
92+
end
93+
cat_axis(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2))
8494

85-
# For now, simply couple back to base implementation
86-
function Base.axes(concat::Concatenated)
87-
catdims = Base.dims2cat(dims(concat))
88-
return Base.cat_size_shape(catdims, concat.args...)
95+
function cat_ndims(dims, as::AbstractArray...)
96+
return max(maximum(dims), maximum(ndims, as))
97+
end
98+
function cat_ndims(dims::Val, as::AbstractArray...)
99+
return cat_ndims(unval(dims), as...)
100+
end
101+
102+
function cat_axes(dims, a::AbstractArray, as::AbstractArray...)
103+
return ntuple(cat_ndims(dims, a, as...)) do dim
104+
return dim in dims ? cat_axis(map(Base.Fix2(axes, dim), (a, as...))...) : axes(a, dim)
105+
end
89106
end
107+
function cat_axes(dims::Val, as::AbstractArray...)
108+
return cat_axes(unval(dims), as...)
109+
end
110+
111+
Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...)
112+
Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...)
113+
Base.size(concat::Concatenated) = length.(axes(concat))
114+
Base.ndims(concat::Concatenated) = length(axes(concat))
90115

91116
# Main logic
92117
# ----------
@@ -122,19 +147,59 @@ Base.materialize!(dest, concat::Concatenated) = copyto!(dest, concat)
122147

123148
Base.copy(concat::Concatenated) = copyto!(similar(concat), concat)
124149

150+
# The following is largely copied from the Base implementation of `Base.cat`, see:
151+
# https://github.com/JuliaLang/julia/blob/885b1cd875f101f227b345f681cc36879124d80d/base/abstractarray.jl#L1778-L1887
152+
_copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x)
153+
_copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x)
154+
155+
cat_size(A) = (1,)
156+
cat_size(A::AbstractArray) = size(A)
157+
cat_size(A, d) = 1
158+
cat_size(A::AbstractArray, d) = size(A, d)
159+
160+
cat_indices(A, d) = Base.OneTo(1)
161+
cat_indices(A::AbstractArray, d) = axes(A, d)
162+
163+
function __cat!(A, shape, catdims, X...)
164+
return __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...)
165+
end
166+
function __cat_offset!(A, shape, catdims, offsets, x, X...)
167+
# splitting the "work" on x from X... may reduce latency (fewer costly specializations)
168+
newoffsets = __cat_offset1!(A, shape, catdims, offsets, x)
169+
return __cat_offset!(A, shape, catdims, newoffsets, X...)
170+
end
171+
__cat_offset!(A, shape, catdims, offsets) = A
172+
function __cat_offset1!(A, shape, catdims, offsets, x)
173+
inds = ntuple(length(offsets)) do i
174+
(i <= length(catdims) && catdims[i]) ? offsets[i] .+ cat_indices(x, i) : 1:shape[i]
175+
end
176+
_copy_or_fill!(A, inds, x)
177+
newoffsets = ntuple(length(offsets)) do i
178+
(i <= length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i]
179+
end
180+
return newoffsets
181+
end
182+
183+
dims2cat(dims::Val) = dims2cat(unval(dims))
184+
function dims2cat(dims)
185+
if any((0), dims)
186+
throw(ArgumentError("All cat dimensions must be positive integers, but got $dims"))
187+
end
188+
return ntuple(in(dims), maximum(dims))
189+
end
190+
125191
# default falls back to replacing interface with Nothing
126192
# this permits specializing on typeof(dest) without ambiguities
127193
# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base.
128-
@inline Base.copyto!(dest::AbstractArray, concat::Concatenated) =
129-
copyto!(dest, convert(Concatenated{Nothing}, concat))
194+
@inline function Base.copyto!(dest::AbstractArray, concat::Concatenated)
195+
return copyto!(dest, convert(Concatenated{Nothing}, concat))
196+
end
130197

131-
# couple back to Base implementation if no specialization exists:
132-
# https://github.com/JuliaLang/julia/blob/29da86bb983066dd076439c2c7bc5e28dbd611bb/base/abstractarray.jl#L1852
133198
function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing})
134-
catdims = Base.dims2cat(dims(concat))
135-
shape = Base.cat_size_shape(catdims, concat.args...)
199+
catdims = dims2cat(dims(concat))
200+
shape = size(concat)
136201
count(!iszero, catdims)::Int > 1 && zero!(dest)
137-
return Base.__cat(dest, shape, catdims, concat.args...)
202+
return __cat!(dest, shape, catdims, concat.args...)
138203
end
139204

140205
end

src/defaultarrayinterface.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,7 @@ end
3030
)
3131
return Base.mapreduce(f, op, as...; kwargs...)
3232
end
33+
34+
function arraytype(::DefaultArrayInterface, T::Type)
35+
return Array{T}
36+
end

src/zero.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,7 @@
44
In-place version of `Base.zero`.
55
"""
66
function zero! end
7+
8+
@derive (T=AbstractArray,) begin
9+
DerivableInterfaces.zero!(::T)
10+
end

test/test_concatenate.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using DerivableInterfaces.Concatenate: concatenated
2+
using Test: @test, @testset
3+
4+
@testset "Concatenated" begin
5+
a = randn(Float32, 2, 2)
6+
b = randn(Float64, 2, 2)
7+
8+
concat = concatenated((1, 2), a, b)
9+
@test axes(concat) == Base.OneTo.((4, 4))
10+
@test size(concat) == (4, 4)
11+
@test eltype(concat) === Float64
12+
@test copy(concat) == cat(a, b; dims=(1, 2))
13+
14+
concat = concatenated(1, a, b)
15+
@test axes(concat) == Base.OneTo.((4, 2))
16+
@test size(concat) == (4, 2)
17+
@test eltype(concat) === Float64
18+
@test copy(concat) == cat(a, b; dims=1)
19+
20+
concat = concatenated(3, a, b)
21+
@test axes(concat) == Base.OneTo.((2, 2, 2))
22+
@test size(concat) == (2, 2, 2)
23+
@test eltype(concat) === Float64
24+
@test copy(concat) == cat(a, b; dims=3)
25+
26+
concat = concatenated(4, a, b)
27+
@test axes(concat) == Base.OneTo.((2, 2, 1, 2))
28+
@test size(concat) == (2, 2, 1, 2)
29+
@test eltype(concat) === Float64
30+
@test copy(concat) == cat(a, b; dims=4)
31+
end

0 commit comments

Comments
 (0)