Skip to content

Commit 1dafd32

Browse files
committed
2 parents 7a76bf1 + 75b4152 commit 1dafd32

File tree

4 files changed

+51
-8
lines changed

4 files changed

+51
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1515
[compat]
1616
DocStringExtensions = "0.8"
1717
OffsetArrays = "1"
18-
SIMDPirates = "0.8.4"
18+
SIMDPirates = "0.8.6"
1919
SLEEFPirates = "0.5"
2020
UnPack = "0,1"
2121
VectorizationBase = "0.12.6"

src/determinestrategy.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ function unroll_no_reductions(ls, order, unrolled, vectorized, Wshift, size_T)
186186
# @show compute_rt, load_rt
187187
# roundpow2(min(4, round(Int, (compute_rt + load_rt + 1) / compute_rt)))
188188
rt = max(compute_rt, load_rt)
189-
rt == 0.0 && return 4
189+
iszero(rt) && return 4
190190
max(1, roundpow2( min( 4, round(Int, 16 / rt) ) ))
191191
end
192192
function determine_unroll_factor(
@@ -286,9 +286,10 @@ function solve_unroll(X, R, u₁L, u₂L, u₁step, u₂step)
286286
discriminant < 0 && return -1,-1,Inf
287287
u₁float = max(float(u₁step), (sqrt(discriminant) + b) / (-2a)) # must be at least 1
288288
u₂float = (RR - u₁float*R₂)/(u₁float*R₁)
289-
if !(isfinite(u₂float) && isfinite(u₁float))
290-
return 4, 4, unroll_cost(X, 4, 4, u₁L, u₂L)
291-
# return itertilesize(X, u₁L, u₂L)
289+
if !(isfinite(u₂float) & isfinite(u₁float)) # brute force
290+
u₁low = u₂low = 1
291+
u₁high = u₂high = REGISTER_COUNT == 32 ? 10 : 6#8
292+
return solve_unroll_iter(X, R, u₁L, u₂L, u₁low:u₁step:u₁high, u₂low:u₂step:u₂high)
292293
end
293294
u₁low = floor(Int, u₁float)
294295
u₂low = max(u₂step, floor(Int, u₂float)) # must be at least 1
@@ -564,6 +565,13 @@ function load_elimination_cost_factor!(
564565
false
565566
end
566567
end
568+
function loadintostore(ls::LoopSet, op::Operation)
569+
# isload(op) || return false # leads to bad behavior more than it helps
570+
# for opp ∈ operations(ls)
571+
# isstore(opp) && opp.ref == op.ref && return true
572+
# end
573+
false
574+
end
567575
function add_constant_offset_load_elmination_cost!(
568576
X, R, choose_to_inline, ls::LoopSet, op::Operation, iters, unrollsyms::UnrollSymbols, u₁reduces::Bool, u₂reduces::Bool, Wshift::Int, size_T::Int, opisininnerloop::Bool
569577
)
@@ -575,6 +583,9 @@ function add_constant_offset_load_elmination_cost!(
575583
rt, lat, rp = cost(ls, op, vectorized, Wshift, size_T)
576584
rt *= iters
577585
rp = opisininnerloop ? rp : zero(rp)
586+
# if loadintostore(ls, op) # For now, let's just avoid unrolling in this way...
587+
# rt = Inf
588+
# end
578589
# u_uid is getting eliminated
579590
# we treat this as the unrolled loop getting eliminated is split into 2 parts:
580591
# 1 a non-cost-reduced part, with factor udependent_reduction
@@ -700,7 +711,8 @@ function evaluate_cost_tile(
700711
prefetch_good_idea = true
701712
end
702713
# @show isunrolled₁, isunrolled₂, op rt, lat, rp
703-
rp = opisininnerloop ? rp : zero(rp) # we only care about register pressure within the inner most loop
714+
rp = (opisininnerloop && !(loadintostore(ls, op))) ? rp : zero(rp) # we only care about register pressure within the inner most loop
715+
# rp = opisininnerloop ? rp : zero(rp) # we only care about register pressure within the inner most loop
704716
rt *= iters[id]
705717
if u₁reduces & u₂reduces
706718
cost_vec[4] += rt

src/mapreduce.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,28 @@ Vectorized version of `reduce`. Reduces the array `A` using the operator `op`.
7979
"""
8080
@inline vreduce(op, arg) = vmapreduce(identity, op, arg)
8181

82+
for (op, init) in zip((:+, :max, :min), (:zero, :identity, :identity))
83+
@eval function vreduce(::typeof($op), arg; dims = nothing)
84+
isnothing(dims) && return _vreduce($op, arg)
85+
@assert length(dims) == 1
86+
out = $init(arg[ntuple(d -> d == dims ? (1:1) : (1:size(arg, d)), ndims(arg))...])
87+
Rpre = CartesianIndices(axes(arg)[1:dims-1])
88+
Rpost = CartesianIndices(axes(arg)[dims+1:end])
89+
_vreduce_dims!(out, $op, Rpre, 1:size(arg, dims), Rpost, arg)
90+
end
91+
92+
@eval function _vreduce_dims!(out, ::typeof($op), Rpre, is, Rpost, arg)
93+
@avx for Ipost in Rpost, i in is, Ipre in Rpre
94+
out[Ipre, 1, Ipost] = $op(out[Ipre, 1, Ipost], arg[Ipre, i, Ipost])
95+
end
96+
return out
97+
end
98+
99+
@eval function _vreduce(::typeof($op), arg)
100+
s = $init(arg[1])
101+
@avx for i in 1:length(arg)
102+
s = $op(s, arg[i])
103+
end
104+
return s
105+
end
106+
end

test/mapreduce.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
if T <: Integer
1414
R = T(1):T(100)
1515
x7 = rand(R, 7); y7 = rand(R, 7);
16-
x = rand(R, 127); y = rand(R, 127);
16+
x = rand(R, 127, 7, 7); y = rand(R, 127, 7, 7);
1717
else
1818
x7 = rand(T, 7); y7 = rand(T, 7);
19-
x = rand(T, 127); y = rand(T, 127);
19+
x = rand(T, 127, 7, 7); y = rand(T, 127, 7, 7);
2020
if VERSION v"1.4"
2121
@test vmapreduce(hypot, +, x, y) mapreduce(hypot, +, x, y)
2222
@test vmapreduce(^, (a,b) -> a + b, x7, y7) mapreduce(^, +, x7, y7)
@@ -38,6 +38,12 @@
3838
@test vmapreduce(log, +, x) sum(log, x)
3939
@test vmapreduce(abs2, +, x) sum(abs2, x)
4040
@test maximum(x) == vreduce(max, x) == maximum_avx(x)
41+
42+
for d in 1:ndims(x)
43+
@test vreduce(max, x; dims = d) maximum(x; dims = d)
44+
@test vreduce(min, x; dims = d) minimum(x; dims = d)
45+
@test vreduce(+, x; dims = d) sum(x; dims = d)
46+
end
4147
end
4248

4349
end

0 commit comments

Comments
 (0)