@@ -9,15 +9,16 @@ struct UnrollArgs{T <: Union{Nothing,Int}}
9
9
u₁loopsym:: Symbol
10
10
u₂loopsym:: Symbol
11
11
vectorized:: Symbol
12
+ u₂max:: Int
12
13
suffix:: T
13
14
end
14
- function UnrollArgs (U :: Int , unrollsyms:: UnrollSymbols , suffix)
15
+ function UnrollArgs (u₁ :: Int , unrollsyms:: UnrollSymbols , u₂max :: Int , suffix)
15
16
@unpack u₁loopsym, u₂loopsym, vectorized = unrollsyms
16
- UnrollArgs (U , u₁loopsym, u₂loopsym, vectorized, suffix)
17
+ UnrollArgs (u₁ , u₁loopsym, u₂loopsym, vectorized, u₂max , suffix)
17
18
end
18
19
function UnrollArgs (ua:: UnrollArgs , u:: Int )
19
- @unpack u₁loopsym, u₂loopsym, vectorized, suffix = ua
20
- UnrollArgs (u, u₁loopsym, u₂loopsym, vectorized, suffix)
20
+ @unpack u₁loopsym, u₂loopsym, vectorized, u₂max, suffix = ua
21
+ UnrollArgs (u, u₁loopsym, u₂loopsym, vectorized, u₂max, suffix)
21
22
end
22
23
# UnrollSymbols(ua::UnrollArgs) = UnrollSymbols(ua.u₁loopsym, ua.u₂loopsym, ua.vectorized)
23
24
@@ -69,14 +70,8 @@ function Loop(itersymbol::Symbol, start::Union{Int,Symbol}, stop::Union{Int,Symb
69
70
end
70
71
Base. length (loop:: Loop ) = 1 + loop. stophint - loop. starthint
71
72
isstaticloop (loop:: Loop ) = loop. startexact & loop. stopexact
72
- function startloop (loop:: Loop , isvectorized, itersymbol)
73
+ function startloop (loop:: Loop , itersymbol)
73
74
startexact = loop. startexact
74
- # if isvectorized
75
- # if startexact
76
- # Expr(:(=), itersymbol, Expr(:call, lv(:_MM), VECTORWIDTHSYMBOL, loop.starthint))
77
- # else
78
- # Expr(:(=), itersymbol, Expr(:call, lv(:_MM), VECTORWIDTHSYMBOL, loop.startsym))
79
- # end
80
75
if startexact
81
76
Expr (:(= ), itersymbol, loop. starthint - 1 )
82
77
else
@@ -97,31 +92,44 @@ addexpr(ex::Number, incr::Number) = ex + incr
97
92
subexpr (ex, incr) = Expr (:call , lv (:vsub ), ex, incr)
98
93
subexpr (ex:: Number , incr:: Number ) = ex - incr
99
94
subexpr (ex, incr:: Number ) = addexpr (ex, - incr)
100
- function vec_looprange (loop:: Loop , UF:: Int , mangledname:: Symbol )
95
+
96
+ staticmulincr (ptr, incr) = Expr (:call , lv (:staticmul ), Expr (:call , :eltype , ptr), incr)
97
+ callpointer (sym) = Expr (:call , :pointer , sym)
98
+ function vec_looprange (loopmax, UF:: Int , mangledname:: Symbol , ptrcomp:: Bool )
101
99
incr = if isone (UF)
102
100
Expr (:call , lv (:valsub ), VECTORWIDTHSYMBOL, 1 )
103
101
else
104
102
Expr (:call , lv (:valmulsub ), VECTORWIDTHSYMBOL, UF, 1 )
105
103
end
106
- if loop. stopexact # split for type stability
107
- Expr (:call , :< , mangledname, subexpr (loop. stophint, incr))
104
+ incr = ptrcomp ? staticmulincr (mangledname, incr) : incr
105
+ compexpr = subexpr (loopmax, incr)
106
+ if ptrcomp
107
+ Expr (:call , :< , callpointer (mangledname), compexpr)
108
108
else
109
- Expr (:call , :< , mangledname, subexpr (loop . stopsym, incr) )
109
+ Expr (:call , :< , mangledname, compexpr )
110
110
end
111
111
end
112
112
113
- function looprange (stopcon, incr:: Int , mangledname:: Symbol )
113
+ function looprange (stopcon, incr:: Int , mangledname:: Symbol , ptrcomp :: Bool )
114
114
incr = 1 - incr
115
115
if iszero (incr)
116
- Expr (:call , :< , mangledname, stopcon)
117
- elseif isone (incr)
118
- Expr (:call , :≤ , mangledname, stopcon)
116
+ if ptrcomp
117
+ Expr (:call , :< , callpointer (mangledname), stopcon)
118
+ else
119
+ Expr (:call , :< , mangledname, stopcon)
120
+ end
121
+ elseif ptrcomp
122
+ Expr (:call , :< , callpointer (mangledname), addexpr (stopcon, staticmulincr (mangledname, incr)))
119
123
else
120
- Expr (:call , :< , mangledname, addexpr (stopcon, incr))
124
+ if isone (incr)
125
+ Expr (:call , :≤ , mangledname, stopcon)
126
+ else
127
+ Expr (:call , :< , mangledname, addexpr (stopcon, incr))
128
+ end
121
129
end
122
130
end
123
131
function looprange (loop:: Loop , incr:: Int , mangledname:: Symbol )
124
- loop. stopexact ? looprange (loop. stophint, incr, mangledname) : looprange (loop. stopsym, incr, mangledname)
132
+ loop. stopexact ? looprange (loop. stophint, incr, mangledname, false ) : looprange (loop. stopsym, incr, mangledname, false )
125
133
end
126
134
function terminatecondition (
127
135
loop:: Loop , us:: UnrollSpecification , n:: Int , mangledname:: Symbol , inclmask:: Bool , UF:: Int = unrollfactor (us, n)
@@ -130,22 +138,36 @@ function terminatecondition(
130
138
looprange (loop, UF, mangledname)
131
139
elseif inclmask
132
140
looprange (loop, 1 , mangledname)
141
+ elseif loop. stopexact
142
+ vec_looprange (loop. stophint, UF, mangledname, false ) # may not be u₂loop
133
143
else
134
- vec_looprange (loop, UF, mangledname) # may not be u₂loop
144
+ vec_looprange (loop. stopsym , UF, mangledname, false ) # may not be u₂loop
135
145
end
136
146
end
137
147
function incrementloopcounter (us:: UnrollSpecification , n:: Int , mangledname:: Symbol , UF:: Int = unrollfactor (us, n))
138
148
if isvectorized (us, n)
139
149
if UF == 1
140
150
Expr (:(= ), mangledname, Expr (:call , lv (:valadd ), VECTORWIDTHSYMBOL, mangledname))
141
- # Expr(:(=), mangledname, Expr(:macrocall, Symbol("@show"), LineNumberNode(@__LINE__,Symbol(@__FILE__)), Expr(:call, lv(:valadd), VECTORWIDTHSYMBOL, mangledname)))
142
151
else
143
152
Expr (:(= ), mangledname, Expr (:call , lv (:valmuladd ), VECTORWIDTHSYMBOL, UF, mangledname))
144
153
end
145
154
else
146
155
Expr (:(= ), mangledname, Expr (:call , lv (:vadd ), mangledname, UF))
147
156
end
148
157
end
158
+ function incrementloopcounter! (q, us:: UnrollSpecification , n:: Int , UF:: Int = unrollfactor (us, n))
159
+ if isvectorized (us, n)
160
+ if UF == 1
161
+ push! (q. args, Expr (:call , lv (:unwrap ), VECTORWIDTHSYMBOL))
162
+ else
163
+ push! (q. args, Expr (:call , lv (:valmul ), VECTORWIDTHSYMBOL, UF))
164
+ end
165
+ elseif isone (UF)
166
+ push! (q. args, Expr (:call , Expr (:curly , lv (:Static ), UF)))
167
+ else
168
+ push! (q. args, UF)
169
+ end
170
+ end
149
171
150
172
# load/compute/store × isunrolled × istiled × pre/post loop × Loop number
151
173
struct LoopOrder <: AbstractArray{Vector{Operation},5}
@@ -294,19 +316,22 @@ function LoopSet(mod::Symbol)
294
316
)
295
317
end
296
318
319
+ cacheunrolled! (ls:: LoopSet , u₁loop, u₂loop, vectorized) = foreach (op -> setunrolled! (op, u₁loop, u₂loop, vectorized), operations (ls))
320
+
297
321
num_loops (ls:: LoopSet ) = length (ls. loops)
298
322
function oporder (ls:: LoopSet )
299
323
N = length (ls. loop_order. loopnames)
300
324
reshape (ls. loop_order. oporder, (2 ,2 ,2 ,N))
301
325
end
302
326
names (ls:: LoopSet ) = ls. loop_order. loopnames
327
+ reversenames (ls:: LoopSet ) = ls. loop_order. bestorder
303
328
function getloopid (ls:: LoopSet , s:: Symbol ):: Int
304
329
for (loopnum,sym) ∈ enumerate (ls. loopsymbols)
305
330
s === sym && return loopnum
306
331
end
307
332
end
308
333
getloop (ls:: LoopSet , s:: Symbol ) = ls. loops[getloopid (ls, s)]
309
- getloop (ls:: LoopSet , i:: Integer ) = ls. loops[i]
334
+ # getloop(ls::LoopSet, i::Integer) = ls.loops[i]
310
335
getloopsym (ls:: LoopSet , i:: Integer ) = ls. loopsymbols[i]
311
336
Base. length (ls:: LoopSet , s:: Symbol ) = length (getloop (ls, s))
312
337
0 commit comments