Skip to content

Commit 4c06218

Browse files
committed
Fix for prefetch-check in case of dependency on non-loop ops.
1 parent 1b79fb1 commit 4c06218

File tree

8 files changed

+66
-11
lines changed

8 files changed

+66
-11
lines changed

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using SIMDPirates: VECTOR_SYMBOLS, evadd, evsub, evmul, evfdiv, vrange,
1616
reduced_add, reduced_prod, reduce_to_add, reduced_max, reduced_min, vsum, vprod, vmaximum, vminimum,
1717
sizeequivalentfloat, sizeequivalentint, vadd!, vsub!, vmul!, vfdiv!, vfmadd!, vfnmadd!, vfmsub!, vfnmsub!,
1818
vfmadd231, vfmsub231, vfnmadd231, vfnmsub231, sizeequivalentfloat, sizeequivalentint, #prefetch,
19-
vmullog2, vmullog10, vdivlog2, vdivlog10, vmullog2add!, vmullog10add!, vdivlog2add!, vdivlog10add!, vfmaddaddone, vadd1
19+
vmullog2, vmullog10, vdivlog2, vdivlog10, vmullog2add!, vmullog10add!, vdivlog2add!, vdivlog10add!, vfmaddaddone, vadd1, relu
2020
using SLEEFPirates: pow
2121
using Base.Broadcast: Broadcasted, DefaultArrayStyle
2222
using LinearAlgebra: Adjoint, Transpose

src/add_stores.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,14 @@ function add_store!(
5757
mpref = array_reference_meta!(ls, array, rawindices, elementbytes, var)
5858
add_store!(ls, mpref, elementbytes)
5959
end
60-
function add_simple_store!(ls::LoopSet, parent::Operation, ref::ArrayReference, elementbytes::Int)
61-
mref = ArrayReferenceMeta(
62-
ref, fill(true, length(getindices(ref)))
63-
)
64-
op = Operation( ls, name(mref), elementbytes, :setindex!, memstore, getindices(ref), NODEPENDENCY, [parent], mref )
60+
function add_simple_store!(ls::LoopSet, parent::Operation, mref::ArrayReferenceMeta, elementbytes::Int)
61+
op = Operation( ls, name(mref), elementbytes, :setindex!, memstore, getindices(mref.ref), NODEPENDENCY, [parent], mref )
6562
add_unique_store!(ls, op)
6663
end
67-
function add_simple_store!(ls::LoopSet, var::Symbol, ref::ArrayReference, elementbytes::Int)
68-
add_simple_store!(ls, getop(ls, var, elementbytes), ref, elementbytes)
64+
function add_simple_store!(ls::LoopSet, var::Union{Symbol,Operation}, ref::Union{ArrayReference,ArrayReferenceMeta}, elementbytes::Int)
65+
parent = isa(var, Symbol) ? getop(ls, var, elementbytes) : var
66+
mref = isa(ref, ArrayReference) ? ArrayReferenceMeta(ref, fill(true, length(getindices(ref)))) : ref
67+
add_simple_store!(ls, parent, mref, elementbytes)
6968
end
7069
function add_store_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int)
7170
array, raw_indices = ref_from_ref!(ls, ex)

src/costs.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ const COST = Dict{Symbol,InstructionCost}(
183183
:iseven => InstructionCost(1, 0.5),
184184
:max => InstructionCost(4,0.5),
185185
:min => InstructionCost(4,0.5),
186+
:relu => InstructionCost(4,0.5),
186187
# Instruction(:ifelse) => InstructionCost(1, 0.5),
187188
:vifelse => InstructionCost(1, 0.5),
188189
:inv => InstructionCost(13,4.0,-2.0,1),
@@ -452,6 +453,7 @@ const FUNCTIONSYMBOLS = IdDict{Type{<:Function},Instruction}(
452453
# typeof(SLEEFPirates.tanh_fast) => :tanh_fast,
453454
typeof(max) => :max,
454455
typeof(min) => :min,
456+
typeof(relu) => :relu,
455457
typeof(<<) => :<<,
456458
typeof(>>) => :>>,
457459
typeof(>>>) => :>>>,

src/graphs.jl

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ function Loop(itersymbol::Symbol, start::Union{Int,Symbol}, stop::Union{Int,Symb
7070
end
7171
Base.length(loop::Loop) = 1 + loop.stophint - loop.starthint
7272
isstaticloop(loop::Loop) = loop.startexact & loop.stopexact
73+
74+
7375
function startloop(loop::Loop, itersymbol)
7476
startexact = loop.startexact
7577
if startexact
@@ -353,11 +355,12 @@ function oporder(ls::LoopSet)
353355
end
354356
names(ls::LoopSet) = ls.loop_order.loopnames
355357
reversenames(ls::LoopSet) = ls.loop_order.bestorder
356-
function getloopid(ls::LoopSet, s::Symbol)::Int
358+
function getloopid_or_nothing(ls::LoopSet, s::Symbol)
357359
for (loopnum,sym) enumerate(ls.loopsymbols)
358360
s === sym && return loopnum
359361
end
360362
end
363+
getloopid(ls::LoopSet, s::Symbol) = getloopid_or_nothing(ls, s)::Int
361364
getloop(ls::LoopSet, s::Symbol) = ls.loops[getloopid(ls, s)]
362365
# getloop(ls::LoopSet, i::Integer) = ls.loops[i]
363366
getloopsym(ls::LoopSet, i::Integer) = ls.loopsymbols[i]
@@ -366,6 +369,10 @@ Base.length(ls::LoopSet, s::Symbol) = length(getloop(ls, s))
366369
# isstaticloop(ls::LoopSet, s::Symbol) = isstaticloop(getloop(ls,s))
367370
# looprangehint(ls::LoopSet, s::Symbol) = length(getloop(ls, s))
368371
# looprangesym(ls::LoopSet, s::Symbol) = getloop(ls, s).rangesym
372+
373+
"""
374+
getop only works while construction a LoopSet object. You cannot use it while lowering.
375+
"""
369376
getop(ls::LoopSet, var::Number, elementbytes) = add_constant!(ls, var, elementbytes)
370377
function getop(ls::LoopSet, var::Symbol, elementbytes::Int)
371378
get!(ls.opdict, var) do
@@ -379,6 +386,16 @@ function getop(ls::LoopSet, var::Symbol, deps, elementbytes::Int)
379386
end
380387
getop(ls::LoopSet, i::Int) = ls.operations[i]
381388

389+
# """
390+
# Returns an operation with the same name as `s`.
391+
# """
392+
# function getoperation(ls::LoopSet, s::Symbol)
393+
# for op ∈ Iterators.Reverse(operations(ls))
394+
# name(op) === s && return op
395+
# end
396+
# throw("Symbol $s not found among operations(ls).")
397+
# end
398+
382399
function Operation(
383400
ls::LoopSet, variable, elementbytes, instruction,
384401
node_type, dependencies, reduced_deps, parents, ref = NOTAREFERENCE
@@ -731,10 +748,39 @@ function UnrollSpecification(ls::LoopSet, u₁loop::Symbol, u₂loop::Symbol, ve
731748
end
732749

733750
"""
751+
looplengthprod(ls::LoopSet)
752+
734753
Convert to `Float64` for the sake of non-64 bit platforms.
735754
"""
736755
looplengthprod(ls::LoopSet) = prod(Float64 length, ls.loops)
737756

757+
758+
function looplength(ls::LoopSet, s::Symbol)
759+
# search_tree(parents(operations(ls)[i]), name(op)) && return true
760+
id = getloopid_or_nothing(ls, s)
761+
if isnothing(id)
762+
l = 0.0
763+
# TODO: we could double count a loop.
764+
for op operations(ls)
765+
name(op) === s || continue
766+
for opp parents(op)
767+
if isloopvalue(opp)
768+
oppname = first(loopdependencies(opp))
769+
l += looplength(ls, oppname)
770+
elseif iscompute(opp)
771+
oppname = name(opp)
772+
l += looplength(ls, oppname)
773+
# TODO elseif isconstant(opp)
774+
end
775+
end
776+
l += 1 - length(parents(op))
777+
end
778+
l
779+
else
780+
Float64(length(ls.loops[id]))
781+
end
782+
end
783+
738784
# function getunrolled(ls::LoopSet)
739785
# order = names(ls)
740786
# us = ls.unrollspecification[]

src/lower_load.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ function prefetchisagoodidea(ls::LoopSet, op::Operation, td::UnrollArgs)
8787
innermostloopindv = findall(map(isequal(innermostloopsym), getindices(op)))
8888
isone(length(innermostloopindv)) || return 0
8989
innermostloopind = first(innermostloopindv)
90-
if prod(s -> Float64(length(getloop(ls, s))), @view(indices[1:innermostloopind-1])) 120.0 && length(getloop(ls, innermostloopsym)) 120
90+
if prod(s -> Float64(looplength(ls, s)), @view(indices[1:innermostloopind-1])) 120.0 && length(getloop(ls, innermostloopsym)) 120
9191
if op.ref.ref.offsets[innermostloopind] < 120
9292
for opp operations(ls)
9393
iscompute(opp) && (innermostloopsym loopdependencies(opp)) && load_constrained(opp, u₁loopsym, u₂loopsym, innermostloopsym, true) && return 0

src/lowering.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,7 @@ end
671671
function lower_unrollspec(ls::LoopSet)
672672
us = ls.unrollspecification[]
673673
@unpack vectorizedloopnum, u₁, u₂ = us
674+
# @show u₁, u₂
674675
order = names(ls)
675676
vectorized = order[vectorizedloopnum]
676677
setup_preamble!(ls, us)

src/operations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ function Base.show(io::IO, op::Operation)
242242
ref = Expr(:ref, name(op.ref)); append!(ref.args, getindices(op))
243243
print(io, Expr(:(=), ref, name(first(parents(op)))))
244244
elseif isloopvalue(op)
245-
print(io, Expr(:(=), op.variable, op.variable))
245+
print(io, Expr(:(=), op.variable, first(loopdependencies(op))))
246246
end
247247
end
248248

src/vectorizationbase_extensions.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ end
2323
Boff.offsets
2424
)
2525
end
26+
27+
@inline function VectorizationBase.stridedpointer_for_broadcast(A::OffsetArrays.OffsetArray)
28+
OffsetStridedPointer(
29+
VectorizationBase.stridedpointer_for_broadcast(parent(A)),
30+
VectorizationBase.staticmul(T, VectorizationBase.filter_strides_by_dimequal1(Base.tail(size(A)), VectorizationBase.tailstrides(A)))
31+
)
32+
end
2633
@inline function Base.transpose(A::OffsetStridedPointer)
2734
OffsetStridedPointer(
2835
transpose(A.ptr), A.offsets

0 commit comments

Comments
 (0)