@@ -525,15 +525,15 @@ function add_broadcast_loops!(ls::LoopSet, loopsyms::Vector{Symbol}, destsym::Sy
525
525
end
526
526
end
527
527
528
- # size of dest determines loops
529
- # function vmaterialize!(
530
- @generated function vmaterialize! (
531
- dest :: AbstractArray{T,N} ,
532
- bc :: BC ,
533
- :: Val{Mod} ,
534
- :: Val{UNROLL} ,
535
- :: Val{dontbc} ,
536
- ) where {T <: NativeTypes ,N,BC <: Union{Broadcasted,Product} ,Mod,UNROLL,dontbc }
528
+ function vmaterialize_fun (
529
+ sizeofT :: Int ,
530
+ N,
531
+ @nospecialize (_ :: Type{BC} ) ,
532
+ Mod ,
533
+ UNROLL ,
534
+ dontbc ,
535
+ transpose :: Bool ,
536
+ ) where {BC }
537
537
# 2 + 1
538
538
# we have an N dimensional loop.
539
539
# need to construct the LoopSet
@@ -542,17 +542,20 @@ end
542
542
set_hw! (ls, rs, rc, cls)
543
543
ls. isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
544
544
loopsyms = [gensym! (ls, " n" ) for _ ∈ 1 : N]
545
- add_broadcast_loops! (ls, loopsyms, :dest )
546
- elementbytes = sizeof (T)
545
+ transpose && pushprepreamble! (ls, Expr (:(= ), :dest , Expr (:call , :parent , :dest′ )))
546
+ ret = transpose ? :dest′ : :dest
547
+ add_broadcast_loops! (ls, loopsyms, ret)
548
+ elementbytes = sizeofT
547
549
add_broadcast! (ls, :destination , :bc , loopsyms, BC, dontbc, elementbytes)
550
+ transpose && reverse! (loopsyms)
548
551
storeop =
549
552
add_simple_store! (ls, :destination , ArrayReference (:dest , loopsyms), elementbytes)
550
553
doaddref! (ls, storeop)
551
554
resize! (ls. loop_order, num_loops (ls)) # num_loops may be greater than N, eg Product
552
555
# return ls
553
556
sc = setup_call (
554
557
ls,
555
- :(Base. Broadcast. materialize! (dest , bc)),
558
+ :(Base. Broadcast. materialize! ($ ret , bc)),
556
559
LineNumberNode (0 ),
557
560
inline,
558
561
false ,
563
566
warncheckarg,
564
567
safe,
565
568
)
566
- Expr (:block , Expr (:meta , :inline ), sc, :dest )
569
+ Expr (:block , Expr (:meta , :inline ), sc, ret)
570
+ end
571
+
572
+ # size of dest determines loops
573
+ # function vmaterialize!(
574
+ @generated function vmaterialize! (
575
+ dest:: AbstractArray{T,N} ,
576
+ bc:: BC ,
577
+ :: Val{Mod} ,
578
+ :: Val{UNROLL} ,
579
+ :: Val{dontbc} ,
580
+ ) where {T<: NativeTypes ,N,BC<: Union{Broadcasted,Product} ,Mod,UNROLL,dontbc}
581
+ vmaterialize_fun (sizeof (T), N, BC, Mod, UNROLL, dontbc, false )
567
582
end
568
583
@generated function vmaterialize! (
569
584
dest′:: Union{Adjoint{T,A},Transpose{T,A}} ,
580
595
UNROLL,
581
596
dontbc,
582
597
}
583
- # we have an N dimensional loop.
584
- # need to construct the LoopSet
585
- ls = LoopSet (Mod)
586
- inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg, safe = UNROLL
587
- set_hw! (ls, rs, rc, cls)
588
- ls. isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
589
- loopsyms = [gensym! (ls, " n" ) for _ ∈ 1 : N]
590
- pushprepreamble! (ls, Expr (:(= ), :dest , Expr (:call , :parent , :dest′ )))
591
- add_broadcast_loops! (ls, loopsyms, :dest′ )
592
- elementbytes = sizeof (T)
593
- add_broadcast! (ls, :destination , :bc , loopsyms, BC, dontbc, elementbytes)
594
- storeop = add_simple_store! (
595
- ls,
596
- :destination ,
597
- ArrayReference (:dest , reverse (loopsyms)),
598
- elementbytes,
599
- )
600
- doaddref! (ls, storeop)
601
- resize! (ls. loop_order, num_loops (ls)) # num_loops may be greater than N, eg Product
602
- Expr (
603
- :block ,
604
- Expr (:meta , :inline ),
605
- setup_call (
606
- ls,
607
- :(Base. Broadcast. materialize! (dest′, bc)),
608
- LineNumberNode (0 ),
609
- inline,
610
- false ,
611
- u₁,
612
- u₂,
613
- v,
614
- threads % Int,
615
- warncheckarg,
616
- safe,
617
- ),
618
- :dest′ ,
619
- )
598
+ vmaterialize_fun (sizeof (T), N, BC, Mod, UNROLL, dontbc, true )
620
599
end
621
600
# these are marked `@inline` so the `@turbo` itself can choose whether or not to inline.
622
601
@generated function vmaterialize! (
0 commit comments