@@ -31,6 +31,11 @@ using Base: promote_eltypeof
31
31
using .. DerivableInterfaces:
32
32
DerivableInterfaces, AbstractInterface, interface, zero!, arraytype
33
33
34
+ unval (x) = x
35
+ unval (:: Val{x} ) where {x} = x
36
+
37
+ function _Concatenated end
38
+
34
39
"""
35
40
Concatenated{Interface,Dims,Args<:Tuple}
36
41
@@ -41,25 +46,25 @@ struct Concatenated{Interface,Dims,Args<:Tuple}
41
46
interface:: Interface
42
47
dims:: Val{Dims}
43
48
args:: Args
44
-
45
- function Concatenated (
46
- interface:: Union{Nothing,AbstractInterface} , dims:: Val{Dims} , args:: Tuple
47
- ) where {Dims}
48
- return new {typeof(interface),Dims,typeof(args)} (interface, dims, args)
49
- end
50
- function Concatenated (dims, args:: Tuple )
51
- return Concatenated (interface (args... ), dims, args)
52
- end
53
- function Concatenated {Interface} (dims, args) where {Interface}
54
- return Concatenated (Interface (), dims, args)
55
- end
56
- function Concatenated {Interface,Dims} (args) where {Interface,Dims}
57
- return new {Interface,Dims,typeof(args)} (Interface (), Val (Dims), args)
49
+ global @inline function _Concatenated (
50
+ interface:: Interface , dims:: Val{Dims} , args:: Args
51
+ ) where {Interface,Dims,Args<: Tuple }
52
+ return new {Interface,Dims,Args} (interface, dims, args)
58
53
end
59
54
end
60
55
61
- dims (:: Concatenated{A,D} ) where {A,D} = D
62
- DerivableInterfaces. interface (concat:: Concatenated ) = concat. interface
56
+ function Concatenated (interface:: Union{Nothing,AbstractInterface} , dims:: Val , args:: Tuple )
57
+ return _Concatenated (interface, dims, args)
58
+ end
59
+ function Concatenated (dims:: Val , args:: Tuple )
60
+ return Concatenated (interface (args... ), dims, args)
61
+ end
62
+ function Concatenated {Interface} (dims:: Val , args:: Tuple ) where {Interface}
63
+ return Concatenated (Interface (), dims, args)
64
+ end
65
+
66
+ dims (:: Concatenated{<:Any,D} ) where {D} = D
67
+ DerivableInterfaces. interface (concat:: Concatenated ) = getfield (concat, :interface )
63
68
64
69
concatenated (dims, args... ) = concatenated (Val (dims), args... )
65
70
concatenated (dims:: Val , args... ) = Concatenated (dims, args)
@@ -80,13 +85,33 @@ function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T}
80
85
return similar (arraytype (interface (concat), T), ax)
81
86
end
82
87
83
- Base. eltype (concat:: Concatenated ) = promote_eltypeof (concat. args... )
88
+ function cat_axis (
89
+ a1:: AbstractUnitRange , a2:: AbstractUnitRange , a_rest:: AbstractUnitRange...
90
+ )
91
+ return cat_axis (cat_axis (a1, a2), a_rest... )
92
+ end
93
+ cat_axis (a1:: AbstractUnitRange , a2:: AbstractUnitRange ) = Base. OneTo (length (a1) + length (a2))
84
94
85
- # For now, simply couple back to base implementation
86
- function Base. axes (concat:: Concatenated )
87
- catdims = Base. dims2cat (dims (concat))
88
- return Base. cat_size_shape (catdims, concat. args... )
95
+ function cat_ndims (dims, as:: AbstractArray... )
96
+ return max (maximum (dims), maximum (ndims, as))
97
+ end
98
+ function cat_ndims (dims:: Val , as:: AbstractArray... )
99
+ return cat_ndims (unval (dims), as... )
100
+ end
101
+
102
+ function cat_axes (dims, a:: AbstractArray , as:: AbstractArray... )
103
+ return ntuple (cat_ndims (dims, a, as... )) do dim
104
+ return dim in dims ? cat_axis (map (Base. Fix2 (axes, dim), (a, as... ))... ) : axes (a, dim)
105
+ end
89
106
end
107
+ function cat_axes (dims:: Val , as:: AbstractArray... )
108
+ return cat_axes (unval (dims), as... )
109
+ end
110
+
111
+ Base. eltype (concat:: Concatenated ) = promote_eltypeof (concat. args... )
112
+ Base. axes (concat:: Concatenated ) = cat_axes (dims (concat), concat. args... )
113
+ Base. size (concat:: Concatenated ) = length .(axes (concat))
114
+ Base. ndims (concat:: Concatenated ) = length (axes (concat))
90
115
91
116
# Main logic
92
117
# ----------
@@ -122,19 +147,59 @@ Base.materialize!(dest, concat::Concatenated) = copyto!(dest, concat)
122
147
123
148
Base. copy (concat:: Concatenated ) = copyto! (similar (concat), concat)
124
149
150
+ # The following is largely copied from the Base implementation of `Base.cat`, see:
151
+ # https://github.com/JuliaLang/julia/blob/885b1cd875f101f227b345f681cc36879124d80d/base/abstractarray.jl#L1778-L1887
152
+ _copy_or_fill! (A, inds, x) = fill! (view (A, inds... ), x)
153
+ _copy_or_fill! (A, inds, x:: AbstractArray ) = (A[inds... ] = x)
154
+
155
+ cat_size (A) = (1 ,)
156
+ cat_size (A:: AbstractArray ) = size (A)
157
+ cat_size (A, d) = 1
158
+ cat_size (A:: AbstractArray , d) = size (A, d)
159
+
160
+ cat_indices (A, d) = Base. OneTo (1 )
161
+ cat_indices (A:: AbstractArray , d) = axes (A, d)
162
+
163
+ function __cat! (A, shape, catdims, X... )
164
+ return __cat_offset! (A, shape, catdims, ntuple (zero, length (shape)), X... )
165
+ end
166
+ function __cat_offset! (A, shape, catdims, offsets, x, X... )
167
+ # splitting the "work" on x from X... may reduce latency (fewer costly specializations)
168
+ newoffsets = __cat_offset1! (A, shape, catdims, offsets, x)
169
+ return __cat_offset! (A, shape, catdims, newoffsets, X... )
170
+ end
171
+ __cat_offset! (A, shape, catdims, offsets) = A
172
+ function __cat_offset1! (A, shape, catdims, offsets, x)
173
+ inds = ntuple (length (offsets)) do i
174
+ (i <= length (catdims) && catdims[i]) ? offsets[i] .+ cat_indices (x, i) : 1 : shape[i]
175
+ end
176
+ _copy_or_fill! (A, inds, x)
177
+ newoffsets = ntuple (length (offsets)) do i
178
+ (i <= length (catdims) && catdims[i]) ? offsets[i] + cat_size (x, i) : offsets[i]
179
+ end
180
+ return newoffsets
181
+ end
182
+
183
+ dims2cat (dims:: Val ) = dims2cat (unval (dims))
184
+ function dims2cat (dims)
185
+ if any (≤ (0 ), dims)
186
+ throw (ArgumentError (" All cat dimensions must be positive integers, but got $dims " ))
187
+ end
188
+ return ntuple (in (dims), maximum (dims))
189
+ end
190
+
125
191
# default falls back to replacing interface with Nothing
126
192
# this permits specializing on typeof(dest) without ambiguities
127
193
# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base.
128
- @inline Base. copyto! (dest:: AbstractArray , concat:: Concatenated ) =
129
- copyto! (dest, convert (Concatenated{Nothing}, concat))
194
+ @inline function Base. copyto! (dest:: AbstractArray , concat:: Concatenated )
195
+ return copyto! (dest, convert (Concatenated{Nothing}, concat))
196
+ end
130
197
131
- # couple back to Base implementation if no specialization exists:
132
- # https://github.com/JuliaLang/julia/blob/29da86bb983066dd076439c2c7bc5e28dbd611bb/base/abstractarray.jl#L1852
133
198
function Base. copyto! (dest:: AbstractArray , concat:: Concatenated{Nothing} )
134
- catdims = Base . dims2cat (dims (concat))
135
- shape = Base . cat_size_shape (catdims, concat. args ... )
199
+ catdims = dims2cat (dims (concat))
200
+ shape = size ( concat)
136
201
count (! iszero, catdims):: Int > 1 && zero! (dest)
137
- return Base . __cat (dest, shape, catdims, concat. args... )
202
+ return __cat! (dest, shape, catdims, concat. args... )
138
203
end
139
204
140
205
end
0 commit comments