Skip to content

Commit 6f4ed3f

Browse files
committed
Modify stride penalty, and fix removal of check for indices like A[3 - i] in search for constant offsets (we need a subtraction for a reversal, so constant offset [3 in this example] may as well be rolled in there).
1 parent 2472156 commit 6f4ed3f

File tree

4 files changed

+9
-6
lines changed

4 files changed

+9
-6
lines changed

src/determinestrategy.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,11 @@ function stride_penalty(ls::LoopSet, op::Operation, order::Vector{Symbol}, loopf
519519
opstrides[1] = 1.0
520520
end
521521
# loops = map(s -> getloop(ls, s), loopdeps)
522+
l = length(getloop(ls, first(loopdeps)))
522523
for i 2:length(loopdeps)
523-
opstrides[i] = opstrides[i-1] * length(getloop(ls, loopdeps[i-1]))
524+
looplength = length(getloop(ls, loopdeps[i-1]))
525+
opstrides[i] = opstrides[i-1] * looplength
526+
l *= looplength
524527
# opstrides[i] = opstrides[i-1] * length(loops[i-1])
525528
end
526529
penalty = 0.0
@@ -530,7 +533,7 @@ function stride_penalty(ls::LoopSet, op::Operation, order::Vector{Symbol}, loopf
530533
penalty += loopfreqs[i] * opstrides[id]
531534
end
532535
end
533-
penalty
536+
penalty * l
534537
end
535538
function stride_penalty(ls::LoopSet, order::Vector{Symbol})
536539
stridepenaltydict = Dict{Symbol,Vector{Float64}}()
@@ -545,7 +548,7 @@ function stride_penalty(ls::LoopSet, order::Vector{Symbol})
545548
push!(v, stride_penalty(ls, op, order, loopfreqs))
546549
end
547550
end
548-
sum(maximum, values(stridepenaltydict)) #* prod(length, ls.loops) / 1024^length(order)
551+
sum(maximum, values(stridepenaltydict)) * 10 / 1024^length(order) #* prod(length, ls.loops)
549552
end
550553
function isoptranslation(ls::LoopSet, op::Operation, unrollsyms::UnrollSymbols)
551554
@unpack u₁loopsym, u₂loopsym, vectorized = unrollsyms

src/memory_ops_common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ function checkforoffset!(
112112
factor = f === :+ ? 1 : -1
113113
arg1 = ind.args[2]
114114
arg2 = ind.args[3]
115-
if arg1 isa Integer# && isone(factor)
115+
if arg1 isa Integer && isone(factor) # we want to return false when we're subtracting the index, e.g. A[3 - i]
116116
if arg2 isa Symbol && arg2 ls.loopsymbols
117117
addoffset!(ls, indices, offsets, loopedindex, loopdependencies, arg2, arg1 * factor)
118118
elseif arg2 isa Expr

test/gemv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
# T = Float32
44
@testset "GEMV" begin
55
# Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (4, 6)
6-
Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (2, 6) : (2, 10)
6+
Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (2, 6) : (2, 8)
77
gemvq = :(for i eachindex(y)
88
yᵢ = 0.0
99
for j eachindex(x)

test/miscellaneous.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Test
44

55
@testset "Miscellaneous" begin
66

7-
Unum, Tnum = LoopVectorization.REGISTER_COUNT == 16 ? (2, 6) : (2, 10)
7+
Unum, Tnum = LoopVectorization.REGISTER_COUNT == 16 ? (2, 6) : (2, 8)
88
dot3q = :(for m 1:M, n 1:N
99
s += x[m] * A[m,n] * y[n]
1010
end);

0 commit comments

Comments
 (0)