Skip to content

Commit e522dc2

Browse files
committed
Support threaded broadcasts
1 parent f2a48c3 commit e522dc2

File tree

11 files changed

+201
-195
lines changed

11 files changed

+201
-195
lines changed

src/broadcast.jl

Lines changed: 159 additions & 151 deletions
Large diffs are not rendered by default.

src/codegen/lower_threads.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ function define_block_size(threadedloop, vloop, tn, W)
322322
end
323323
function thread_one_loops_expr(
324324
ls::LoopSet, ua::UnrollArgs, valid_thread_loop::Vector{Bool}, ntmax::UInt, c::Float64,
325-
UNROLL::Tuple{Bool,Int8,Int8,Int,Int,Int,Int,Int,Int,Int,UInt}, OPS::Expr, ARF::Expr, AM::Expr, LPSYM::Expr
325+
UNROLL::Tuple{Bool,Int8,Int8,Bool,Int,Int,Int,Int,Int,Int,Int,UInt}, OPS::Expr, ARF::Expr, AM::Expr, LPSYM::Expr
326326
)
327327
looplen = looplengthprod(ls)
328328
c = 0.05460264079015985 * c / looplen
@@ -440,7 +440,7 @@ function define_thread_blocks(threadedloop1, threadedloop2, vloop, u₁loop, u
440440
end
441441
function thread_two_loops_expr(
442442
ls::LoopSet, ua::UnrollArgs, valid_thread_loop::Vector{Bool}, ntmax::UInt, c::Float64,
443-
UNROLL::Tuple{Bool,Int8,Int8,Int,Int,Int,Int,Int,Int,Int,UInt}, OPS::Expr, ARF::Expr, AM::Expr, LPSYM::Expr
443+
UNROLL::Tuple{Bool,Int8,Int8,Bool,Int,Int,Int,Int,Int,Int,Int,UInt}, OPS::Expr, ARF::Expr, AM::Expr, LPSYM::Expr
444444
)
445445
looplen = looplengthprod(ls)
446446
c = 0.05460264079015985 * c / looplen
@@ -608,7 +608,7 @@ function valid_thread_loops(ls::LoopSet)
608608
valid_thread_loop, ua, c
609609
end
610610
function avx_threads_expr(
611-
ls::LoopSet, UNROLL::Tuple{Bool,Int8,Int8,Int,Int,Int,Int,Int,Int,Int,UInt},
611+
ls::LoopSet, UNROLL::Tuple{Bool,Int8,Int8,Bool,Int,Int,Int,Int,Int,Int,Int,UInt},
612612
nt::UInt, OPS::Expr, ARF::Expr, AM::Expr, LPSYM::Expr
613613
)
614614
valid_thread_loop, ua, c = valid_thread_loops(ls)

src/condense_loopset.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ function shifted_loopset(ls::LoopSet, loopsyms::Vector{Symbol})
105105
end
106106
ld
107107
end
108+
# loopdeps_uint(ls::LoopSet, op::Operation) = (@show op; shifted_loopset(ls, loopdependencies(op)))
108109
loopdeps_uint(ls::LoopSet, op::Operation) = shifted_loopset(ls, loopdependencies(op))
109110
reduceddeps_uint(ls::LoopSet, op::Operation) = shifted_loopset(ls, reduceddependencies(op))
110111
childdeps_uint(ls::LoopSet, op::Operation) = shifted_loopset(ls, reducedchildren(op))
@@ -299,18 +300,19 @@ end
299300
# first_cache_size() = _first_cache_size(cache_size(first_cache()))
300301

301302
@generated function _avx_config_val(
302-
::Val{inline}, ::Val{u₁}, ::Val{u₂}, ::Val{thread}, ::StaticInt{W},
303-
::StaticInt{RS}, ::StaticInt{AR}, ::StaticInt{NT},
303+
::Val{CNFARG}, ::StaticInt{W}, ::StaticInt{RS}, ::StaticInt{AR}, ::StaticInt{NT},
304304
::StaticInt{CLS}, ::StaticInt{L1}, ::StaticInt{L2}, ::StaticInt{L3}
305-
) where {inline,u₁,u₂,thread,W,RS,AR,CLS,L1,L2,L3,NT}
305+
) where {CNFARG,W,RS,AR,CLS,L1,L2,L3,NT}
306+
inline,u₁,u₂,BROADCAST,thread = CNFARG
306307
nt = min(thread % UInt, NT % UInt)
307-
t = Expr(:tuple, inline, u₁, u₂, W, RS, AR, CLS, L1,L2,L3, nt)
308+
t = Expr(:tuple, inline, u₁, u₂, BROADCAST, W, RS, AR, CLS, L1,L2,L3, nt)
308309
Expr(:call, Expr(:curly, :Val, t))
309310
end
310-
@inline function avx_config_val(::Val{inline}, ::Val{u₁}, ::Val{u₂}, ::Val{thread}, ::StaticInt{W}) where {inline,u₁,u₂,thread,W}
311+
@inline function avx_config_val(
312+
::Val{CNFARG}, ::StaticInt{W}
313+
) where {CNFARG,W}
311314
_avx_config_val(
312-
Val{inline}(), Val{u₁}(), Val{u₂}(), Val{thread}(), StaticInt{W}(),
313-
register_size(), available_registers(), lv_max_num_threads(),
315+
Val{CNFARG}(), StaticInt{W}(), register_size(), available_registers(), lv_max_num_threads(),
314316
cache_linesize(), cache_size(StaticInt(1)), cache_size(StaticInt(2)), cache_size(StaticInt(3))
315317
)
316318
end
@@ -338,7 +340,8 @@ function generate_call(ls::LoopSet, (inline,u₁,u₂)::Tuple{Bool,Int8,Int8}, t
338340
loop_syms = tuple_expr(QuoteNode, ls.loopsymbols)
339341
func = debug ? lv(:_avx_loopset_debug) : lv(:_avx_!)
340342
lbarg = debug ? Expr(:call, :typeof, loop_bounds) : loop_bounds
341-
unroll_param_tup = Expr(:call, lv(:avx_config_val), :(Val{$inline}()), :(Val{$u₁}()), :(Val{$u₂}()), :(Val{$thread}()), VECTORWIDTHSYMBOL)
343+
configarg = (inline,u₁,u₂,ls.isbroadcast[],thread)
344+
unroll_param_tup = Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), VECTORWIDTHSYMBOL)
342345
q = Expr(:call, func, unroll_param_tup, val(operation_descriptions), val(arrayref_descriptions), val(argmeta), val(loop_syms))
343346
# debug && deleteat!(q.args, 2)
344347
vargs_as_tuple = true#!debug

src/constructors.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,20 @@ end
3838
function substitute_broadcast(q::Expr, mod::Symbol, inline, u₁, u₂, threads)
3939
ci = first(Meta.lower(LoopVectorization, q).args).code
4040
nargs = length(ci)-1
41-
ex = Expr(:block,)
41+
ex = Expr(:block)
4242
syms = [gensym() for _ 1:nargs]
43-
valarg = Expr(:call, lv(:avx_config_val), :(Val{$inline}()), :(Val{$u₁}()), :(Val{$u₂}()), :(Val{$threads}()), staticexpr(0))
43+
configarg = (inline,u₁,u₂,true,threads)
44+
unroll_param_tup = Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), staticexpr(0))
4445
for n 1:nargs
4546
ciₙ = ci[n]
4647
ciₙargs = ciₙ.args
4748
f = first(ciₙargs)
4849
if ciₙ.head === :(=)
4950
push!(ex.args, Expr(:(=), f, syms[((ciₙargs[2])::Core.SSAValue).id]))
5051
elseif isglobalref(f, Base, :materialize!)
51-
add_ci_call!(ex, lv(:vmaterialize!), ciₙargs, syms, n, valarg, mod)
52+
add_ci_call!(ex, lv(:vmaterialize!), ciₙargs, syms, n, unroll_param_tup, mod)
5253
elseif isglobalref(f, Base, :materialize)
53-
add_ci_call!(ex, lv(:vmaterialize), ciₙargs, syms, n, valarg, mod)
54+
add_ci_call!(ex, lv(:vmaterialize), ciₙargs, syms, n, unroll_param_tup, mod)
5455
else
5556
add_ci_call!(ex, f, ciₙargs, syms, n)
5657
end

src/modeling/graphs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ function LoopSet(mod::Symbol)
558558
Matrix{Float64}(undef, 4, 2), # reg_pres
559559
Bool[], Bool[], Ref{UnrollSpecification}(),
560560
Ref(false), Ref{LoopStartStopManager}(),
561-
Ref(0), Ref(0), Ref(false),
561+
Ref(0), Ref(0), Ref(false), # vector width, sym counter, isbroadcast
562562
Ref(0), Ref(0), Ref(0), Ref((0,0,0)),# hw params
563563
Ref(-1), # Ureduct
564564
Tuple{Vector{Symbol},Vector{Int}}[],

src/modeling/operations.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@ struct ArrayReference
2323
offsets::Vector{Int8}
2424
strides::Vector{Int8}
2525
end
26-
ArrayReference(array, indices) = ArrayReference(array, indices, zeros(Int8, length(indices)), ones(Int8, length(indices)))
26+
function ArrayReference(array, indices)
27+
ninds = length(indices)
28+
if ninds > 0
29+
ninds -= (first(indices) === DISCONTIGUOUS)
30+
end
31+
ArrayReference(array, indices, zeros(Int8, ninds), ones(Int8, ninds))
32+
end
2733
function sameref(x::ArrayReference, y::ArrayReference)
2834
(x.array === y.array) && (x.indices == y.indices)
2935
end
@@ -330,7 +336,7 @@ function Operation(id::Int, var::Symbol, elementbytes::Int, instr, optype::Opera
330336
end
331337
Base.:(==)(x::ArrayReferenceMetaPosition, y::ArrayReferenceMetaPosition) = x.mref == y.mref
332338
# Avoid memory allocations by using this for ops that aren't references
333-
const NOTAREFERENCE = ArrayReferenceMeta(ArrayReference(Symbol(""), Union{Symbol,Int}[]),Bool[],Symbol(""))
339+
const NOTAREFERENCE = ArrayReferenceMeta(ArrayReference(Symbol(""), Symbol[]),Bool[],Symbol(""))
334340
const NOTAREFERENCEMP = ArrayReferenceMetaPosition(NOTAREFERENCE, NOPARENTS, Symbol[], Symbol[],Symbol(""))
335341
varname(::Nothing) = nothing
336342
varname(mpref::ArrayReferenceMetaPosition) = mpref.varname

src/parse/add_loads.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function maybeaddref!(ls::LoopSet, op)
1+
function maybeaddref!(ls::LoopSet, op::Operation)
22
ref = op.ref
33
id = findfirst(==(ref), ls.refs_aliasing_syms)
44
# try to CSE
@@ -40,10 +40,8 @@ function add_simple_load!(
4040
ls::LoopSet, var::Symbol, ref::ArrayReference, elementbytes::Int,
4141
actualarray::Bool = true, broadcast::Bool = false
4242
)
43-
loopdeps = Symbol[s for s ref.indices]
44-
mref = ArrayReferenceMeta(
45-
ref, fill(true, length(loopdeps) - isdiscontiguous(ref))
46-
)
43+
loopdeps = copy(getindicesonly(ref))
44+
mref = ArrayReferenceMeta(ref, fill(true, length(loopdeps)))
4745
add_simple_load!(ls, var, mref, loopdeps, elementbytes, actualarray, broadcast)
4846
end
4947
function add_simple_load!(

src/parse/memory_ops_common.jl

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,11 @@ function add_vptr!(ls::LoopSet, array::Symbol, vptrarray::Symbol, actualarray::B
3737
if !includesarray(ls, array)
3838
push!(ls.includedarrays, array)
3939
actualarray && push!(ls.includedactualarrays, array)
40-
func = lv(broadcast ? :stridedpointer_for_broadcast : :stridedpointer)
41-
pushpreamble!(ls, Expr(:(=), vptrarray, Expr(:call, func, array)))
42-
# if broadcast
43-
# pushpreamble!(ls, Expr(:(=), vptrarray, Expr(:call, lv(:stridedpointer_for_broadcast), array)))
44-
# else
45-
# pushpreamble!(ls, Expr(:(=), vptrarray, Expr(:call, lv(:stridedpointer), array)))
46-
# # pushpreamble!(ls, Expr(:(=), vptrarray, Expr(:call, lv(:noaliasstridedpointer), array)))
47-
# end
40+
broadcast || pushpreamble!(ls, Expr(:(=), vptrarray, Expr(:call, lv(:stridedpointer), array)))
4841
end
4942
nothing
5043
end
5144

52-
# @inline valsum() = Val{0}()
53-
# @inline valsum(::Val{M}) where {M} = Val{M}()
54-
# @generated valsum(::Val{M}, ::Val{N}) where {M,N} = Val{M+N}()
55-
# @inline valsum(::Val{M}, ::Val{N}, ::Val{K}, args...) where {M,N,K} = valsum(valsum(Val{M}(), Val{N}()), Val{K}(), args...)
5645
@inline staticdims(::Any) = One()
5746
@inline staticdims(::CartesianIndices{N}) where {N} = StaticInt{N}()
5847

src/precompile.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ function _precompile_()
22
ccall(:jl_generating_output, Cint, ()) == 1 || return nothing
33

44
Base.precompile(Tuple{Type{ArrayRefStruct},LoopSet,ArrayReferenceMeta,Vector{Symbol},Vector{Int64}})
5-
Base.precompile(Tuple{typeof(_avx_loopset),Any,Any,Any,Any,Core.SimpleVector,Core.SimpleVector,Tuple{Bool, Int8, Int8, Int64, Int64, Int64, Int64, Int64, Int64, Int64, UInt64}})
5+
Base.precompile(Tuple{typeof(_avx_loopset),Any,Any,Any,Any,Core.SimpleVector,Core.SimpleVector,Tuple{Bool, Int8, Int8, Bool, Int64, Int64, Int64, Int64, Int64, Int64, Int64, UInt64}})
66
Base.precompile(Tuple{typeof(add_broadcast!),LoopSet,Symbol,Symbol,Vector{Symbol},Type{SubArray{Float32, 2, Array{Float32, 3}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}}, true}},Int64})
77
Base.precompile(Tuple{typeof(add_broadcast!),LoopSet,Symbol,Symbol,Vector{Symbol},Type{SubArray{Float64, 2, Array{Float64, 3}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}}, true}},Int64})
88
Base.precompile(Tuple{typeof(add_broadcast!),LoopSet,Symbol,Symbol,Vector{Symbol},Type{SubArray{Int32, 2, Array{Int32, 3}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}}, true}},Int64})

src/reconstruct_loopset.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -559,9 +559,10 @@ function avx_loopset!(
559559
num_params = extract_external_functions!(ls, num_params, vargs)
560560
ls
561561
end
562-
function avx_body(ls::LoopSet, UNROLL::Tuple{Bool,Int8,Int8,Int,Int,Int,Int,Int,Int,Int,UInt})
563-
inline, u₁, u₂, W, rs, rc, cls, l1, l2, l3, nt = UNROLL
562+
function avx_body(ls::LoopSet, UNROLL::Tuple{Bool,Int8,Int8,Bool,Int,Int,Int,Int,Int,Int,Int,UInt})
563+
inline, u₁, u₂, isbroadcast, W, rs, rc, cls, l1, l2, l3, nt = UNROLL
564564
q = iszero(u₁) ? lower_and_split_loops(ls, inline % Int) : lower(ls, u₁ % Int, u₂ % Int, inline % Int)
565+
ls.isbroadcast[] = isbroadcast
565566
iszero(length(ls.outer_reductions)) ? push!(q.args, nothing) : push!(q.args, loopset_return_value(ls, Val(true)))
566567
q
567568
end
@@ -584,14 +585,14 @@ function tovector(@nospecialize(t))
584585
end
585586
function _avx_loopset(
586587
@nospecialize(OPSsv), @nospecialize(ARFsv), @nospecialize(AMsv), @nospecialize(LPSYMsv), LBsv::Core.SimpleVector, vargs::Core.SimpleVector,
587-
UNROLL::Tuple{Bool,Int8,Int8,Int,Int,Int,Int,Int,Int,Int,UInt}
588+
UNROLL::Tuple{Bool,Int8,Int8,Bool,Int,Int,Int,Int,Int,Int,Int,UInt}
588589
)
589590
nops = length(OPSsv) ÷ 3
590591
instr = Instruction[Instruction(OPSsv[3i+1], OPSsv[3i+2]) for i 0:nops-1]
591592
ops = OperationStruct[ OPSsv[3i] for i 1:nops ]
592593
ls = LoopSet(:LoopVectorization)
593-
inline, u₁, u₂, W, rs, rc, cls, l1, l2, l3, nt = UNROLL
594-
set_hw!(ls, rs, rc, cls, l1, l2, l3); ls.vector_width[] = W
594+
inline, u₁, u₂, isbroadcast, W, rs, rc, cls, l1, l2, l3, nt = UNROLL
595+
set_hw!(ls, rs, rc, cls, l1, l2, l3); ls.vector_width[] = W; ls.isbroadcast[] = isbroadcast
595596
avx_loopset!(
596597
ls, instr, ops,
597598
ArrayRefStruct[ARFsv...],
@@ -626,10 +627,10 @@ Execute an `@avx` block. The block's code is represented via the arguments:
626627
ls = _avx_loopset(OPS, ARF, AM, LPSYM, LB.parameters, V.parameters, UNROLL)
627628
# return @show avx_body(ls, UNROLL)
628629
if last(UNROLL) > 1
629-
inline, u₁, u₂, W, rs, rc, cls, l1, l2, l3, nt = UNROLL
630+
inline, u₁, u₂, isbroadcast, W, rs, rc, cls, l1, l2, l3, nt = UNROLL
630631
# wrap in `OPS, ARF, AM, LPSYM` in `Expr` to homogenize types
631632
avx_threads_expr(
632-
ls, (inline, u₁, u₂, W, rs, rc, cls, l1, l2, l3, one(UInt)), nt,
633+
ls, (inline, u₁, u₂, isbroadcast, W, rs, rc, cls, l1, l2, l3, one(UInt)), nt,
633634
:(Val{$OPS}()), :(Val{$ARF}()), :(Val{$AM}()), :(Val{$LPSYM}())
634635
)
635636
else

0 commit comments

Comments
 (0)