Skip to content

Commit 1ac177d

Browse files
committed
Add option for manually specifying unroll or tile, tweaked unrolling in absence of reductions.
1 parent 81b41c9 commit 1ac177d

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

src/constructors.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,20 @@ macro avx(q)
9898
end
9999
esc(q2)
100100
end
101-
101+
macro avx(arg, q)
102+
@assert q.head === :for
103+
@assert arg.head === :(=)
104+
local U::Int, T::Int
105+
if arg.args[1] === :unroll
106+
U = arg.args[2]
107+
T = -1
108+
elseif arg.args[1] === :tile
109+
tup = arg.args[2]
110+
@assert tup.head === :tuple
111+
U = tup.args[1]
112+
T = tup.args[2]
113+
end
114+
esc(lower(LoopSet(q), U, T))
115+
end
102116

103117

src/determinestrategy.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,10 @@ function unroll_no_reductions(ls, order, vectorized, Wshift, size_T)
131131
end
132132
end
133133
# heuristic guess
134-
# @show compute_rt, load_rt
135-
roundpow2(min(4, round(Int, (compute_rt + load_rt + 1) / compute_rt)))
134+
@show compute_rt, load_rt
135+
# roundpow2(min(4, round(Int, (compute_rt + load_rt + 1) / compute_rt)))
136+
rt = max(compute_rt, load_rt)
137+
roundpow2( min( 4, round(Int, 16 / rt) ) )
136138
end
137139
function determine_unroll_factor(
138140
ls::LoopSet, order::Vector{Symbol}, unrolled::Symbol, vectorized::Symbol = first(order)
@@ -151,7 +153,7 @@ function determine_unroll_factor(
151153
if iszero(num_reductions)
152154
# if only 1 loop, no need to unroll
153155
# if more than 1 loop, there is some cost. Picking 2 here as a heuristic.
154-
return length(order) == 1 ? 1 : unroll_no_reductions(ls, order, vectorized, Wshift, size_T)
156+
return unroll_no_reductions(ls, order, vectorized, Wshift, size_T)
155157
end
156158
# So if num_reductions > 0, we set the unroll factor to be high enough so that the CPU can be kept busy
157159
# if there are, U = max(1, round(Int, max(latency) * throughput / num_reductions)) = max(1, round(Int, latency / (recip_throughput * num_reductions)))

src/lowering.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,13 @@ function lower(ls::LoopSet)
888888
fillorder!(ls, order, istiled)
889889
istiled ? lower_tiled(ls, vectorized, U, T) : lower_unrolled(ls, vectorized, U)
890890
end
891+
function lower(ls::LoopSet, U, T = -1)
892+
num_loops(ls) == 1 && @assert T == -1
893+
order, vectorized, _U, _T = choose_order(ls)
894+
istiled = T != -1
895+
fillorder!(ls, order, istiled)
896+
istiled ? lower_tiled(ls, vectorized, U, T) : lower_unrolled(ls, vectorized, U)
897+
end
891898

892899
Base.convert(::Type{Expr}, ls::LoopSet) = lower(ls)
893900
Base.show(io::IO, ls::LoopSet) = println(io, lower(ls))

0 commit comments

Comments
 (0)