@@ -31,6 +31,11 @@ using Base: promote_eltypeof
3131using .. DerivableInterfaces:
3232 DerivableInterfaces, AbstractInterface, interface, zero!, arraytype
3333
34+ unval (x) = x
35+ unval (:: Val{x} ) where {x} = x
36+
37+ function _Concatenated end
38+
3439"""
3540 Concatenated{Interface,Dims,Args<:Tuple}
3641
@@ -41,25 +46,25 @@ struct Concatenated{Interface,Dims,Args<:Tuple}
4146 interface:: Interface
4247 dims:: Val{Dims}
4348 args:: Args
44-
45- function Concatenated (
46- interface:: Union{Nothing,AbstractInterface} , dims:: Val{Dims} , args:: Tuple
47- ) where {Dims}
48- return new {typeof(interface),Dims,typeof(args)} (interface, dims, args)
49- end
50- function Concatenated (dims, args:: Tuple )
51- return Concatenated (interface (args... ), dims, args)
52- end
53- function Concatenated {Interface} (dims, args) where {Interface}
54- return Concatenated (Interface (), dims, args)
55- end
56- function Concatenated {Interface,Dims} (args) where {Interface,Dims}
57- return new {Interface,Dims,typeof(args)} (Interface (), Val (Dims), args)
49+ global @inline function _Concatenated (
50+ interface:: Interface , dims:: Val{Dims} , args:: Args
51+ ) where {Interface,Dims,Args<: Tuple }
52+ return new {Interface,Dims,Args} (interface, dims, args)
5853 end
5954end
6055
61- dims (:: Concatenated{A,D} ) where {A,D} = D
62- DerivableInterfaces. interface (concat:: Concatenated ) = concat. interface
56+ function Concatenated (interface:: Union{Nothing,AbstractInterface} , dims:: Val , args:: Tuple )
57+ return _Concatenated (interface, dims, args)
58+ end
59+ function Concatenated (dims:: Val , args:: Tuple )
60+ return Concatenated (interface (args... ), dims, args)
61+ end
62+ function Concatenated {Interface} (dims:: Val , args:: Tuple ) where {Interface}
63+ return Concatenated (Interface (), dims, args)
64+ end
65+
66+ dims (:: Concatenated{<:Any,D} ) where {D} = D
67+ DerivableInterfaces. interface (concat:: Concatenated ) = getfield (concat, :interface )
6368
6469concatenated (dims, args... ) = concatenated (Val (dims), args... )
6570concatenated (dims:: Val , args... ) = Concatenated (dims, args)
@@ -80,13 +85,33 @@ function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T}
8085 return similar (arraytype (interface (concat), T), ax)
8186end
8287
83- Base. eltype (concat:: Concatenated ) = promote_eltypeof (concat. args... )
88+ function cat_axis (
89+ a1:: AbstractUnitRange , a2:: AbstractUnitRange , a_rest:: AbstractUnitRange...
90+ )
91+ return cat_axis (cat_axis (a1, a2), a_rest... )
92+ end
93+ cat_axis (a1:: AbstractUnitRange , a2:: AbstractUnitRange ) = Base. OneTo (length (a1) + length (a2))
8494
85- # For now, simply couple back to base implementation
86- function Base. axes (concat:: Concatenated )
87- catdims = Base. dims2cat (dims (concat))
88- return Base. cat_size_shape (catdims, concat. args... )
95+ function cat_ndims (dims, as:: AbstractArray... )
96+ return max (maximum (dims), maximum (ndims, as))
97+ end
98+ function cat_ndims (dims:: Val , as:: AbstractArray... )
99+ return cat_ndims (unval (dims), as... )
100+ end
101+
102+ function cat_axes (dims, a:: AbstractArray , as:: AbstractArray... )
103+ return ntuple (cat_ndims (dims, a, as... )) do dim
104+ return dim in dims ? cat_axis (map (Base. Fix2 (axes, dim), (a, as... ))... ) : axes (a, dim)
105+ end
89106end
107+ function cat_axes (dims:: Val , as:: AbstractArray... )
108+ return cat_axes (unval (dims), as... )
109+ end
110+
111+ Base. eltype (concat:: Concatenated ) = promote_eltypeof (concat. args... )
112+ Base. axes (concat:: Concatenated ) = cat_axes (dims (concat), concat. args... )
113+ Base. size (concat:: Concatenated ) = length .(axes (concat))
114+ Base. ndims (concat:: Concatenated ) = length (axes (concat))
90115
91116# Main logic
92117# ----------
@@ -122,19 +147,59 @@ Base.materialize!(dest, concat::Concatenated) = copyto!(dest, concat)
122147
123148Base. copy (concat:: Concatenated ) = copyto! (similar (concat), concat)
124149
150+ # The following is largely copied from the Base implementation of `Base.cat`, see:
151+ # https://github.com/JuliaLang/julia/blob/885b1cd875f101f227b345f681cc36879124d80d/base/abstractarray.jl#L1778-L1887
152+ _copy_or_fill! (A, inds, x) = fill! (view (A, inds... ), x)
153+ _copy_or_fill! (A, inds, x:: AbstractArray ) = (A[inds... ] = x)
154+
155+ cat_size (A) = (1 ,)
156+ cat_size (A:: AbstractArray ) = size (A)
157+ cat_size (A, d) = 1
158+ cat_size (A:: AbstractArray , d) = size (A, d)
159+
160+ cat_indices (A, d) = Base. OneTo (1 )
161+ cat_indices (A:: AbstractArray , d) = axes (A, d)
162+
163+ function __cat! (A, shape, catdims, X... )
164+ return __cat_offset! (A, shape, catdims, ntuple (zero, length (shape)), X... )
165+ end
166+ function __cat_offset! (A, shape, catdims, offsets, x, X... )
167+ # splitting the "work" on x from X... may reduce latency (fewer costly specializations)
168+ newoffsets = __cat_offset1! (A, shape, catdims, offsets, x)
169+ return __cat_offset! (A, shape, catdims, newoffsets, X... )
170+ end
171+ __cat_offset! (A, shape, catdims, offsets) = A
172+ function __cat_offset1! (A, shape, catdims, offsets, x)
173+ inds = ntuple (length (offsets)) do i
174+ (i <= length (catdims) && catdims[i]) ? offsets[i] .+ cat_indices (x, i) : 1 : shape[i]
175+ end
176+ _copy_or_fill! (A, inds, x)
177+ newoffsets = ntuple (length (offsets)) do i
178+ (i <= length (catdims) && catdims[i]) ? offsets[i] + cat_size (x, i) : offsets[i]
179+ end
180+ return newoffsets
181+ end
182+
183+ dims2cat (dims:: Val ) = dims2cat (unval (dims))
184+ function dims2cat (dims)
185+ if any (≤ (0 ), dims)
186+ throw (ArgumentError (" All cat dimensions must be positive integers, but got $dims " ))
187+ end
188+ return ntuple (in (dims), maximum (dims))
189+ end
190+
125191# default falls back to replacing interface with Nothing
126192# this permits specializing on typeof(dest) without ambiguities
127193# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base.
128- @inline Base. copyto! (dest:: AbstractArray , concat:: Concatenated ) =
129- copyto! (dest, convert (Concatenated{Nothing}, concat))
194+ @inline function Base. copyto! (dest:: AbstractArray , concat:: Concatenated )
195+ return copyto! (dest, convert (Concatenated{Nothing}, concat))
196+ end
130197
131- # couple back to Base implementation if no specialization exists:
132- # https://github.com/JuliaLang/julia/blob/29da86bb983066dd076439c2c7bc5e28dbd611bb/base/abstractarray.jl#L1852
133198function Base. copyto! (dest:: AbstractArray , concat:: Concatenated{Nothing} )
134- catdims = Base . dims2cat (dims (concat))
135- shape = Base . cat_size_shape (catdims, concat. args ... )
199+ catdims = dims2cat (dims (concat))
200+ shape = size ( concat)
136201 count (! iszero, catdims):: Int > 1 && zero! (dest)
137- return Base . __cat (dest, shape, catdims, concat. args... )
202+ return __cat! (dest, shape, catdims, concat. args... )
138203end
139204
140205end
0 commit comments