Skip to content

Commit eae7620

Browse files
committed
Fix broken precompilation warning, and bump SIMDPirates requirement.
1 parent a692739 commit eae7620

File tree

3 files changed

+22
-9
lines changed

3 files changed

+22
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1515
[compat]
1616
DocStringExtensions = "0.8"
1717
OffsetArrays = "1"
18-
SIMDPirates = "0.8.4"
18+
SIMDPirates = "0.8.6"
1919
SLEEFPirates = "0.5"
2020
UnPack = "0,1"
2121
VectorizationBase = "0.12.6"

src/determinestrategy.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ function unroll_no_reductions(ls, order, unrolled, vectorized, Wshift, size_T)
186186
# @show compute_rt, load_rt
187187
# roundpow2(min(4, round(Int, (compute_rt + load_rt + 1) / compute_rt)))
188188
rt = max(compute_rt, load_rt)
189-
rt == 0.0 && return 4
189+
iszero(rt) && return 4
190190
max(1, roundpow2( min( 4, round(Int, 16 / rt) ) ))
191191
end
192192
function determine_unroll_factor(
@@ -286,9 +286,10 @@ function solve_unroll(X, R, u₁L, u₂L, u₁step, u₂step)
286286
discriminant < 0 && return -1,-1,Inf
287287
u₁float = max(float(u₁step), (sqrt(discriminant) + b) / (-2a)) # must be at least 1
288288
u₂float = (RR - u₁float*R₂)/(u₁float*R₁)
289-
if !(isfinite(u₂float) && isfinite(u₁float))
290-
return 4, 4, unroll_cost(X, 4, 4, u₁L, u₂L)
291-
# return itertilesize(X, u₁L, u₂L)
289+
if !(isfinite(u₂float) & isfinite(u₁float)) # brute force
290+
u₁low = u₂low = 1
291+
u₁high = u₂high = REGISTER_COUNT == 32 ? 10 : 6#8
292+
return solve_unroll_iter(X, R, u₁L, u₂L, u₁low:u₁step:u₁high, u₂low:u₂step:u₂high)
292293
end
293294
u₁low = floor(Int, u₁float)
294295
u₂low = max(u₂step, floor(Int, u₂float)) # must be at least 1
@@ -564,6 +565,13 @@ function load_elimination_cost_factor!(
564565
false
565566
end
566567
end
568+
function loadintostore(ls::LoopSet, op::Operation)
569+
isload(op) || return false
570+
for opp operations(ls)
571+
isstore(opp) && opp.ref == op.ref && return true
572+
end
573+
false
574+
end
567575
function add_constant_offset_load_elmination_cost!(
568576
X, R, choose_to_inline, ls::LoopSet, op::Operation, iters, unrollsyms::UnrollSymbols, u₁reduces::Bool, u₂reduces::Bool, Wshift::Int, size_T::Int, opisininnerloop::Bool
569577
)
@@ -575,6 +583,9 @@ function add_constant_offset_load_elmination_cost!(
575583
rt, lat, rp = cost(ls, op, vectorized, Wshift, size_T)
576584
rt *= iters
577585
rp = opisininnerloop ? rp : zero(rp)
586+
# if loadintostore(ls, op) # For now, let's just avoid unrolling in this way...
587+
# rt = Inf
588+
# end
578589
# u_uid is getting eliminated
579590
# we treat this as the unrolled loop getting eliminated is split into 2 parts:
580591
# 1 a non-cost-reduced part, with factor udependent_reduction
@@ -700,7 +711,8 @@ function evaluate_cost_tile(
700711
prefetch_good_idea = true
701712
end
702713
# @show isunrolled₁, isunrolled₂, op rt, lat, rp
703-
rp = opisininnerloop ? rp : zero(rp) # we only care about register pressure within the inner most loop
714+
rp = (opisininnerloop && !(loadintostore(ls, op))) ? rp : zero(rp) # we only care about register pressure within the inner most loop
715+
# rp = opisininnerloop ? rp : zero(rp) # we only care about register pressure within the inner most loop
704716
rt *= iters[id]
705717
if u₁reduces & u₂reduces
706718
cost_vec[4] += rt

src/mapreduce.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ Vectorized version of `reduce`. Reduces the array `A` using the operator `op`.
8080
@inline vreduce(op, arg) = vmapreduce(identity, op, arg)
8181

8282
for (op, init) in zip((:+, :max, :min), (:zero, :identity, :identity))
83-
@eval function vreduce(::typeof($op), arg; dims)
83+
@eval function vreduce(::typeof($op), arg; dims = nothing)
84+
isnothing(dims) && return _vreduce($op, arg)
8485
@assert length(dims) == 1
8586
out = $init(arg[ntuple(d -> d == dims ? (1:1) : (1:size(arg, d)), ndims(arg))...])
8687
Rpre = CartesianIndices(axes(arg)[1:dims-1])
@@ -95,11 +96,11 @@ for (op, init) in zip((:+, :max, :min), (:zero, :identity, :identity))
9596
return out
9697
end
9798

98-
@eval function vreduce(::typeof($op), arg)
99+
@eval function _vreduce(::typeof($op), arg)
99100
s = $init(arg[1])
100101
@avx for i in 1:length(arg)
101102
s = $op(s, arg[i])
102103
end
103104
return s
104105
end
105-
end
106+
end

0 commit comments

Comments
 (0)