@@ -56,40 +56,94 @@ function set_upstream_family!(adal::Vector{T}, op::Operation, val::T, ld::Vector
56
56
set_upstream_family! (adal, opp, val, ld, id)
57
57
end
58
58
end
59
-
59
+ function search_for_reductinit! (op:: Operation , opswap:: Operation , var:: Symbol , loopdeps:: Vector{Symbol} )
60
+ for (i,opp) ∈ enumerate (parents (op))
61
+ if (name (opp) === var) && (length (reduceddependencies (opp)) == 0 ) && (length (loopdependencies (opp)) == length (loopdeps)) && (length (children (opp)) == 1 )
62
+ if all (in (loopdeps), loopdependencies (opp))
63
+ parents (op)[i] = opswap
64
+ return opp
65
+ end
66
+ end
67
+ opcheck = search_for_reductinit! (opp, opswap, var, loopdeps)
68
+ opcheck === opp || return opcheck
69
+ end
70
+ return op
71
+ end
60
72
function addoptoorder! (
61
73
ls:: LoopSet , included_vars:: Vector{Bool} , place_after_loop:: Vector{Bool} , op:: Operation ,
62
74
loopsym:: Symbol , _n:: Int , u₁loop:: Symbol , u₂loop:: Symbol , vectorized:: Symbol , u₂max:: Int
63
75
)
64
- lo = ls. loop_order
65
- id = identifier (op)
66
- included_vars[id] || return nothing
67
- loopsym ∈ loopdependencies (op) || return nothing
68
- for opp ∈ parents (op) # ensure parents are added first
69
- addoptoorder! (ls, included_vars, place_after_loop, opp, loopsym, _n, u₁loop, u₂loop, vectorized, u₂max)
70
- end
71
- included_vars[id] || return nothing
72
- included_vars[id] = false
73
- isunrolled = (isu₁unrolled (op)) + 1
74
- istiled = isu₂unrolled (op) + 1
75
- # optype = Int(op.node_type) + 1
76
- after_loop = place_after_loop[id] + 1
77
- if ! isloopvalue (op)
78
- isnopidentity (ls, op, u₁loop, u₂loop, vectorized, u₂max) || push! (lo[isunrolled,istiled,after_loop,_n], op)
79
- # if istiled
80
- # isnopidentity(ls, op, u₁loop, u₂loop, vectorized, u₂max, u₂max) || push!(lo[isunrolled,2,after_loop,_n], op)
81
- # else
82
- # isnopidentity(ls, op, u₁loop, u₂loop, vectorized, u₂max, nothing) || push!(lo[isunrolled,1,after_loop,_n], op)
83
- # end
76
+ lo = ls. loop_order
77
+ id = identifier (op)
78
+ included_vars[id] || return nothing
79
+ loopsym ∈ loopdependencies (op) || return nothing
80
+ for opp ∈ parents (op) # ensure parents are added first
81
+ addoptoorder! (ls, included_vars, place_after_loop, opp, loopsym, _n, u₁loop, u₂loop, vectorized, u₂max)
82
+ end
83
+ included_vars[id] || return nothing
84
+ included_vars[id] = false
85
+ isunrolled = (isu₁unrolled (op)) + 1
86
+ istiled = isu₂unrolled (op) + 1
87
+ # optype = Int(op.node_type) + 1
88
+ after_loop = place_after_loop[id] + 1
89
+ if ! isloopvalue (op)
90
+ isnopidentity (ls, op, u₁loop, u₂loop, vectorized, u₂max) || push! (lo[isunrolled,istiled,after_loop,_n], op)
91
+ # if istiled
92
+ # isnopidentity(ls, op, u₁loop, u₂loop, vectorized, u₂max, u₂max) || push!(lo[isunrolled,2,after_loop,_n], op)
93
+ # else
94
+ # isnopidentity(ls, op, u₁loop, u₂loop, vectorized, u₂max, nothing) || push!(lo[isunrolled,1,after_loop,_n], op)
95
+ # end
96
+ end
97
+ # @show op, after_loop
98
+ # isloopvalue(op) || push!(lo[isunrolled,istiled,after_loop,_n], op)
99
+ # all(opp -> iszero(length(reduceddependencies(opp))), parents(op)) &&
100
+ set_upstream_family! (place_after_loop, op, false , loopdependencies (op), identifier (op)) # parents that have already been included are not moved, so no need to check included_vars to filter
101
+ nothing
102
+ end
103
+ function replace_reduct_init! (ls:: LoopSet , op:: Operation , opsub:: Operation , opcheck:: Operation )
104
+ deleteat! (parents (op), 2 )
105
+ op. variable = opcheck. variable
106
+ opsub. variable = opcheck. variable
107
+ op. mangledvariable = opcheck. mangledvariable
108
+ opsub. mangledvariable = opcheck. mangledvariable
109
+ op. instruction = instruction (:identity )
110
+ fill_children! (ls)
111
+ end
112
+ function nounrollreduction (op:: Operation , u₁loop:: Symbol , u₂loop:: Symbol , vectorized:: Symbol )
113
+ reduceddeps = reduceddependencies (op)
114
+ (vectorized ∉ reduceddeps) &&
115
+ (u₁loop ∉ reduceddeps) &&
116
+ (u₂loop ∉ reduceddeps)
117
+ end
118
+ function load_short_static_reduction_first! (ls:: LoopSet , u₁loop:: Symbol , u₂loop:: Symbol , vectorized:: Symbol )
119
+ for op ∈ operations (ls)
120
+ iscompute (op) || continue
121
+ length (reduceddependencies (op)) == 0 && continue
122
+ if (instruction (op). instr === :reduced_add )
123
+ vecloop = getloop (ls, vectorized)
124
+ if isstaticloop (vecloop) && (length (vecloop) ≤ 16 ) && nounrollreduction (op, u₁loop, u₂loop, vectorized)
125
+ opsub = parents (op)[2 ]
126
+ length (children (opsub)) == 1 || continue
127
+ opsearch = parents (op)[1 ]
128
+ opcheck = search_for_reductinit! (opsearch, opsub, name (opsearch), loopdependencies (op))
129
+ opcheck === opsearch || replace_reduct_init! (ls, op, opsub, opcheck)
130
+
131
+ end
132
+ elseif (instruction (op). instr === :add_fast ) && (instruction (first (parents (op))). instr === :identity )
133
+ vecloop = getloop (ls, vectorized)
134
+ if isstaticloop (vecloop) && (length (vecloop) ≤ 16 ) && nounrollreduction (op, u₁loop, u₂loop, vectorized)
135
+ opsub = parents (op)[2 ]
136
+ ((length (reduceddependencies (opsub)) == 0 ) & (length (children (opsub)) == 1 )) || continue
137
+ opsearch = parents (op)[1 ]
138
+ opcheck = search_for_reductinit! (opsearch, opsub, name (opsearch), loopdependencies (op))
139
+ opcheck === opsearch || replace_reduct_init! (ls, op, opsub, opcheck)
140
+ end
84
141
end
85
- # @show op, after_loop
86
- # isloopvalue(op) || push!(lo[isunrolled,istiled,after_loop,_n], op)
87
- # all(opp -> iszero(length(reduceddependencies(opp))), parents(op)) &&
88
- set_upstream_family! (place_after_loop, op, false , loopdependencies (op), identifier (op)) # parents that have already been included are not moved, so no need to check included_vars to filter
89
- nothing
142
+ end
90
143
end
91
144
92
145
function fillorder! (ls:: LoopSet , order:: Vector{Symbol} , u₁loop:: Symbol , u₂loop:: Symbol , u₂max:: Int , vectorized:: Symbol )
146
+ load_short_static_reduction_first! (ls, u₁loop, u₂loop, vectorized)
93
147
lo = ls. loop_order
94
148
resize! (lo, length (ls. loopsymbols))
95
149
ro = lo. loopnames # reverse order; will have same order as lo
0 commit comments