Skip to content

Commit bb2aef7

Browse files
committed
Fix offset adjustment for offset indices that don't have constant stride. Fixes #287.
1 parent dfddc8d commit bb2aef7

File tree

5 files changed

+141
-18
lines changed

5 files changed

+141
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@ Static = "0.2"
3030
StrideArraysCore = "0.1.12"
3131
ThreadingUtilities = "0.4.2"
3232
UnPack = "1"
33-
VectorizationBase = "0.20.17"
33+
VectorizationBase = "0.20.18"
3434
julia = "1.5"

src/codegen/lower_memory_common.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,10 @@ function mem_offset(op::Operation, td::UnrollArgs, inds_calc_by_ptr_offset::Vect
138138
indvectorized = _mm & (ind === vloopsym)
139139
offset = offsets[n] % Int
140140
stride = strides[n] % Int
141-
ind_by_offset = inds_calc_by_ptr_offset[n] | (ind === CONSTANTZEROINDEX)
142-
if !ind_by_offset
143-
offset += (stride - 1)
141+
if ind CONSTANTZEROINDEX
142+
offset += (stride - 1)
144143
end
144+
ind_by_offset = inds_calc_by_ptr_offset[n] | (ind === CONSTANTZEROINDEX)
145145
@unpack vstep = td
146146
if loopedindex[n]
147147
addoffset!(ret, indvectorized, vstep, stride, ind, offset, ind_by_offset) # 7 arg
@@ -297,8 +297,8 @@ function mem_offset_u(
297297
stride = convert(Int, strides[n])
298298
indvectorized = ind === vloopsym
299299
indvectorizedmm = _mm & indvectorized
300-
if !ind_by_offset
301-
offset += (stride - 1)
300+
if ind CONSTANTZEROINDEX
301+
offset += (stride - 1)
302302
end
303303
if ind === u₁loopsym
304304
addvectoroffset!(ret, indvectorizedmm, incr₁, u₁step, vstep, stride, ind, offset, ind_by_offset, indvectorized) # 9 arg

src/modeling/graphs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ mutable struct LoopSet
383383
preamble::Expr
384384
prepreamble::Expr # performs extractions that must be performed first, and don't need further registering
385385
preamble_symsym::Vector{Tuple{Int,Symbol}}
386-
preamble_symint::Vector{Tuple{Int,Tuple{Int,Int32,Bool}}}
386+
preamble_symint::Vector{Tuple{Int,Tuple{Int,Int32,Bool}}} # (id,(intval,intsz,signed))
387387
preamble_symfloat::Vector{Tuple{Int,Float64}}
388388
preamble_zeros::Vector{Tuple{Int,NumberType}}
389389
preamble_funcofeltypes::Vector{Tuple{Int,Float64}}

src/parse/memory_ops_common.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ byterepresentable(x::Integer)::Bool = typemin(Int8) ≤ x ≤ typemax(Int8)
113113
function _addoffset!(indices, offsets, strides, loopedindex, loopdependencies, ind, offset, stride)
114114
push!(indices, ind)
115115
push!(offsets, offset % Int8)
116+
# push!(offsets, (offset+stride-1) % Int8)
116117
push!(strides, stride % Int8)
117118
push!(loopedindex, true)
118119
push!(loopdependencies, ind)
@@ -249,8 +250,35 @@ function checkforoffset!(
249250
loopedindex::Vector{Bool}, loopdependencies::Vector{Symbol}, reduceddeps::Vector{Symbol}, ind::Expr
250251
)::Symbol
251252

252-
offset, mult_syms = affine_index_expression(ls, ind)
253-
if !byterepresentable(offset)
253+
offset, mult_syms = affine_index_expression(ls, ind)
254+
let deleted = 0, N = length(mult_syms)
255+
for n 1:N
256+
ntemp = n - deleted
257+
mlt, sym = mult_syms[ntemp]
258+
opm = get(ls.opdict, sym, nothing)
259+
opm === nothing && continue
260+
isconstant(opm) || continue
261+
found = false
262+
for (opid,(intval,intsz,signed)) ls.preamble_symint
263+
if opid == identifier(opm)
264+
offset += intval * mlt
265+
deleted += 1
266+
deleteat!(mult_syms, ntemp)
267+
found = true
268+
break
269+
end
270+
end
271+
found && continue
272+
for (opid,nt) ls.preamble_zeros
273+
if opid == identifier(opm)
274+
deleted += 1
275+
deleteat!(mult_syms, ntemp)
276+
break
277+
end
278+
end
279+
end
280+
end
281+
if !byterepresentable(offset)
254282
if length(mult_syms) == 1
255283
mlt,sym = only(mult_syms)
256284
if !byterepresentable(mlt)
@@ -263,7 +291,6 @@ function checkforoffset!(
263291
vptrarray = gesp_const_offset!(ls, vptrarray, ninds, indices, loopedindex, 1, offset - r)
264292
offset = r
265293
end
266-
267294
# (success && byterepresentable(offset)) || return false, vptrarray
268295
if length(mult_syms) == 0
269296
addconstindex!(indices, offsets, strides, loopedindex, offset)

test/shuffleloadstores.jl

Lines changed: 104 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
2-
3-
41
function dot_simd(a::AbstractVector, b::AbstractVector)
52
s = zero(eltype(a))
63
@fastmath @inbounds @simd for i eachindex(a)
@@ -196,6 +193,94 @@ function sumdim2!(r1, r2)
196193
r1
197194
end
198195

196+
# Issue 287
197+
function my_gemm_noturbo!(out, s::Matrix{UInt8}, V)
198+
Vcols = size(V, 2)
199+
srows = size(s, 1)
200+
scols = size(s, 2)
201+
k = srows >> 2
202+
rem = srows & 3
203+
@inbounds @fastmath for c in 1:Vcols
204+
for j in 1:scols
205+
for l in 1:k
206+
block = s[l, j]
207+
for p in 1:4
208+
Aij = (block >> (2 * (p - 1))) & 3
209+
out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
210+
end
211+
end
212+
end
213+
end
214+
# TODO handle rem
215+
end
216+
function my_gemm_unroll(out, s::Matrix{UInt8}, V)
217+
Vcols = size(V, 2)
218+
srows = size(s, 1)
219+
scols = size(s, 2)
220+
k = srows >> 2
221+
rem = srows & 3
222+
@avx for c in 1:Vcols
223+
for j in 1:scols
224+
for l in 1:k
225+
block = s[l, j]
226+
for p in 1:4
227+
Aij = (block >> (2 * (p - 1))) & 3
228+
out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
229+
end
230+
end
231+
end
232+
end
233+
# TODO handle rem
234+
end
235+
function my_gemm_manual_unroll(out, s::Matrix{UInt8}, V)
236+
Vcols = size(V, 2)
237+
srows = size(s, 1)
238+
scols = size(s, 2)
239+
k = srows >> 2
240+
rem = srows & 3
241+
@avx for c in 1:Vcols
242+
for j in 1:scols
243+
for l in 1:k
244+
block = s[l, j]
245+
# unrolled loop
246+
p = 1
247+
Aij = (block >> (2 * (p - 1))) & 3
248+
out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
249+
p = 2
250+
Aij = (block >> (2 * (p - 1))) & 3
251+
out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
252+
p = 3
253+
Aij = (block >> (2 * (p - 1))) & 3
254+
out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
255+
p = 4
256+
Aij = (block >> (2 * (p - 1))) & 3
257+
out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
258+
end
259+
end
260+
end
261+
# TODO handle rem
262+
end
263+
function my_gemm_nexpr_unroll(out, s::Matrix{UInt8}, V)
264+
Vcols = size(V, 2)
265+
srows = size(s, 1)
266+
scols = size(s, 2)
267+
k = srows >> 2
268+
rem = srows & 3
269+
@turbo for c in 1:Vcols
270+
for j in 1:scols
271+
for l in 1:k
272+
block = s[l, j]
273+
# unrolled loop
274+
Base.Cartesian.@nexprs 4 p -> begin
275+
Aij = (block >> (2 * (p - 1))) & 3
276+
out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
277+
end
278+
end
279+
end
280+
end
281+
# TODO handle rem
282+
end
283+
199284
@testset "shuffles load/stores" begin
200285
@show @__LINE__
201286
for i 1:128
@@ -206,7 +291,7 @@ end
206291
@test dsimd cdot_mat(ac, bc)
207292
end
208293
@test dsimd cdot_affine(ac, bc) cdot_stride(ac, bc)
209-
294+
210295

211296
xq = [ntuple(_ -> rand(), Val(4)) for _ 1:i];
212297
yq = [ntuple(_ -> rand(), Val(4)) for _ 1:i];
@@ -230,7 +315,7 @@ end
230315
Aca = reinterpret(reshape, Float64, Ac);
231316
Bca = reinterpret(reshape, Float64, Bc);
232317
cmatmul_array!(Cca, Aca, Bca)
233-
318+
234319
@test Cc1 Cc2# ≈ Cc3
235320
end
236321
end
@@ -250,11 +335,22 @@ end
250335
ϕ = view(fill(1e5+1e7im, 2*J+17, G+17, H+17, M+17), 9:2*J+9, 9:G+9, 9:H+9, 9:M+9) .= rand.() .+ rand.().*im;
251336
@test issue209(M, G, J, H, B, ϕ) issue209_noavx(M, G, J, H, B, ϕ)
252337
end
253-
338+
254339
s = Array{Float64}(undef, 4, 128, 128);
255340
s2 = rand(4, 2, 128, 128);
256341
@test sumdim2_turbo!(s, s2) sumdim2!(similar(s), s2)
257342

343+
# issue 287
344+
out_test = zeros(100, 10);
345+
out_test1 = zeros(100, 10);
346+
s = rand(UInt8, 25, 100);
347+
V = rand(100, 10);
348+
my_gemm_noturbo!(out_test, s, V);
349+
my_gemm_unroll(out_test1, s, V);
350+
@test out_test out_test1
351+
my_gemm_manual_unroll(fill!(out_test1, 0), s, V);
352+
@test out_test out_test1
353+
my_gemm_nexpr_unroll(fill!(out_test1, 0), s, V);
354+
@test out_test out_test1
355+
258356
end
259-
260-

0 commit comments

Comments
 (0)