36
36
37
37
# const ID = Threads.Atomic{UInt}(0)
38
38
39
+ # TODO : can some computations be cached in the operations?
39
40
"""
40
41
if ooperation_type == memstore || operation_type == memstore# || operation_type == compute_new || operation_type == compute_update
41
42
symbolic metadata contains info on direct dependencies / placement within loop.
@@ -77,6 +78,8 @@ struct Operation
77
78
end
78
79
end
79
80
81
+
82
+
80
83
function isreduction (op:: Operation )
81
84
(op. node_type == memstore) && (length (op. symbolic_metadata) < length (op. dependencies))# && issubset(op.symbolic_metadata, op.dependencies)
82
85
end
@@ -92,6 +95,7 @@ identifier(op::Operation) = op.identifier
92
95
name (op:: Operation ) = op. variable
93
96
instruction (op:: Operation ) = op. instruction
94
97
98
+
95
99
function symposition (op:: Operation , sym:: Symbol )
96
100
findfirst (s -> s === sym, op. symbolic_metadata)
97
101
end
@@ -139,7 +143,7 @@ struct Loop
139
143
hintexact:: Bool # if true, rangesym ignored and rangehint used for final lowering
140
144
end
141
145
function Loop (itersymbol:: Symbol , rangehint:: Int )
142
- Loop ( itersymbol, rangehint, :undef , true )
146
+ Loop ( itersymbol, rangehint, Symbol ( " ##UNDEFINED## " ) , true )
143
147
end
144
148
function Loop (itersymbol:: Symbol , rangesym:: Symbol , rangehint:: Int = 1_024 )
145
149
Loop ( itersymbol, rangehint, rangesym, false )
@@ -152,8 +156,9 @@ struct LoopSet
152
156
loadops:: Vector{Operation} # Split them to make it easier to iterate over just a subset
153
157
computeops:: Vector{Operation}
154
158
storeops:: Vector{Operation}
155
- reductions:: Set{UInt} # IDs of reduction operations that need to be reduced at end.
156
- strideset:: Vector{}
159
+ inner_reductions:: Set{UInt} # IDs of reduction operations nested within loops and stored.
160
+ outer_reductions:: Set{UInt} # IDs of reduction operations that need to be reduced at end.
161
+ # strideset::Vector{}
157
162
end
158
163
num_loops (ls:: LoopSet ) = length (ls. loops)
159
164
isstaticloop (ls:: LoopSet , s:: Symbol ) = ls. loops[s]. hintexact
511
516
# Using sentinel values (eg, T = -1 for non tiling) in part to avoid recompilation.
512
517
function lower_load! (
513
518
q:: Expr , op:: Operation , W:: Int , unrolled:: Symbol ,
514
- U:: Int , T:: Int = - 1 , tiled:: Symbol = :undef
519
+ U:: Int , T:: Int = - 1 , tiled:: Symbol = Symbol ( " ##UNDEFINED## " )
515
520
)
516
521
loopdeps = loopdependencies (op)
517
522
var = op. variable
@@ -526,15 +531,15 @@ function lower_load!(
526
531
push! (q. args, Expr (:(= ), var, Expr (:call ,:vload ,ptr,memoff)))
527
532
elseif T == - 1
528
533
for u ∈ 0 : U- 1
529
- push! (q. args, Expr (:(= ), Symbol (var,:_ ,u), Expr (:call ,:vload , Val (W ), ptr, u == 0 ? memoff : push! (copy (memoff), W* u))))
534
+ push! (q. args, Expr (:(= ), Symbol (var,:_ ,u), Expr (:call ,:vload , Val {W} ( ), ptr, u == 0 ? memoff : push! (copy (memoff), W* u))))
530
535
end
531
536
else # tiling
532
537
for t ∈ 0 : T- 1
533
538
replace_ind_inoffset! (memoff, op, tind, t)
534
539
for u ∈ 0 : U- 1
535
540
memoff2 = copy (memoff)
536
541
u > 0 && push! (memoff2, W* u)
537
- push! (q. args, Expr (:(= ), Symbol (var, :_ , u, :_ , t), Expr (:call , :vload , Val (W ), ptr, memoff2)))
542
+ push! (q. args, Expr (:(= ), Symbol (var, :_ , u, :_ , t), Expr (:call , :vload , Val {W} ( ), ptr, memoff2)))
538
543
end
539
544
end
540
545
end
@@ -571,9 +576,9 @@ function lower_load!(
571
576
push! (q. args, Expr (:(= ), var, Expr (:call , :load , ptr, memoff)))
572
577
end
573
578
end
574
- function lower_store! (q :: Expr , op :: Operation , unrolled :: Symbol , U, T = 1 )
579
+ function lower_store! (
575
580
q:: Expr , op:: Operation , W:: Int , unrolled:: Symbol ,
576
- U:: Int , T:: Int = - 1 , tiled:: Symbol = :undef
581
+ U:: Int , T:: Int = - 1 , tiled:: Symbol = Symbol ( " ##UNDEFINED## " )
577
582
)
578
583
loopdeps = loopdependencies (op)
579
584
var = first (parents (op)). variable
@@ -588,15 +593,15 @@ function lower_store!(q::Expr, op::Operation, unrolled::Symbol, U, T = 1)
588
593
push! (q. args, Expr (:(= ), var, Expr (:call ,:vload ,ptr,memoff)))
589
594
elseif T == - 1
590
595
for u ∈ 0 : U- 1
591
- push! (q. args, Expr (:( = ), Symbol (var,:_ ,u), Expr ( :call , :vstore , Val (W), ptr, u == 0 ? memoff : push! (copy (memoff), W* u) )))
596
+ push! (q. args, Expr (:call , :vstore! , ptr, Symbol (var,:_ ,u), u == 0 ? memoff : push! (copy (memoff), W* u)))
592
597
end
593
598
else # tiling
594
599
for t ∈ 0 : T- 1
595
600
replace_ind_inoffset! (memoff, op, tind, t)
596
601
for u ∈ 0 : U- 1
597
602
memoff2 = copy (memoff)
598
603
u > 0 && push! (memoff2, W* u)
599
- push! (q. args, Expr (:( = ), Symbol (var, :_ , u, :_ , t), Expr ( :call , :vload , Val (W), ptr, memoff2) ))
604
+ push! (q. args, Expr (:call , :vstore! , ptr, Symbol (var, :_ , u, :_ , t), memoff2))
600
605
end
601
606
end
602
607
end
@@ -609,16 +614,16 @@ function lower_store!(q::Expr, op::Operation, unrolled::Symbol, U, T = 1)
609
614
for u ∈ 0 : U- 1
610
615
memoff2 = copy (memoff)
611
616
u > 0 && push! (memoff2, ustride > 1 ? u* W* ustride : Expr (:call ,:* ,op. symbolic_metadata[upos],u* W) )
612
- push! (q. args, Expr (:( = ), Symbol (var,:_ ,u,:_ ,t), Expr (:call , :gather , ptr, Expr ( :call , : vadd , memoff2, ustrides) )))
617
+ push! (q. args, Expr (:call , :scatter! , ptr, Symbol (var,:_ ,u,:_ ,t), Expr (:call , :vadd , memoff2, ustrides)))
613
618
end
614
619
end
615
620
# elseif unitstride(op, tiled) # TODO : we load tiled, and then shuffle
616
621
elseif U == 1 # we gather, no tile, no extra unroll
617
- push! (q. args, Expr (:( = ), var, Expr ( : call ,:gather ,ptr,Expr (:call ,:vadd ,memoff,ustrides) )))
622
+ push! (q. args, Expr (:call ,:scatter! ,ptr, var, Expr (:call ,:vadd ,memoff,ustrides)))
618
623
else # we gather, no tile, but extra unroll
619
624
for u ∈ 0 : U- 1
620
625
memoff2 = u == 0 ? memoff : push! (copy (memoff), ustride > 1 ? u* W* ustride : Expr (:call ,:* ,op. symbolic_metadata[upos],u* W) )
621
- push! (q. args, Expr (:( = ), Symbol (var,:_ ,u), Expr (:call , :gather , ptr, Expr ( :call , : vadd ,memoff2,ustrides) )))
626
+ push! (q. args, Expr (:call , :scatter! , ptr, Symbol (var,:_ ,u), Expr (:call ,: vadd ,memoff2,ustrides)))
622
627
end
623
628
end
624
629
end
@@ -627,13 +632,43 @@ function lower_store!(q::Expr, op::Operation, unrolled::Symbol, U, T = 1)
627
632
# memoff2 = copy(memoff)
628
633
for t ∈ 0 : T- 1
629
634
replace_ind_inoffset! (memoff, op, tind, t)
630
- push! (q. args, Expr (:( = ), Symbol (var,:_ ,t), Expr ( :call , :load , ptr, copy (memoff) )))
635
+ push! (q. args, Expr (:call , :store! , ptr, Symbol (var,:_ ,t), copy (memoff)))
631
636
end
632
637
else # load scalar; promotion should broadcast as/when neccesary
633
- push! (q. args, Expr (:( = ), var, Expr ( :call , :load , ptr, memoff) ))
638
+ push! (q. args, Expr (:call , :store! , var , ptr, memoff))
634
639
end
635
640
end
636
- function lower_compute! (q:: Expr , op:: Operation , unrolled:: Symbol , U, T = 1 )
641
+ # A compute op needs to know the unrolling and tiling status of each of its parents.
642
+ #
643
+ function lower_compute! (
644
+ q:: Expr , op:: Operation , W:: Int , unrolled:: Symbol ,
645
+ U:: Int , T:: Int = - 1 , tiled:: Symbol = Symbol (" ##UNDEFINED##" )
646
+ )
647
+ opunrolled = unrolled ∈ loopdependencies (op)
648
+ optiled = tiled ∈ loopdependencies (op)
649
+ var = op. variable
650
+ # cache unroll and tiling check of parents
651
+ # not broadcasted, because we use frequent checks of individual bools
652
+ # making BitArrays inefficient.
653
+ parentsunrolled = [unrolled ∈ loopdependencies (opp) for oppp ∈ parents (op)]
654
+ parentstiled = [tiled ∈ loopdependencies (opp) for oppp ∈ parents (op)]
655
+ if opunrolled
656
+ if optiled
657
+ for t ∈ 0 : T- 1
658
+ for u ∈ 0 : U- 1
659
+
660
+ end
661
+ end
662
+ else # not tiled
663
+
664
+ end
665
+ else # not unrolled
666
+ if optiled # but not unrolled
667
+
668
+ else # not tiled and not unrolled
669
+
670
+ end
671
+ end
637
672
end
638
673
function lower! (q:: Expr , op:: Operation , unrolled:: Symbol , U, T = 1 )
639
674
if isload (op)
0 commit comments