@@ -289,14 +289,18 @@ BroadcastStyle(::Type{<:Hcat{<:Any}}) where N = LazyArrayStyle{2}()
289
289
broadcasted (:: LazyArrayStyle , op, A:: Vcat ) =
290
290
Vcat (broadcast (x -> broadcast (op, x), A. args)... )
291
291
292
- broadcasted (:: LazyArrayStyle , op, A:: Vcat , c:: Number ) =
293
- Vcat (broadcast ((x,y) -> broadcast (op, x, y), A. args, c)... )
294
- broadcasted (:: LazyArrayStyle , op, c:: Number , A:: Vcat ) =
295
- Vcat (broadcast ((x,y) -> broadcast (op, x, y), c, A. args)... )
296
- broadcasted (:: LazyArrayStyle , op, A:: Vcat , c:: Ref ) =
297
- Vcat (broadcast ((x,y) -> broadcast (op, x, Ref (y)), A. args, c)... )
298
- broadcasted (:: LazyArrayStyle , op, c:: Ref , A:: Vcat ) =
299
- Vcat (broadcast ((x,y) -> broadcast (op, Ref (x), y), c, A. args)... )
292
+ for Cat in (:Vcat , :Hcat )
293
+ @eval begin
294
+ broadcasted (:: LazyArrayStyle , op, A:: $Cat , c:: Number ) =
295
+ $ Cat (broadcast ((x,y) -> broadcast (op, x, y), A. args, c)... )
296
+ broadcasted (:: LazyArrayStyle , op, c:: Number , A:: $Cat ) =
297
+ $ Cat (broadcast ((x,y) -> broadcast (op, x, y), c, A. args)... )
298
+ broadcasted (:: LazyArrayStyle , op, A:: $Cat , c:: Ref ) =
299
+ $ Cat (broadcast ((x,y) -> broadcast (op, x, Ref (y)), A. args, c)... )
300
+ broadcasted (:: LazyArrayStyle , op, c:: Ref , A:: $Cat ) =
301
+ $ Cat (broadcast ((x,y) -> broadcast (op, Ref (x), y), c, A. args)... )
302
+ end
303
+ end
300
304
301
305
302
306
# determine indices of components of a vcat
@@ -314,7 +318,7 @@ function broadcasted(::LazyArrayStyle, op, A::Vcat{<:Any,1}, B::AbstractVector)
314
318
B_arrays = _vcat_getindex_eval (B,kr... ) # evaluate B at same chunks as A
315
319
ApplyVector (vcat, broadcast ((a,b) -> broadcast (op,a,b), A. args, B_arrays)... )
316
320
end
317
-
321
+
318
322
function broadcasted (:: LazyArrayStyle , op, A:: AbstractVector , B:: Vcat{<:Any,1} )
319
323
kr = _vcat_axes (axes .(B. args)... )
320
324
A_arrays = _vcat_getindex_eval (A,kr... )
325
329
broadcasted (:: LazyArrayStyle , op, A:: Vcat{<:Any,1} , B:: Vcat{<:Any,1} ) =
326
330
Broadcasted {LazyArrayStyle} (op, (A, B))
327
331
332
+ # ambiguities
333
+ broadcasted (:: LazyArrayStyle , op, A:: Vcat{<:Any,1} , B:: CachedVector ) = cache_broadcast (op, A, B)
334
+ broadcasted (:: LazyArrayStyle , op, A:: CachedVector , B:: Vcat{<:Any,1} ) = cache_broadcast (op, A, B)
335
+
336
+ broadcasted (:: LazyArrayStyle{1} , :: typeof (* ), a:: Vcat{<:Any,1} , b:: Zeros{<:Any,1} )=
337
+ broadcast (DefaultArrayStyle {1} (), * , a, b)
338
+
339
+
328
340
329
341
function + (A:: Vcat , B:: Vcat )
330
342
size (A) == size (B) || throw (DimensionMismatch (" dimensions must match." ))
@@ -477,7 +489,7 @@ function materialize!(M::MatMulVecAdd{ApplyLayout{typeof(hcat)},ApplyLayout{type
477
489
# ###
478
490
479
491
480
- most (a) = reverse (tail (reverse (a)))
492
+ most (a) = reverse (tail (reverse (a)))
481
493
colsupport (M:: Vcat , j) = first (colsupport (first (M. args),j)): (size (Vcat (most (M. args)... ),1 )+ last (colsupport (last (M. args),j)))
482
494
483
495
556
568
557
569
sublayout (:: ApplyLayout{typeof(vcat)} , _) = ApplyLayout {typeof(vcat)} ()
558
570
sublayout (:: ApplyLayout{typeof(hcat)} , _) = ApplyLayout {typeof(hcat)} ()
571
+ # a row-slice of an Hcat is equivalent to a Vcat
572
+ sublayout (:: ApplyLayout{typeof(hcat)} , :: Type{<:Tuple{Number,AbstractVector}} ) = ApplyLayout {typeof(vcat)} ()
559
573
560
574
arguments (:: ApplyLayout{typeof(vcat)} , V:: SubArray{<:Any,2,<:Any,<:Tuple{<:Slice,<:Any}} ) =
561
575
view .(arguments (parent (V)), Ref (:), Ref (parentindices (V)[2 ]))
@@ -575,14 +589,17 @@ _view_vcat(a::Number, kr) = Fill(a,length(kr))
575
589
_view_vcat (a:: Number , kr, jr) = Fill (a,length (kr), length (jr))
576
590
_view_vcat (a, kr... ) = view (a, kr... )
577
591
578
- function arguments (:: ApplyLayout{typeof(vcat)} , V:: SubArray{<:Any,1} )
579
- A = parent (V)
592
+ function _vcat_sub_arguments (:: ApplyLayout{typeof(vcat)} , A, V)
580
593
kr = parentindices (V)[1 ]
581
594
sz = size .(arguments (A),1 )
582
595
skr = intersect .(_argsindices (sz), Ref (kr))
583
596
skr2 = broadcast ((a,b) -> a .- b .+ 1 , skr, _vcat_firstinds (sz))
584
597
_view_vcat .(arguments (A), skr2)
585
598
end
599
+ _vcat_sub_arguments (:: ApplyLayout{typeof(hcat)} , A, V) = arguments (ApplyLayout {typeof(hcat)} (), V)
600
+
601
+ _vcat_sub_arguments (A, V) = _vcat_sub_arguments (MemoryLayout (typeof (A)), A, V)
602
+ arguments (:: ApplyLayout{typeof(vcat)} , V:: SubArray{<:Any,1} ) = _vcat_sub_arguments (parent (V), V)
586
603
587
604
function arguments (:: ApplyLayout{typeof(vcat)} , V:: SubArray{<:Any,2} )
588
605
A = parent (V)
596
613
_view_hcat (a:: Number , kr, jr) = Fill (a,length (kr),length (jr))
597
614
_view_hcat (a, kr, jr) = view (a, kr, jr)
598
615
599
- function arguments (:: ApplyLayout{typeof(hcat)} , V:: SubArray{<:Any,2} )
616
+ function arguments (:: ApplyLayout{typeof(hcat)} , V:: SubArray )
600
617
A = parent (V)
601
618
kr,jr = parentindices (V)
602
619
sz = size .(arguments (A),2 )
@@ -628,6 +645,7 @@ function sub_materialize(::ApplyLayout{typeof(hcat)}, V)
628
645
end
629
646
ret
630
647
end
648
+
631
649
# temporarily allocate. In the future, we add a loop over arguments
632
650
materialize! (M:: MatMulMatAdd{<:AbstractColumnMajor,<:ApplyLayout{typeof(vcat)}} ) =
633
651
materialize! (MulAdd (M. α,M. A,Array (M. B),M. β,M. C))
0 commit comments