48
48
function lower_zero! (
49
49
q:: Expr , op:: Operation , ls:: LoopSet , ua:: UnrollArgs , zerotyp:: NumberType = zerotype (ls, op)
50
50
)
51
- @unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = ua
52
- mvar, opu₁, opu₂ = variable_name_and_unrolled (op, u₁loopsym, u₂loopsym, vloopsym, suffix, ls)
53
- ! opu₂ && suffix > 0 && return
54
- # TODO : for u₁, needs to consider if reducedchildren are u₁-unrolled
55
- # reductions need to consider reduct-status
56
- # if !opu₁
57
- # opu₁ = u₁loopsym ∈ reducedchildren(op)
58
- # end
59
- mvar = Symbol (mvar, ' _' , Core. ifelse (opu₁, u₁, 1 ))
60
- typeT = typeof_sym (ls, op, zerotyp)
61
- # TODO : make should_broadcast_op handle everything.
62
- if isvectorized (op) || vloopsym ∈ reducedchildren (op) || vloopsym ∈ reduceddependencies (op) || should_broadcast_op (op)
63
- if opu₁ && u₁ > 1
64
- call = Expr (:call , lv (:zero_vecunroll ), staticexpr (u₁), VECTORWIDTHSYMBOL, typeT, staticexpr (reg_size (ls)))
65
- else
66
- call = Expr (:call , lv (:_vzero ), VECTORWIDTHSYMBOL, typeT, staticexpr (reg_size (ls)))
67
- end
51
+ @unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = ua
52
+ mvar, opu₁, opu₂ = variable_name_and_unrolled (op, u₁loopsym, u₂loopsym, vloopsym, suffix, ls)
53
+ ! opu₂ && suffix > 0 && return
54
+ # TODO : for u₁, needs to consider if reducedchildren are u₁-unrolled
55
+ # reductions need to consider reduct-status
56
+ # if !opu₁
57
+ # opu₁ = u₁loopsym ∈ reducedchildren(op)
58
+ # end
59
+ typeT = typeof_sym (ls, op, zerotyp)
60
+ # TODO : make should_broadcast_op handle everything.
61
+ if isvectorized (op) || vloopsym ∈ reducedchildren (op) || vloopsym ∈ reduceddependencies (op) || should_broadcast_op (op)
62
+ if opu₁ && u₁ > 1
63
+ call = Expr (:call , lv (:zero_vecunroll ), staticexpr (u₁), VECTORWIDTHSYMBOL, typeT, staticexpr (reg_size (ls)))
68
64
else
69
- call = Expr (:call , :zero , typeT)
70
- if opu₁ && u₁ > 1
71
- # broadcastsym = Symbol(mvar, "_#init#")
72
- # pushpreamble!(ls, Expr(:(=), broadcastsym, call))
73
- t = Expr (:tuple )
74
- for u ∈ 1 : u₁
75
- push! (t. args, call)
76
- end
77
- call = Expr (:call , lv (:VecUnroll ), t)
78
- end
65
+ call = Expr (:call , lv (:_vzero ), VECTORWIDTHSYMBOL, typeT, staticexpr (reg_size (ls)))
66
+ end
67
+ else
68
+ call = Expr (:call , :zero , typeT)
69
+ if opu₁ && u₁ > 1
70
+ t = Expr (:tuple )
71
+ for u ∈ 1 : u₁
72
+ push! (t. args, call)
73
+ end
74
+ call = Expr (:call , lv (:VecUnroll ), t)
75
+ end
76
+ end
77
+ if (suffix == - 1 ) && opu₂
78
+ for u ∈ 0 : u₂max- 1
79
+ push! (q. args, Expr (:(= ), Symbol (mvar, u, " __" , Core. ifelse (opu₁, u₁, 1 )), call))
79
80
end
81
+ else
82
+ mvar = Symbol (mvar, ' _' , Core. ifelse (opu₁, u₁, 1 ))
80
83
push! (q. args, Expr (:(= ), mvar, call))
81
- nothing
84
+ end
85
+ nothing
82
86
end
83
87
# Have to awkwardly search through `operations(ls)` to try and find op's child
84
88
function getparentsreductzero (ls:: LoopSet , op:: Operation ):: Float64
@@ -95,52 +99,65 @@ vecbasefunc(f) = Expr(:(.), Expr(:(.), :LoopVectorization, QuoteNode(:Vectorizat
95
99
function lower_constant! (
96
100
q:: Expr , op:: Operation , ls:: LoopSet , ua:: UnrollArgs
97
101
)
98
- @unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = ua
99
- mvar, opu₁, opu₂ = variable_name_and_unrolled (op, u₁loopsym, u₂loopsym, vloopsym, suffix, ls)
100
- ! opu₂ && suffix > 0 && return
101
- mvar = Symbol (mvar, ' _' , Core. ifelse (opu₁, u₁, 1 ))
102
- instruction = op. instruction
103
- constsym = instruction. instr
104
- # constsym = Symbol(instruction.instr, '_', 1)
105
- reducedchildvectorized = vloopsym ∈ reducedchildren (op)
106
- if reducedchildvectorized || isvectorized (op) || vloopsym ∈ reduceddependencies (op) || should_broadcast_op (op)
107
- # call = Expr(:call, lv(:vbroadcast), W, Expr(:call, lv(:maybeconvert), typeT, constsym))
108
- call = if reducedchildvectorized && vloopsym ∉ loopdependencies (op)
109
- instrclass = getparentsreductzero (ls, op)
110
- if instrclass == ADDITIVE_IN_REDUCTIONS
111
- Expr (:call , vecbasefunc (:addscalar ), Expr (:call , lv (:vzero ), VECTORWIDTHSYMBOL, ELTYPESYMBOL), constsym)
112
- elseif instrclass == MULTIPLICATIVE_IN_REDUCTIONS
113
- Expr (:call , vecbasefunc (:mulscalar ), Expr (:call , lv (:vbroadcast ), VECTORWIDTHSYMBOL, Expr (:call , :one , ELTYPESYMBOL)), constsym)
114
- elseif instrclass == MAX
115
- Expr (:call , vecbasefunc (:maxscalar ), Expr (:call , lv (:vbroadcast ), VECTORWIDTHSYMBOL, Expr (:call , :typemin , ELTYPESYMBOL)), constsym)
116
- elseif instrclass == MIN
117
- Expr (:call , vecbasefunc (:minscalar ), Expr (:call , lv (:vbroadcast ), VECTORWIDTHSYMBOL, Expr (:call , :typemax , ELTYPESYMBOL)), constsym)
118
- else
119
- throw (" Reductions of type $(reduction_zero (reinstrclass)) not yet supported; please file an issue as a reminder to take care of this." )
120
- end
121
- else
122
- Expr (:call , lv (:vbroadcast ), VECTORWIDTHSYMBOL, constsym)
123
- end
124
- if opu₁ && u₁ > 1
125
- # broadcastsym = Symbol(mvar, "_#init#")
126
- # push!(q.args, Expr(:(=), broadcastsym, call))
127
- t = Expr (:tuple )
128
- for u ∈ 1 : u₁
129
- push! (t. args, call)
130
- end
131
- call = Expr (:call , lv (:VecUnroll ), t)
132
- end
133
- push! (q. args, Expr (:(= ), mvar, call))
134
- elseif opu₁ && u₁ > 1
135
- t = Expr (:tuple )
136
- for u ∈ 1 : u₁
137
- push! (t. args, constsym)
138
- end
139
- push! (q. args, Expr (:(= ), mvar, Expr (:call , lv (:VecUnroll ), t)))
102
+ @unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = ua
103
+ mvar, opu₁, opu₂ = variable_name_and_unrolled (op, u₁loopsym, u₂loopsym, vloopsym, suffix, ls)
104
+ ! opu₂ && suffix > 0 && return
105
+ instruction = op. instruction
106
+ constsym = instruction. instr
107
+ # constsym = Symbol(instruction.instr, '_', 1)
108
+ reducedchildvectorized = vloopsym ∈ reducedchildren (op)
109
+ if reducedchildvectorized || isvectorized (op) || vloopsym ∈ reduceddependencies (op) || should_broadcast_op (op)
110
+ # call = Expr(:call, lv(:vbroadcast), W, Expr(:call, lv(:maybeconvert), typeT, constsym))
111
+ call = if reducedchildvectorized && vloopsym ∉ loopdependencies (op)
112
+ instrclass = getparentsreductzero (ls, op)
113
+ if instrclass == ADDITIVE_IN_REDUCTIONS
114
+ Expr (:call , vecbasefunc (:addscalar ), Expr (:call , lv (:vzero ), VECTORWIDTHSYMBOL, ELTYPESYMBOL), constsym)
115
+ elseif instrclass == MULTIPLICATIVE_IN_REDUCTIONS
116
+ Expr (:call , vecbasefunc (:mulscalar ), Expr (:call , lv (:vbroadcast ), VECTORWIDTHSYMBOL, Expr (:call , :one , ELTYPESYMBOL)), constsym)
117
+ elseif instrclass == MAX
118
+ Expr (:call , vecbasefunc (:maxscalar ), Expr (:call , lv (:vbroadcast ), VECTORWIDTHSYMBOL, Expr (:call , :typemin , ELTYPESYMBOL)), constsym)
119
+ elseif instrclass == MIN
120
+ Expr (:call , vecbasefunc (:minscalar ), Expr (:call , lv (:vbroadcast ), VECTORWIDTHSYMBOL, Expr (:call , :typemax , ELTYPESYMBOL)), constsym)
121
+ else
122
+ throw (" Reductions of type $(reduction_zero (reinstrclass)) not yet supported; please file an issue as a reminder to take care of this." )
123
+ end
140
124
else
141
- push! (q . args, Expr (:( = ), mvar , constsym) )
125
+ Expr ( :call , lv ( :vbroadcast ), VECTORWIDTHSYMBOL , constsym)
142
126
end
143
- nothing
127
+ if opu₁ && u₁ > 1
128
+ # broadcastsym = Symbol(mvar, "_#init#")
129
+ # push!(q.args, Expr(:(=), broadcastsym, call))
130
+ t = Expr (:tuple )
131
+ for u ∈ 1 : u₁
132
+ push! (t. args, call)
133
+ end
134
+ call = Expr (:call , lv (:VecUnroll ), t)
135
+ end
136
+ elseif opu₁ && u₁ > 1
137
+ t = Expr (:tuple )
138
+ for u ∈ 1 : u₁
139
+ push! (t. args, constsym)
140
+ end
141
+ call = Expr (:call , lv (:VecUnroll ), t)
142
+ elseif opu₂ & (suffix == - 1 )
143
+ for u ∈ 0 : u₂max- 1
144
+ push! (q. args, Expr (:(= ), Symbol (mvar, u, " __" , 1 ), constsym))
145
+ end
146
+ return nothing
147
+ else
148
+ push! (q. args, Expr (:(= ), Symbol (mvar, ' _' , 1 ), constsym))
149
+ return nothing
150
+ end
151
+ u₁tag = Core. ifelse (opu₁, u₁, 1 )
152
+ if opu₂ & (suffix == - 1 )
153
+ for u ∈ 0 : u₂max- 1
154
+ push! (q. args, Expr (:(= ), Symbol (mvar, u, " __" , u₁tag), call))
155
+ end
156
+ else
157
+ mvar = Symbol (mvar, ' _' , u₁tag)
158
+ push! (q. args, Expr (:(= ), mvar, call))
159
+ end
160
+ nothing
144
161
end
145
162
146
163
isconstantop (op:: Operation ) = (instruction (op) === LOOPCONSTANT) || (isconstant (op) && length (loopdependencies (op)) == 0 )
0 commit comments