116
116
117
117
# # copyto!
118
118
# based on Base/array.jl, Base/abstractarray.jl
119
-
120
- function copyto! (dest:: AbstractMatrix , V:: Vcat{<:Any,2} )
121
- arrays = V. args
119
+ copyto! (dest:: AbstractArray , V:: Vcat ) = vcat_copyto! (dest, arguments (V)... )
120
+ function vcat_copyto! (dest:: AbstractMatrix , arrays... )
122
121
nargs = length (arrays)
123
122
nrows = size (dest,1 )
124
123
nrows == sum (a-> size (a, 1 ), arrays) || throw (DimensionMismatch (" sum of rows each matrix must equal $nrows " ))
@@ -131,35 +130,13 @@ function copyto!(dest::AbstractMatrix, V::Vcat{<:Any,2})
131
130
pos = 1
132
131
for a in arrays
133
132
p1 = pos+ size (a,1 )- 1
134
- dest[ pos: p1, :] . = a
133
+ copyto! ( view ( dest, pos: p1, :), a)
135
134
pos = p1+ 1
136
135
end
137
136
return dest
138
137
end
139
138
140
- # this is repeated to avoid allocation in .=
141
- function copyto! (dest:: AbstractMatrix , V:: Vcat{<:Any,2,<:Tuple{Vararg{<:AbstractMatrix}}} )
142
- arrays = V. args
143
- nargs = length (arrays)
144
- nrows = size (dest,1 )
145
- nrows == sum (a-> size (a, 1 ), arrays) || throw (DimensionMismatch (" sum of rows each matrix must equal $nrows " ))
146
- ncols = size (dest, 2 )
147
- for a in arrays
148
- if size (a, 2 ) != ncols
149
- throw (DimensionMismatch (" number of columns of each array must match (got $(map (x-> size (x,2 ), A)) )" ))
150
- end
151
- end
152
- pos = 1
153
- for a in arrays
154
- p1 = pos+ size (a,1 )- 1
155
- dest[pos: p1, :] = a
156
- pos = p1+ 1
157
- end
158
- return dest
159
- end
160
-
161
- function copyto! (arr:: AbstractVector , A:: Vcat{<:Any,1,<:Tuple{Vararg{<:AbstractVector}}} )
162
- arrays = A. args
139
+ function vcat_copyto! (arr:: AbstractVector , arrays... )
163
140
n = 0
164
141
for a in arrays
165
142
n += length (a)
@@ -168,16 +145,14 @@ function copyto!(arr::AbstractVector, A::Vcat{<:Any,1,<:Tuple{Vararg{<:AbstractV
168
145
169
146
i = 0
170
147
@inbounds for a in arrays
171
- for ai in a
172
- i += 1
173
- arr[i] = ai
174
- end
148
+ m = length (a)
149
+ copyto! (view (arr,i+ 1 : i+ m), a)
150
+ i += m
175
151
end
176
152
arr
177
153
end
178
154
179
- function copyto! (arr:: Vector{T} , A:: Vcat {T,1 ,<: Tuple{Vararg{<:Vector{T}}} }) where T
180
- arrays = A. args
155
+ function vcat_copyto! (arr:: Vector{T} , arrays:: Vector{T} ...) where T
181
156
n = 0
182
157
for a in arrays
183
158
n += length (a)
@@ -217,8 +192,8 @@ function copyto!(arr::Vector{T}, A::Vcat{T,1,<:Tuple{Vararg{<:Vector{T}}}}) wher
217
192
return arr
218
193
end
219
194
220
- function copyto! (dest:: AbstractMatrix , H:: Hcat )
221
- arrays = H . args
195
+ copyto! (dest:: AbstractMatrix , H:: Hcat ) = hcat_copyto! (dest, arguments (H) ... )
196
+ function hcat_copyto! (dest :: AbstractMatrix , arrays... )
222
197
nargs = length (arrays)
223
198
nrows = size (dest, 1 )
224
199
ncols = 0
@@ -229,7 +204,7 @@ function copyto!(dest::AbstractMatrix, H::Hcat)
229
204
ncols += (nd== 2 ? size (a,2 ) : 1 )
230
205
end
231
206
232
- nrows == size (H ,1 ) || throw (DimensionMismatch (" Destination rows must match" ))
207
+ nrows == size (first (arrays) ,1 ) || throw (DimensionMismatch (" Destination rows must match" ))
233
208
ncols == size (dest,2 ) || throw (DimensionMismatch (" Destination columns must match" ))
234
209
235
210
pos = 1
@@ -242,22 +217,22 @@ function copyto!(dest::AbstractMatrix, H::Hcat)
242
217
else
243
218
for a in arrays
244
219
p1 = pos+ (isa (a,AbstractMatrix) ? size (a, 2 ) : 1 )- 1
245
- dest[ :, pos: p1] . = a
220
+ copyto! ( view ( dest, :, pos: p1), a)
246
221
pos = p1+ 1
247
222
end
248
223
end
249
224
return dest
250
225
end
251
226
252
- function copyto ! (dest:: AbstractMatrix , H :: Hcat{<:Any,Tuple{Vararg{<: AbstractVector}}} )
227
+ function hcat_copyto ! (dest:: AbstractMatrix , arrays :: AbstractVector... )
253
228
height = size (dest, 1 )
254
- for j = 1 : length (H )
255
- if length (H [j]) != height
229
+ for j = 1 : length (arrays )
230
+ if length (arrays [j]) != height
256
231
throw (DimensionMismatch (" vectors must have same lengths" ))
257
232
end
258
233
end
259
- for j= 1 : length (H )
260
- dest[i,:] . = H [j]
234
+ for j= 1 : length (arrays )
235
+ copyto! ( view ( dest,:,j), arrays [j])
261
236
end
262
237
263
238
dest
@@ -502,27 +477,36 @@ applylayout(::Type{typeof(vcat)}, ::A, ::ZerosLayout) where A = PaddedLayout{A}(
502
477
cachedlayout (:: A , :: ZerosLayout ) where A = PaddedLayout {A} ()
503
478
504
479
505
- paddeddata (A:: CachedArray ) = A. data
480
+ paddeddata (A:: CachedArray ) = view ( A. data, OneTo .(A . datasize) ... )
506
481
paddeddata (A:: Vcat ) = A. args[1 ]
507
482
508
483
function == (A:: CachedVector{<:Any,<:Any,<:Zeros} , B:: CachedVector{<:Any,<:Any,<:Zeros} )
509
484
length (A) == length (B) || return false
510
- n = max (length (A . data), length (B . data) )
485
+ n = max (A . datasize[ 1 ], B . datasize[ 1 ] )
511
486
resizedata! (A,n); resizedata! (B,n)
512
- A. data == B. data
487
+ view ( A. data, OneTo (n)) == view ( B. data, OneTo (n))
513
488
end
514
489
515
490
# special copyto! since `similar` of a padded returns a cached
516
491
for Typ in (:Number , :AbstractVector )
517
492
@eval function copyto! (dest:: CachedVector{T,Vector{T},<:Zeros{T,1}} , src:: Vcat{<:Any,1,<:Tuple{<:$Typ,<:Zeros}} ) where T
518
493
length (src) ≤ length (dest) || throw (BoundsError ())
519
494
a,_ = src. args
520
- resizedata! (dest, length (a)) # make sure we are padded enough
521
- copyto! (dest. data, a)
495
+ n = length (a)
496
+ resizedata! (dest, n) # make sure we are padded enough
497
+ copyto! (view (dest. data,OneTo (n)), a)
522
498
dest
523
499
end
524
500
end
525
501
502
+ function copyto! (dest:: CachedVector{T,Vector{T},<:Zeros{T,1}} , src:: CachedVector{V,Vector{V},<:Zeros{V,1}} ) where {T,V}
503
+ length (src) ≤ length (dest) || throw (BoundsError ())
504
+ n = src. datasize[1 ]
505
+ resizedata! (dest, n)
506
+ copyto! (view (dest. data,OneTo (n)), view (src. data,OneTo (n)))
507
+ dest
508
+ end
509
+
526
510
struct Dot{StyleA,StyleB,ATyp,BTyp}
527
511
A:: ATyp
528
512
B:: BTyp
@@ -556,16 +540,17 @@ end
556
540
# subarrays
557
541
# ##
558
542
559
- subarraylayout (:: ApplyLayout{typeof(vcat)} , _) =
560
- ApplyLayout {typeof(vcat)} ()
561
- subarraylayout (:: ApplyLayout{typeof(hcat)} , _) =
562
- ApplyLayout {typeof(hcat)} ()
543
+ subarraylayout (:: ApplyLayout{typeof(vcat)} , _) = ApplyLayout {typeof(vcat)} ()
544
+ subarraylayout (:: ApplyLayout{typeof(hcat)} , _) = ApplyLayout {typeof(hcat)} ()
563
545
564
546
arguments (:: ApplyLayout{typeof(vcat)} , V:: SubArray{<:Any,2,<:Any,<:Tuple{<:Slice,<:Any}} ) =
565
547
view .(arguments (parent (V)), Ref (:), Ref (parentindices (V)[2 ]))
566
548
arguments (:: ApplyLayout{typeof(hcat)} , V:: SubArray{<:Any,2,<:Any,<:Tuple{<:Any,<:Slice}} ) =
567
549
view .(arguments (parent (V)), Ref (parentindices (V)[1 ]), Ref (:))
568
550
551
+ copyto! (dest:: AbstractArray{T,N} , src:: SubArray{T,N,<:Vcat{T,N}} ) where {T,N} = vcat_copyto! (dest, arguments (src)... )
552
+ copyto! (dest:: AbstractMatrix{T} , src:: SubArray{T,2,<:Hcat{T}} ) where T = hcat_copyto! (dest, arguments (src)... )
553
+
569
554
570
555
_vcat_lastinds (sz) = _vcat_cumsum (sz... )
571
556
_vcat_firstinds (sz) = (1 , (1 .+ most (_vcat_lastinds (sz))). .. )
@@ -612,7 +597,7 @@ function sub_materialize(::ApplyLayout{typeof(vcat)}, V)
612
597
_,jr = parentindices (V)
613
598
for a in arguments (V)
614
599
m = size (a,1 )
615
- view (ret,n+ 1 : n+ m,:) . = a
600
+ copyto! ( view (ret,n+ 1 : n+ m,:), a)
616
601
n += m
617
602
end
618
603
ret
@@ -624,9 +609,13 @@ function sub_materialize(::ApplyLayout{typeof(hcat)}, V)
624
609
kr,_ = parentindices (V)
625
610
for a in arguments (V)
626
611
m = size (a,2 )
627
- view (ret,:,n+ 1 : n+ m) . = a
612
+ copyto! ( view (ret,:,n+ 1 : n+ m), a)
628
613
n += m
629
614
end
630
615
ret
631
616
end
632
-
617
+ # temporarily allocate. In the future, we add a loop over arguments
618
+ materialize! (M:: MatMulMatAdd{<:AbstractColumnMajor,<:ApplyLayout{typeof(vcat)}} ) =
619
+ materialize! (MulAdd (M. α,M. A,Array (M. B),M. β,M. C))
620
+ materialize! (M:: MatMulVecAdd{<:AbstractColumnMajor,<:ApplyLayout{typeof(vcat)}} ) =
621
+ materialize! (MulAdd (M. α,M. A,Array (M. B),M. β,M. C))
0 commit comments