Skip to content

Commit 4c91232

Browse files
committed
brrr -> vectorize, determine outer reduct type from init type
1 parent c4b7d1e commit 4c91232

11 files changed

+224
-168
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.19"
4+
version = "0.12.20"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ using Requires
5151

5252

5353
export LowDimArray, stridedpointer, indices,
54-
@avx, @avxt, @brrr, @tbrrr, *ˡ, _avx_!,
54+
@avx, @avxt, @vectorize, @tvectorize, *ˡ, _avx_!,
5555
vmap, vmap!, vmapt, vmapt!, vmapnt, vmapnt!, vmapntt, vmapntt!,
5656
tanh_fast, sigmoid_fast,
5757
vfilter, vfilter!, vmapreduce, vreduce

src/codegen/lowering.jl

Lines changed: 38 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -426,22 +426,23 @@ end
426426

427427
function outer_reduction_zero(op::Operation, u₁u::Bool, Umax::Int, reduct_class::Float64, rs::Expr)
428428
reduct_zero = reduction_zero(reduct_class)
429+
Tsym = outer_reduct_init_typename(op)
429430
if isvectorized(op)
430431
if Umax == 1 || !u₁u
431432
if reduct_zero === :zero
432-
Expr(:call, lv(:_vzero), VECTORWIDTHSYMBOL, ELTYPESYMBOL, rs)
433+
Expr(:call, lv(:_vzero), VECTORWIDTHSYMBOL, Tsym, rs)
433434
else
434-
Expr(:call, lv(:_vbroadcast), VECTORWIDTHSYMBOL, Expr(:call, reduct_zero, ELTYPESYMBOL), rs)
435+
Expr(:call, lv(:_vbroadcast), VECTORWIDTHSYMBOL, Expr(:call, reduct_zero, Tsym), rs)
435436
end
436437
else
437438
if reduct_zero === :zero
438-
Expr(:call, lv(:zero_vecunroll), staticexpr(Umax), VECTORWIDTHSYMBOL, ELTYPESYMBOL, rs)
439+
Expr(:call, lv(:zero_vecunroll), staticexpr(Umax), VECTORWIDTHSYMBOL, Tsym, rs)
439440
else
440-
Expr(:call, lv(:vbroadcast_vecunroll), staticexpr(Umax), VECTORWIDTHSYMBOL, Expr(:call, reduct_zero, ELTYPESYMBOL), rs)
441+
Expr(:call, lv(:vbroadcast_vecunroll), staticexpr(Umax), VECTORWIDTHSYMBOL, Expr(:call, reduct_zero, Tsym), rs)
441442
end
442443
end
443444
else
444-
Expr(:call, reduct_zero, ELTYPESYMBOL)
445+
Expr(:call, reduct_zero, Tsym)
445446
end
446447
end
447448
# function outer_reduction_name(mvar::Symbol, _Umax::Int, u₂::Int, u₁u::Bool)
@@ -614,68 +615,37 @@ function gc_preserve(ls::LoopSet, q::Expr)
614615
# Expr(:block, gcp)
615616
end
616617

617-
function typeof_outer_reduction_init(ls::LoopSet, op::Operation)
618-
opid = identifier(op)
619-
for (id, sym) ls.preamble_symsym
620-
opid == id && return Expr(:call, :typeof, sym)
621-
end
622-
for (id,(intval,intsz,signed)) ls.preamble_symint
623-
if opid == id
624-
if intsz == 1
625-
return :Bool
626-
elseif signed
627-
return Symbol(:Int,intsz)
628-
else
629-
return Symbol(:UInt,intsz)
630-
end
631-
end
632-
end
633-
for (id,floatval) ls.preamble_symfloat
634-
opid == id && return :Float64
635-
end
636-
for (id,typ) ls.preamble_zeros
637-
instruction(ops[id]) === LOOPCONSTANT || continue
638-
opid == id || continue
639-
if typ == IntOrFloat
640-
return :Float64
641-
elseif typ == HardInt
642-
return :Int
643-
else#if typ == HardFloat
644-
return :Float64
645-
end
646-
end
647-
throw("Could not find initializing constant.")
648-
end
649-
function typeof_outer_reduction(ls::LoopSet, op::Operation)
650-
for opp operations(ls)
651-
opp === op && continue
652-
name(op) === name(opp) && return typeof_outer_reduction_init(ls, opp)
653-
end
654-
throw("Could not find initialization op.")
655-
end
656-
657-
function determine_eltype(ls::LoopSet)::Union{Symbol,Expr}
658-
if length(ls.includedactualarrays) == 0
659-
if length(ls.outer_reductions) == 0
660-
return Expr(:call, lv(:typeof), 0)
661-
elseif length(ls.outer_reductions) == 1
662-
op = ls.operations[only(ls.outer_reductions)]
663-
return typeof_outer_reduction(ls, op)
664-
else
665-
pt = Expr(:call, lv(:promote_type))
666-
for j ls.outer_reductions
667-
push!(pt.args, typeof_outer_reduction(ls, ls.operations[j]))
668-
end
669-
return pt
670-
end
671-
elseif length(ls.includedactualarrays) == 1
672-
return Expr(:call, lv(:eltype), first(ls.includedactualarrays))
618+
function determine_eltype(ls::LoopSet, ortypdefined::Bool)::Union{Symbol,Expr}
619+
narrays = length(ls.includedactualarrays)
620+
noreduc = length(ls.outer_reductions)
621+
ntyp = narrays + noreduc
622+
if ntyp == 0
623+
return Expr(:call, lv(:typeof), 0)
624+
elseif ntyp == 1
625+
if narrays == 1
626+
return Expr(:call, lv(:eltype), first(ls.includedactualarrays))
627+
else
628+
oreducop = ls.operations[only(ls.outer_reductions)]
629+
if ortypdefined
630+
return typeof_expr(oreducop)
631+
else
632+
return outer_reduct_init_typename(oreducop)
633+
end
673634
end
674-
promote_q = Expr(:call, lv(:promote_type))
675-
for array ls.includedactualarrays
676-
push!(promote_q.args, Expr(:call, lv(:eltype), array))
635+
end
636+
pt = Expr(:call, lv(:promote_type))
637+
for array ls.includedactualarrays
638+
push!(pt.args, Expr(:call, lv(:eltype), array))
639+
end
640+
for j ls.outer_reductions
641+
oreducop = ls.operations[j]
642+
if ortypdefined
643+
push!(pt.args, typeof_expr(oreducop))
644+
else
645+
push!(pt.args, outer_reduct_init_typename(oreducop))
677646
end
678-
promote_q
647+
end
648+
return pt
679649
end
680650
@inline _eltype(x) = eltype(x)
681651
@inline _eltype(::BitArray) = VectorizationBase.Bit
@@ -752,8 +722,8 @@ function definemask(loop::Loop)
752722
maskexpr(addexpr(lenexpr, 1))
753723
end
754724
end
755-
function define_eltype_vec_width!(q::Expr, ls::LoopSet, vectorized)
756-
push!(q.args, Expr(:(=), ELTYPESYMBOL, determine_eltype(ls)))
725+
function define_eltype_vec_width!(q::Expr, ls::LoopSet, vectorized, ortypdefined::Bool)
726+
push!(q.args, Expr(:(=), ELTYPESYMBOL, determine_eltype(ls, ortypdefined)))
757727
push!(q.args, Expr(:(=), VECTORWIDTHSYMBOL, determine_width(ls, vectorized)))
758728
nothing
759729
end
@@ -764,7 +734,7 @@ function setup_preamble!(ls::LoopSet, us::UnrollSpecification, Ureduct::Int)
764734
u₂loopsym = order[u₂loopnum]
765735
vectorized = order[vloopnum]
766736
set_vector_width!(ls, vectorized)
767-
iszero(length(ls.includedactualarrays) + length(ls.outer_reductions)) || define_eltype_vec_width!(ls.preamble, ls, vectorized)
737+
iszero(length(ls.includedactualarrays) + length(ls.outer_reductions)) || define_eltype_vec_width!(ls.preamble, ls, vectorized, false)
768738
lower_licm_constants!(ls)
769739
isone(num_loops(ls)) || pushpreamble!(ls, definemask(getloop(ls, vectorized)))#, u₁ > 1 && u₁loopnum == vloopnum))
770740
if (Ureduct == u₁) || (u₂ != -1) || (Ureduct == -1)

src/condense_loopset.jl

Lines changed: 69 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
Base.:|(u::Unsigned, it::IndexType) = u | UInt8(it)
44
Base.:(==)(u::Unsigned, it::IndexType) = (u % UInt8) == UInt8(it)
55

6+
struct StaticType{T} end
7+
@inline gettype(::StaticType{T}) where {T} = T
68
function _append_fields!(t::Expr, body::Expr, sym::Symbol, ::Type{T}) where {T}
79
gf = GlobalRef(Core,:getfield)
810
for f 1:fieldcount(T)
@@ -13,7 +15,7 @@ function _append_fields!(t::Expr, body::Expr, sym::Symbol, ::Type{T}) where {T}
1315
elseif fieldcount(TF) 0
1416
push!(t.args, gfcall)
1517
elseif TF <: DataType
16-
push!(t.args, Expr(:call, GlobalRef(Base,:Val), gfcall))
18+
push!(t.args, Expr(:call, Expr(:curly, lv(:StaticType), gfcall)))
1719
else
1820
newsym = gensym(sym)
1921
push!(body.args, Expr(:(=), newsym, gfcall))
@@ -31,7 +33,7 @@ end
3133
elseif fieldcount(T) 0
3234
push!(t.args, :r)
3335
elseif T <: DataType
34-
push!(t.args, Expr(:call, GlobalRef(Base,:Val), :r))
36+
push!(t.args, Expr(:call, Expr(:curly, lv(:StaticType), :r)))
3537
else
3638
_append_fields!(t, body, :r, T)
3739
end
@@ -49,7 +51,7 @@ function rebuild_fields(offset::Int, ::Type{T}) where {T}
4951
elseif fieldcount(TF) 0
5052
push!(call.args, Expr(:call, gf, :t, (offset += 1), false))
5153
elseif TF <: DataType
52-
push!(call.args, Expr(:call, GlobalRef(VectorizationBase, :unwrap), Expr(:call, gf, :t, (offset += 1), false)))
54+
push!(call.args, Expr(:call, lv(:gettype), Expr(:call, gf, :t, (offset += 1), false)))
5355
else
5456
arg, offset = rebuild_fields(offset, TF)
5557
push!(call.args, arg)
@@ -63,7 +65,7 @@ end
6365
elseif fieldcount(T) 0
6466
call = Expr(:call, GlobalRef(Core,:getfield), :t, 1, false)
6567
elseif T <: DataType
66-
call = Expr(:call, GlobalRef(VectorizationBase, :unwrap), Expr(:call, GlobalRef(Core,:getfield), :t, 1, false))
68+
call = Expr(:call, lv(:gettype), Expr(:call, GlobalRef(Core,:getfield), :t, 1, false))
6769
else
6870
call, _ = rebuild_fields(0, T)
6971
end
@@ -196,16 +198,16 @@ function getroots(ls::LoopSet)::Vector{Bool}
196198
getroots!(rooted, ls)
197199
end
198200
function getroots!(rooted::Vector{Bool}, ls::LoopSet)
199-
fill!(rooted, false)
200-
ops = operations(ls)
201-
for or ls.outer_reductions
202-
recursively_set_parents_true!(rooted, ops[or])
203-
end
204-
for op ops
205-
isstore(op) && recursively_set_parents_true!(rooted, op)
206-
end
207-
length(ls.includedactualarrays) == 0 || remove_outer_reducts!(rooted, ls)
208-
return rooted
201+
fill!(rooted, false)
202+
ops = operations(ls)
203+
for or ls.outer_reductions
204+
recursively_set_parents_true!(rooted, ops[or])
205+
end
206+
for op ops
207+
isstore(op) && recursively_set_parents_true!(rooted, op)
208+
end
209+
remove_outer_reducts!(rooted, ls)
210+
return rooted
209211
end
210212
function OperationStruct!(varnames::Vector{Symbol}, ids::Vector{Int}, ls::LoopSet, op::Operation)
211213
instr = instruction(op)
@@ -635,49 +637,51 @@ function generate_call_split(
635637
end
636638
return generate_call_types(ls, preserve, shouldindbyind, roots, extra_args, inlineu₁u₂, thread, debug)
637639
end
640+
638641
# Try to condense in type stable manner
639642
function generate_call_types(
640643
ls::LoopSet, preserve::Vector{Symbol}, shouldindbyind::Vector{Bool}, roots::Vector{Bool}, extra_args::Expr, (inline,u₁,u₂)::Tuple{Bool,Int8,Int8}, thread::UInt, debug::Bool
641644
)
642645
# good place to check for split
643-
operation_descriptions = Expr(:tuple)
644-
varnames = Symbol[]; ids = Vector{Int}(undef, length(operations(ls)))
645-
ops = operations(ls)
646-
for op ops
647-
instr::Instruction = instruction(op)
648-
if (isconstant(op) && (instr == LOOPCONSTANT)) && (!roots[identifier(op)])
649-
instr = op.instruction = DROPPEDCONSTANT
650-
end
651-
push!(operation_descriptions.args, QuoteNode(instr.mod))
652-
push!(operation_descriptions.args, QuoteNode(instr.instr))
653-
push!(operation_descriptions.args, OperationStruct!(varnames, ids, ls, op))
654-
end
655-
arraysymbolinds = Symbol[]
656-
arrayref_descriptions = Expr(:tuple)
657-
duplicate_ref = fill(false, length(ls.refs_aliasing_syms))
658-
for (j,ref) enumerate(ls.refs_aliasing_syms)
659-
vpref = vptr(ref)
660-
# duplicate_ref[j] ≠ 0 && continue
661-
duplicate_ref[j] && continue
662-
push!(arrayref_descriptions.args, ArrayRefStruct(ls, ref, arraysymbolinds, ids))
663-
end
664-
argmeta = argmeta_and_consts_description(ls, arraysymbolinds)
665-
loop_bounds = loop_boundaries(ls, shouldindbyind)
666-
loop_syms = tuple_expr(QuoteNode, ls.loopsymbols)
667-
func = debug ? lv(:_avx_loopset_debug) : lv(:_avx_!)
668-
lbarg = debug ? Expr(:call, :typeof, loop_bounds) : loop_bounds
669-
configarg = (inline,u₁,u₂,ls.isbroadcast,thread)
670-
unroll_param_tup = Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), VECTORWIDTHSYMBOL)
671-
q = Expr(:call, func, unroll_param_tup, val(operation_descriptions), val(arrayref_descriptions), val(argmeta), val(loop_syms))
672-
673-
add_reassigned_syms!(extra_args, ls) # counterpart to `add_ops!` constants
674-
for (opid,sym) ls.preamble_symsym # counterpart to process_metadata! symsym extraction
675-
if instruction(ops[opid]) DROPPEDCONSTANT
676-
push!(extra_args.args, sym)
677-
end
646+
operation_descriptions = Expr(:tuple)
647+
varnames = Symbol[]; ids = Vector{Int}(undef, length(operations(ls)))
648+
ops = operations(ls)
649+
for op ops
650+
instr::Instruction = instruction(op)
651+
if (isconstant(op) && (instr == LOOPCONSTANT)) && (!roots[identifier(op)])
652+
instr = op.instruction = DROPPEDCONSTANT
653+
end
654+
push!(operation_descriptions.args, QuoteNode(instr.mod))
655+
push!(operation_descriptions.args, QuoteNode(instr.instr))
656+
push!(operation_descriptions.args, OperationStruct!(varnames, ids, ls, op))
657+
end
658+
arraysymbolinds = Symbol[]
659+
arrayref_descriptions = Expr(:tuple)
660+
duplicate_ref = fill(false, length(ls.refs_aliasing_syms))
661+
for (j,ref) enumerate(ls.refs_aliasing_syms)
662+
vpref = vptr(ref)
663+
# duplicate_ref[j] ≠ 0 && continue
664+
duplicate_ref[j] && continue
665+
push!(arrayref_descriptions.args, ArrayRefStruct(ls, ref, arraysymbolinds, ids))
666+
end
667+
argmeta = argmeta_and_consts_description(ls, arraysymbolinds)
668+
loop_bounds = loop_boundaries(ls, shouldindbyind)
669+
loop_syms = tuple_expr(QuoteNode, ls.loopsymbols)
670+
func = debug ? lv(:_avx_loopset_debug) : lv(:_avx_!)
671+
lbarg = debug ? Expr(:call, :typeof, loop_bounds) : loop_bounds
672+
configarg = (inline,u₁,u₂,ls.isbroadcast,thread)
673+
unroll_param_tup = Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), VECTORWIDTHSYMBOL)
674+
q = Expr(:call, func, unroll_param_tup, val(operation_descriptions), val(arrayref_descriptions), val(argmeta), val(loop_syms))
675+
676+
add_reassigned_syms!(extra_args, ls) # counterpart to `add_ops!` constants
677+
for (opid,sym) ls.preamble_symsym # counterpart to process_metadata! symsym extraction
678+
if instruction(ops[opid]) DROPPEDCONSTANT
679+
push!(extra_args.args, sym)
678680
end
679-
append!(extra_args.args, arraysymbolinds) # add_array_symbols!
681+
end
682+
append!(extra_args.args, arraysymbolinds) # add_array_symbols!
680683
add_external_functions!(extra_args, ls) # extract_external_functions!
684+
add_outerreduct_types!(extra_args, ls) # extract_outerreduct_types!
681685
if debug
682686
vecwidthdefq = Expr(:block)
683687
push!(q.args, Expr(:tuple, lbarg, extra_args))
@@ -686,17 +690,22 @@ function generate_call_types(
686690
vecwidthdefq = Expr(:block, Expr(:(=), vargsym, Expr(:tuple, lbarg, extra_args)))
687691
push!(q.args, Expr(:call, GlobalRef(Base,:Val), Expr(:call, GlobalRef(Base,:typeof), vargsym)), Expr(:(...), Expr(:call, lv(:flatten_to_tuple), vargsym)))
688692
end
689-
define_eltype_vec_width!(vecwidthdefq, ls, nothing)
690-
push!(vecwidthdefq.args, q)
691-
if debug
692-
pushpreamble!(ls,vecwidthdefq)
693-
Expr(:block, ls.prepreamble, ls.preamble)
694-
else
695-
setup_call_final(ls, setup_outerreduct_preserve(ls, vecwidthdefq, preserve))
696-
end
693+
define_eltype_vec_width!(vecwidthdefq, ls, nothing, true)
694+
push!(vecwidthdefq.args, q)
695+
if debug
696+
pushpreamble!(ls,vecwidthdefq)
697+
Expr(:block, ls.prepreamble, ls.preamble)
698+
else
699+
setup_call_final(ls, setup_outerreduct_preserve(ls, vecwidthdefq, preserve))
700+
end
701+
end
702+
# @inline reductinittype(::T) where {T} = StaticType{T}()
703+
typeof_expr(op::Operation) = Expr(:call, GlobalRef(Base,:typeof), name(op))
704+
function add_outerreduct_types!(extra_args::Expr, ls::LoopSet) # extract_outerreduct_types!
705+
for or ls.outer_reductions
706+
push!(extra_args.args, typeof_expr(operations(ls)[or]))
707+
end
697708
end
698-
699-
700709
"""
701710
check_args(::Vararg{AbstractArray})
702711

0 commit comments

Comments
 (0)