@@ -26,7 +26,7 @@ isdense(::Type{<:DenseArray}) = true
26
26
27
27
28
28
29
- @enum NodeType begin
29
+ @enum OperationType begin
30
30
memload
31
31
memstore
32
32
compute_new
37
37
# const ID = Threads.Atomic{UInt}(0)
38
38
39
39
"""
40
- if node_type == memstore || node_type == compute_new || node_type == compute_store
40
+ if ooperation_type == memstore || operation_type == memstore# || operation_type == compute_new || operation_type == compute_update
41
41
symbolic metadata contains info on direct dependencies / placement within loop.
42
42
43
-
43
+ if accesses_memory(op)
44
+ Symbol(:vptr_, op.variable)
45
+ is how we access the memory.
46
+ If numerical_metadata[i] == -1
47
+ Symbol(:stride_, op.variable, :_, op.symbolic_metadata[i])
48
+ is the stride for loop index
49
+ symbolic_metadata[i]
44
50
"""
45
51
struct Operation
46
52
identifier:: UInt
47
53
variable:: Symbol
48
54
elementbytes:: Int
49
55
instruction:: Symbol
50
- node_type:: NodeType
56
+ node_type:: OperationType
51
57
# dependencies::Vector{Symbol}
52
58
dependencies:: Set{Symbol}
53
59
# dependencies::Set{Symbol}
54
60
parents:: Vector{Operation}
55
61
children:: Vector{Operation}
56
- numerical_metadata:: Vector{Int}
62
+ numerical_metadata:: Vector{Int} # stride of -1 indicates dynamic
57
63
symbolic_metadata:: Vector{Symbol}
64
+ # strides::Dict{Symbol,Union{Symbol,Int}}
58
65
function Operation (
66
+ identifier,
59
67
elementbytes,
60
68
instruction,
61
69
node_type,
62
- identifier,
63
70
variable = gensym ()
64
71
)
65
72
# identifier = Threads.atomic_add!(ID, one(UInt))
66
73
new (
67
74
identifier, variable, elementbytes, instruction, node_type,
68
- Set {Symbol} (), Operation[], Operation[], Int[], Symbol[]
75
+ Set {Symbol} (), Operation[], Operation[], Int[], Symbol[]# , Dict{Symbol,Union{Symbol,Int}}()
69
76
)
70
77
end
71
78
end
@@ -85,16 +92,45 @@ identifier(op::Operation) = op.identifier
85
92
name (op:: Operation ) = op. variable
86
93
instruction (op:: Operation ) = op. instruction
87
94
95
+ function symposition (op:: Operation , sym:: Symbol )
96
+ findfirst (s -> s === sym, op. symbolic_metadata)
97
+ end
88
98
function stride (op:: Operation , sym:: Symbol )
89
99
@assert accesses_memory (op) " This operation does not access memory!"
90
100
# access stride info?
91
- op. numerical_metadata[findfirst (s -> s === sym, op . symbolic_metadata )]
101
+ op. numerical_metadata[symposition (op,sym )]
92
102
end
93
103
# function
94
104
function unitstride (op:: Operation , sym:: Symbol )
95
105
(first (op. symbolic_metadata) === sym) && (first (op. numerical_metadata) == 1 )
96
106
end
97
-
107
+ function mem_offset (op:: Operation , incr:: Int = 0 ):: Union{Symbol,Expr}
108
+ @assert accesses_memory (op) " Computing memory offset only makes sense for operations that access memory."
109
+ @unpack numerical_metadata, symbolic_metadata = op
110
+ if incr == 0 && length (numerical_metadata) == 1
111
+ firstsym = first (symbolic_metadata)
112
+ if first (numerical_metadata) == 1
113
+ return firstsym
114
+ elseif first (numerical_metadata) == - 1
115
+ return Expr (:call , :* , Symbol (:stride_ , op. variable, :_ , firstsym), firstsym)
116
+ else
117
+ return Expr (:call , :* , first (numerical_metadata), firstsym)
118
+ end
119
+ end
120
+ ret = Expr (:call , :+ , )
121
+ for i ∈ eachindex (numerical_metadata)
122
+ sym = symbolic_metadata[i]; num = numerical_metadata[i]
123
+ if num == 1
124
+ push! (ret. args, sym)
125
+ elseif num == - 1
126
+ push! (ret. args, Expr (:call , :* , Symbol (:stride_ , op. variable, :_ , firstsym), sym))
127
+ else
128
+ push! (ret. args, Expr (:call , :* , num, sym))
129
+ end
130
+ end
131
+ incr == 0 || push! (ret. args, incr)
132
+ ret
133
+ end
98
134
99
135
struct Loop
100
136
itersymbol:: Symbol
@@ -457,25 +493,147 @@ function depends_on_assigned(op::Operation, assigned::Vector{Bool})
457
493
end
458
494
false
459
495
end
460
- function lower_load! (q:: Expr , op:: Operation , unrolled:: Symbol , U, Umax, T = nothing , Tmax = nothing )
496
+ function replace_ind_in_offset! (offset:: Vector , op:: Operation , ind:: Int , dynamic:: Bool , t)
497
+ t == 0 && return nothing
498
+ var = op. variable
499
+ siter = op. symbolic_metadata[ind]
500
+ striden = op. numerical_metadata[ind]
501
+ strides = Symbol (:stride_ , var)
502
+ offset[ind] = if tstriden == - 1
503
+ Expr (:call , :* , Expr (:call , :+ , strides, t), siter)
504
+ else
505
+ Expr (:call , :* , striden + t, siter)
506
+ end
507
+ nothing
508
+ end
509
+
510
+ # TODO : this code should be rewritten to be more "orthogonal", so that we're just combining separate pieces.
511
+ # Using sentinel values (eg, T = -1 for non tiling) in part to avoid recompilation.
512
+ function lower_load! (
513
+ q:: Expr , op:: Operation , W:: Int , unrolled:: Symbol ,
514
+ U:: Int , T:: Int = - 1 , tiled:: Symbol = :undef
515
+ )
461
516
loopdeps = loopdependencies (op)
517
+ var = op. variable
518
+ ptr = Symbol (:vptr_ , var)
519
+ memoff = mem_offset (op)
520
+ tind = T == - 1 ? - 1 : findfirst (s -> s === tiled, op. symbolic_metadata)
521
+ upos = symposition (op, unrolled)
522
+ ustride = op. numerical_metadata[upos]
462
523
if unrolled ∈ loopdeps # we need a vector
463
- if unitstride (op, unrolled) # vload
464
-
465
- else # gather
466
-
524
+ if ustride == 1 # vload
525
+ if T == - 1 && U == 1
526
+ push! (q. args, Expr (:(= ), var, Expr (:call ,:vload ,ptr,memoff)))
527
+ elseif T == - 1
528
+ for u ∈ 0 : U- 1
529
+ push! (q. args, Expr (:(= ), Symbol (var,:_ ,u), Expr (:call ,:vload , Val (W), ptr, u == 0 ? memoff : push! (copy (memoff), W* u))))
530
+ end
531
+ else # tiling
532
+ for t ∈ 0 : T- 1
533
+ replace_ind_inoffset! (memoff, op, tind, t)
534
+ for u ∈ 0 : U- 1
535
+ memoff2 = copy (memoff)
536
+ u > 0 && push! (memoff2, W* u)
537
+ push! (q. args, Expr (:(= ), Symbol (var, :_ , u, :_ , t), Expr (:call , :vload , Val (W), ptr, memoff2)))
538
+ end
539
+ end
540
+ end
541
+ else
542
+ # ustep = ustride > 1 ? ustride : op.symbolic_metadata[upos]
543
+ ustrides = Expr (:tuple , (ustride > 1 ? [Core. VecElement {Int} (ustride* w) for w ∈ 0 : W- 1 ] : [:(Core. VecElement {Int} ($ (op. symbolic_metadata[upos])* $ w)) for w ∈ 0 : W- 1 ]). .. )
544
+ if T != - 1 # gather tile
545
+ for t ∈ 0 : T- 1
546
+ replace_ind_inoffset! (memoff, op, tind, t)
547
+ for u ∈ 0 : U- 1
548
+ memoff2 = copy (memoff)
549
+ u > 0 && push! (memoff2, ustride > 1 ? u* W* ustride : Expr (:call ,:* ,op. symbolic_metadata[upos],u* W) )
550
+ push! (q. args, Expr (:(= ), Symbol (var,:_ ,u,:_ ,t), Expr (:call , :gather , ptr, Expr (:call , :vadd , memoff2, ustrides))))
551
+ end
552
+ end
553
+ # elseif unitstride(op, tiled) # TODO : we load tiled, and then shuffle
554
+ elseif U == 1 # we gather, no tile, no extra unroll
555
+ push! (q. args, Expr (:(= ), var, Expr (:call ,:gather ,ptr,Expr (:call ,:vadd ,memoff,ustrides))))
556
+ else # we gather, no tile, but extra unroll
557
+ for u ∈ 0 : U- 1
558
+ memoff2 = u == 0 ? memoff : push! (copy (memoff), ustride > 1 ? u* W* ustride : Expr (:call ,:* ,op. symbolic_metadata[upos],u* W) )
559
+ push! (q. args, Expr (:(= ), Symbol (var,:_ ,u), Expr (:call , :gather , ptr, Expr (:call ,:vadd ,memoff2,ustrides))))
560
+ end
561
+ end
562
+ end
563
+ elseif T != - 1 && tiled ∈ loopdeps # load for each tile.
564
+ # load per T.
565
+ # memoff2 = copy(memoff)
566
+ for t ∈ 0 : T- 1
567
+ replace_ind_inoffset! (memoff, op, tind, t)
568
+ push! (q. args, Expr (:(= ), Symbol (var,:_ ,t), Expr (:call , :load , ptr, copy (memoff))))
467
569
end
468
570
else # load scalar; promotion should broadcast as/when neccesary
469
- Expr (:call , :(VectorizationBase . load) , )
571
+ push! (q . args, Expr (:( = ), var, Expr ( : call , :load , ptr, memoff)) )
470
572
end
471
573
end
472
574
function lower_store! (q:: Expr , op:: Operation , unrolled:: Symbol , U, T = 1 )
473
-
575
+ q:: Expr , op:: Operation , W:: Int , unrolled:: Symbol ,
576
+ U:: Int , T:: Int = - 1 , tiled:: Symbol = :undef
577
+ )
578
+ loopdeps = loopdependencies (op)
579
+ var = first (parents (op)). variable
580
+ ptr = Symbol (:vptr_ , op. variable)
581
+ memoff = mem_offset (op)
582
+ tind = T == - 1 ? - 1 : findfirst (s -> s === tiled, op. symbolic_metadata)
583
+ upos = symposition (op, unrolled)
584
+ ustride = op. numerical_metadata[upos]
585
+ if unrolled ∈ loopdeps # we need a vector
586
+ if ustride == 1 # vload
587
+ if T == - 1 && U == 1
588
+ push! (q. args, Expr (:(= ), var, Expr (:call ,:vload ,ptr,memoff)))
589
+ elseif T == - 1
590
+ for u ∈ 0 : U- 1
591
+ push! (q. args, Expr (:(= ), Symbol (var,:_ ,u), Expr (:call ,:vstore , Val (W), ptr, u == 0 ? memoff : push! (copy (memoff), W* u))))
592
+ end
593
+ else # tiling
594
+ for t ∈ 0 : T- 1
595
+ replace_ind_inoffset! (memoff, op, tind, t)
596
+ for u ∈ 0 : U- 1
597
+ memoff2 = copy (memoff)
598
+ u > 0 && push! (memoff2, W* u)
599
+ push! (q. args, Expr (:(= ), Symbol (var, :_ , u, :_ , t), Expr (:call , :vload , Val (W), ptr, memoff2)))
600
+ end
601
+ end
602
+ end
603
+ else
604
+ # ustep = ustride > 1 ? ustride : op.symbolic_metadata[upos]
605
+ ustrides = Expr (:tuple , (ustride > 1 ? [Core. VecElement {Int} (ustride* w) for w ∈ 0 : W- 1 ] : [:(Core. VecElement {Int} ($ (op. symbolic_metadata[upos])* $ w)) for w ∈ 0 : W- 1 ]). .. )
606
+ if T != - 1 # gather tile
607
+ for t ∈ 0 : T- 1
608
+ replace_ind_inoffset! (memoff, op, tind, t)
609
+ for u ∈ 0 : U- 1
610
+ memoff2 = copy (memoff)
611
+ u > 0 && push! (memoff2, ustride > 1 ? u* W* ustride : Expr (:call ,:* ,op. symbolic_metadata[upos],u* W) )
612
+ push! (q. args, Expr (:(= ), Symbol (var,:_ ,u,:_ ,t), Expr (:call , :gather , ptr, Expr (:call , :vadd , memoff2, ustrides))))
613
+ end
614
+ end
615
+ # elseif unitstride(op, tiled) # TODO : we load tiled, and then shuffle
616
+ elseif U == 1 # we gather, no tile, no extra unroll
617
+ push! (q. args, Expr (:(= ), var, Expr (:call ,:gather ,ptr,Expr (:call ,:vadd ,memoff,ustrides))))
618
+ else # we gather, no tile, but extra unroll
619
+ for u ∈ 0 : U- 1
620
+ memoff2 = u == 0 ? memoff : push! (copy (memoff), ustride > 1 ? u* W* ustride : Expr (:call ,:* ,op. symbolic_metadata[upos],u* W) )
621
+ push! (q. args, Expr (:(= ), Symbol (var,:_ ,u), Expr (:call , :gather , ptr, Expr (:call ,:vadd ,memoff2,ustrides))))
622
+ end
623
+ end
624
+ end
625
+ elseif T != - 1 && tiled ∈ loopdeps # load for each tile.
626
+ # load per T.
627
+ # memoff2 = copy(memoff)
628
+ for t ∈ 0 : T- 1
629
+ replace_ind_inoffset! (memoff, op, tind, t)
630
+ push! (q. args, Expr (:(= ), Symbol (var,:_ ,t), Expr (:call , :load , ptr, copy (memoff))))
631
+ end
632
+ else # load scalar; promotion should broadcast as/when neccesary
633
+ push! (q. args, Expr (:(= ), var, Expr (:call , :load , ptr, memoff)))
634
+ end
474
635
end
475
636
function lower_compute! (q:: Expr , op:: Operation , unrolled:: Symbol , U, T = 1 )
476
- for t ∈ T, u ∈ U
477
-
478
- end
479
637
end
480
638
function lower! (q:: Expr , op:: Operation , unrolled:: Symbol , U, T = 1 )
481
639
if isload (op)
0 commit comments