Skip to content

Commit 8c8c876

Browse files
committed
Set series of dependent operations to have the same name. This solution is going to be broken in some cases, but at least fixes #259.
1 parent 3f605b9 commit 8c8c876

File tree

4 files changed

+68
-16
lines changed

4 files changed

+68
-16
lines changed

src/codegen/loopstartstopmanager.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,14 @@ function pushgespind!(
425425
if index_by_index
426426
if gespsymbol === Symbol("")
427427
if constoffset == 0
428-
push!(gespinds.args, Expr(:call, GlobalRef(VectorizationBase, :NullStep)))
428+
ns = Expr(:call, GlobalRef(VectorizationBase, :NullStep))
429+
if fromgsp
430+
loop = getloop(ls, ind)
431+
if loop.rangesym Symbol("")
432+
ns = Expr(:call, lv(:similardims), loop.rangesym, ns)
433+
end
434+
end
435+
push!(gespinds.args, ns)
429436
else
430437
push!(gespinds.args, staticexpr(constoffset))
431438
end

src/condense_loopset.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ val(x) = Expr(:call, Expr(:curly, :Val, x))
279279
@inline gespf1(x::StridedPointer{T,1}, i::Tuple{Zero}) where {T} = x
280280
@inline gespf1(x::StridedBitPointer{T,1}, i::Tuple{Zero}) where {T} = x
281281
@generated function gespf1(x::StridedPointer{T,N,C,B,R}, i::Tuple{I}) where {T,N,I<:Integer,C,B,R}
282-
I === Zero && return :x
282+
# I === Zero && return :x
283283
ri = 0; rm = typemax(Int)
284284
for (i, r) enumerate(R)
285285
if r < rm
@@ -295,6 +295,34 @@ val(x) = Expr(:call, Expr(:curly, :Val, x))
295295
StridedPointer{$T,1,$(C===1 ? 1 : 0),$(B===1 ? 1 : 0),$(R[ri],)}(ptr, (getfield(getfield(x,:strd), $ri, 1),), (Zero(),))
296296
end
297297
end
298+
@generated function gespf1(x::StridedPointer{T,N,C,B,R}, ::Tuple{VectorizationBase.NullStep}) where {T,N,C,B,R}
299+
ri = 0; rm = typemax(Int)
300+
for (i, r) enumerate(R)
301+
if r < rm
302+
rm = r
303+
ri = i
304+
end
305+
end
306+
ri = max(1, ri)
307+
quote
308+
$(Expr(:meta,:inline))
309+
StridedPointer{$T,1,$(C===1 ? 1 : 0),$(B===1 ? 1 : 0),$(R[ri],)}(pointer(x), (getfield(getfield(x,:strd), $ri, 1),), (getfield(getfield(x,:offsets), $ri, 1),))
310+
end
311+
end
312+
@generated function gespf1(x::StridedBitPointer{N,C,B,R}, ::Tuple{VectorizationBase.NullStep}) where {N,C,B,R}
313+
ri = 0; rm = typemax(Int)
314+
for (i, r) enumerate(R)
315+
if r < rm
316+
rm = r
317+
ri = i
318+
end
319+
end
320+
ri = max(1, ri)
321+
quote
322+
$(Expr(:meta,:inline))
323+
StridedBitPointer{1,$(C===1 ? 1 : 0),$(B===1 ? 1 : 0),$(R[ri],)}(pointer(x), (getfield(getfield(x,:strd), $ri, 1),), (getfield(getfield(x,:offsets), $ri, 1),))
324+
end
325+
end
298326
@generated function gespf1(x::StridedBitPointer{T,N,C,B,R}, i::Tuple{I}) where {T,N,I<:Integer,C,B,R}
299327
I === Zero && return :x
300328
quote

src/parse/add_compute.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -154,21 +154,27 @@ function add_reduced_deps!(op::Operation, reduceddeps::Vector{Symbol})
154154
end
155155

156156
function substitute_op_in_parents!(
157-
vparents::Vector{Operation}, replacer::Operation, replacee::Operation, reduceddeps::Vector{Symbol}
157+
vparents::Vector{Operation}, replacer::Operation, replacee::Operation, reduceddeps::Vector{Symbol}, reductsym::Symbol
158158
)
159-
found = false
160-
for i eachindex(vparents)
161-
opp = vparents[i]
162-
if opp === replacee
163-
vparents[i] = replacer
164-
found = true
165-
else
166-
fopp = substitute_op_in_parents!(parents(opp), replacer, replacee, reduceddeps)
167-
fopp && add_reduced_deps!(opp, reduceddeps)
168-
found |= fopp
169-
end
159+
found = false
160+
for i eachindex(vparents)
161+
opp = vparents[i]
162+
if opp === replacee
163+
vparents[i] = replacer
164+
found = true
165+
else
166+
fopp = substitute_op_in_parents!(parents(opp), replacer, replacee, reduceddeps, reductsym)
167+
if fopp
168+
add_reduced_deps!(opp, reduceddeps)
169+
# FIXME: https://github.com/JuliaSIMD/LoopVectorization.jl/issues/259
170+
#
171+
opp.variable = reductsym
172+
opp.mangledvariable = Symbol("##", reductsym, :_)
173+
end
174+
found |= fopp
170175
end
171-
found
176+
end
177+
found
172178
end
173179

174180

@@ -216,7 +222,7 @@ function add_reduction_update_parent!(
216222
update_deps!(deps, reduceddeps, reductinit)#parent) # deps and reduced deps will not be disjoint
217223
end
218224
elseif !isouterreduction && reductinit !== parent
219-
substitute_op_in_parents!(vparents, reductinit, parent, reduceddeps)
225+
substitute_op_in_parents!(vparents, reductinit, parent, reduceddeps, reductsym)
220226
end
221227
update_reduction_status!(vparents, reduceddeps, name(reductinit))
222228
# this is the op added by add_compute

test/gemv.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,13 @@ using Test
227227
end
228228
end
229229

230+
function depchain_with_different_deps!(c1,c2,A,b)
231+
@avx for j in axes(A,1), k in axes(A,2)
232+
c1[j] += A[j,k] * b[k]
233+
c2[j] += A[j,k] * b[k] - 0 * b[j]
234+
end
235+
end
236+
230237
M, K, N = 51, 49, 61
231238
for T (Float32, Float64, Int32, Int64)
232239
@show T, @__LINE__
@@ -324,5 +331,9 @@ using Test
324331
@test Y0 Y1
325332
@test dY0 dY1
326333

334+
Y2 = zeros(TC, N, 2); Y3 = zeros(TC,N,2);
335+
depchain_with_different_deps!(Y2, Y3, A, b)
336+
@test view(Y2,:,1) == view(Y3,:,1)
337+
@test view(Y2,:,1) A*b
327338
end
328339
end

0 commit comments

Comments
 (0)