Skip to content

Commit b090052

Browse files
committed
Broadcasting cannot use the new lowering approach. 32-bit builds only have 8 floating point registers available. Imrpove lowering with LLVM < 10 (10 was good already).
1 parent 51b81f2 commit b090052

13 files changed

+125
-114
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.8.0"
4+
version = "0.8.1"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -15,10 +15,10 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1515
[compat]
1616
DocStringExtensions = "0.8"
1717
OffsetArrays = "1"
18-
SIMDPirates = "0.8.3"
18+
SIMDPirates = "0.8.4"
1919
SLEEFPirates = "0.5"
2020
UnPack = "0,1"
21-
VectorizationBase = "0.12.1"
21+
VectorizationBase = "0.12.2"
2222
julia = "1.1"
2323

2424
[extras]

src/LoopVectorization.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module LoopVectorization
22

33
using VectorizationBase, SIMDPirates, SLEEFPirates, UnPack, OffsetArrays
4-
using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector_load_expr,
4+
using VectorizationBase: REGISTER_SIZE, extract_data, num_vector_load_expr,
55
mask, masktable, pick_vector_width_val, valmul, valrem, valmuladd, valmulsub, valadd, valsub, _MM,
66
maybestaticlength, maybestaticsize, staticm1, staticp1, staticmul, subsetview, vzero, stridedpointer_for_broadcast,
77
Static, Zero, StaticUnitRange, StaticLowerUnitRange, StaticUpperUnitRange, unwrap, maybestaticrange,
@@ -29,6 +29,17 @@ export LowDimArray, stridedpointer,
2929

3030
const VECTORWIDTHSYMBOL, ELTYPESYMBOL = Symbol("##Wvecwidth##"), Symbol("##Tloopeltype##")
3131

32+
"""
33+
REGISTER_COUNT defined in VectorizationBase is supposed to correspond to the actual number of floating point registers on the system.
34+
It is hardcoded into a file at build time.
35+
However, someone may have multiple builds of Julia on the same system, some 32-bit and some 64-bit (e.g., they use 64-bit primarilly,
36+
but keep a 32-bit build on hand to debug test failures on Appveyor's 32-bit build). Thus, we don't want REGISTER_COUNT to be hardcoded
37+
in such a fashion.
38+
32-bit builds are limited to only 8 floating point registers, so we take care of that here.
39+
40+
If you want good performance, DO NOT use a 32-bit build of Julia if you don't have to.
41+
"""
42+
const REGISTER_COUNT = Sys.ARCH === :i686 ? 8 : VectorizationBase.REGISTER_COUNT
3243

3344
include("vectorizationbase_extensions.jl")
3445
include("predicates.jl")

src/broadcast.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ end
245245
# @show typeof(dest)
246246
loopsyms = [gensym(:n) for n 1:N]
247247
ls = LoopSet(Mod)
248+
ls.isbroadcast[] = true
248249
sizes = Expr(:tuple)
249250
for (n,itersym) enumerate(loopsyms)
250251
Nsym = gensym(:N)
@@ -271,6 +272,7 @@ end
271272
# need to construct the LoopSet
272273
loopsyms = [gensym(:n) for n 1:N]
273274
ls = LoopSet(Mod)
275+
ls.isbroadcast[] = true
274276
pushpreamble!(ls, Expr(:(=), :dest, Expr(:call, :parent, :dest′)))
275277
sizes = Expr(:tuple)
276278
for (n,itersym) enumerate(loopsyms)

src/costs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ function vector_cost(ic::InstructionCost, Wshift, sizeof_T)
9191
srt, sl, srp
9292
end
9393

94-
const OPAQUE_INSTRUCTION = InstructionCost(-1.0, 50, 50.0, VectorizationBase.REGISTER_COUNT)
94+
const OPAQUE_INSTRUCTION = InstructionCost(-1.0, 50, 50.0, REGISTER_COUNT)
9595

9696
instruction_cost(instruction::Instruction) = instruction.mod === :LoopVectorization ? COST[instruction.instr] : OPAQUE_INSTRUCTION
9797
instruction_cost(instruction::Symbol) = get(COST, instruction, OPAQUE_INSTRUCTION)

src/determinestrategy.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ function solve_unroll(
370370
W::Int, vectorized::Symbol,
371371
u₁loop::Loop, u₂loop::Loop
372372
)
373-
maxu₂base = maxu₁base = VectorizationBase.REGISTER_COUNT == 32 ? 10 : 6#8
373+
maxu₂base = maxu₁base = REGISTER_COUNT == 32 ? 10 : 6#8
374374
maxu₂ = maxu₂base#8
375375
maxu₁ = maxu₁base#8
376376
u₁L = length(u₁loop)
@@ -535,13 +535,13 @@ function load_elimination_cost_factor!(
535535
# if isstaticloop(loop) && length(loop) ≤ 4
536536
# itersym = loop.itersymbol
537537
# if itersym !== u₁loopsym && itersym !== u₂loopsym
538-
# return (0.25, VectorizationBase.REGISTER_COUNT == 32 ? 2.0 : 1.0)
538+
# return (0.25, REGISTER_COUNT == 32 ? 2.0 : 1.0)
539539
# # return (0.25, 1.0)
540540
# return true
541541
# end
542542
# end
543543
# end
544-
# # (0.25, VectorizationBase.REGISTER_COUNT == 32 ? 1.2 : 1.0)
544+
# # (0.25, REGISTER_COUNT == 32 ? 1.2 : 1.0)
545545
# (0.25, 1.0)
546546
cost_vec[1] += 0.1rt
547547
reg_pressure[1] += 0.51rp
@@ -707,7 +707,7 @@ function evaluate_cost_tile(
707707
end
708708
end
709709
# @show cost_vec reg_pressure
710-
costpenalty = (sum(reg_pressure) > VectorizationBase.REGISTER_COUNT) ? 2 : 1
710+
costpenalty = (sum(reg_pressure) > REGISTER_COUNT) ? 2 : 1
711711
# @show order, vectorized cost_vec reg_pressure
712712
# @show solve_unroll(ls, u₁loopsym, u₂loopsym, cost_vec, reg_pressure)
713713
u₁, u₂, ucost = solve_unroll(ls, u₁loopsym, u₂loopsym, cost_vec, reg_pressure, W, vectorized)

src/graphs.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ struct LoopSet
266266
unrollspecification::Base.RefValue{UnrollSpecification}
267267
loadelimination::Base.RefValue{Bool}
268268
lssm::Base.RefValue{LoopStartStopManager}
269+
isbroadcast::Base.RefValue{Bool}
269270
mod::Symbol
270271
end
271272

@@ -353,7 +354,8 @@ function LoopSet(mod::Symbol)
353354
Matrix{Float64}(undef, 4, 2),
354355
Matrix{Float64}(undef, 5, 2),
355356
Bool[], Bool[], Ref{UnrollSpecification}(),
356-
Ref(false), Ref{LoopStartStopManager}(), mod
357+
Ref(false), Ref{LoopStartStopManager}(),
358+
Ref(false), mod
357359
)
358360
end
359361

src/loopstartstopmanager.jl

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,31 @@ function uniquearrayrefs(ls::LoopSet)
1212
uniquerefs
1313
end
1414

15-
otherindexunrolled(loopsym::Symbol, ind::Symbol, loopdeps::Vector{Symbol}) = loopsym !== ind && loopsym loopdeps
15+
otherindexunrolled(loopsym::Symbol, ind::Symbol, loopdeps::Vector{Symbol}) = (loopsym !== ind) && (loopsym loopdeps)
1616
function otherindexunrolled(ls::LoopSet, ind::Symbol, ref::ArrayReferenceMeta)
1717
us = ls.unrollspecification[]
18-
otherindexunrolled(getloopsym(ls, us.u₁loopnum), ind, loopdependencies(ref)) || otherindexunrolled(getloopsym(ls, us.u₂loopnum), ind, loopdependencies(ref))
18+
u₁sym = names(ls)[us.u₁loopnum]
19+
u₂sym = us.u₂loopnum > 0 ? names(ls)[us.u₂loopnum] : Symbol("##undefined##")
20+
otherindexunrolled(u₁sym, ind, loopdependencies(ref)) || otherindexunrolled(u₂sym, ind, loopdependencies(ref))
1921
end
2022
multiple_with_name(n::Symbol, v::Vector{ArrayReferenceMeta}) = sum(ref -> n === vptr(ref), v) > 1
2123
# TODO: DRY between indices_calculated_by_pointer_offsets and use_loop_induct_var
2224
function indices_calculated_by_pointer_offsets(ls::LoopSet, ar::ArrayReferenceMeta)
23-
looporder = names(ls)
2425
indices = getindices(ar)
26+
ls.isbroadcast[] && return fill(false, length(indices))
27+
looporder = names(ls)
2528
offset = isdiscontiguous(ar)
2629
gespinds = Expr(:tuple)
2730
out = Vector{Bool}(undef, length(indices))
2831
li = ar.loopedindex
2932
for i eachindex(li)
3033
ii = i + offset
3134
ind = indices[ii]
32-
j = findfirst(isequal(ind), view(indices, offset+1:ii-1))
33-
if !isnothing(j)
34-
out[i] = out[j - offset]
35-
continue
36-
end
35+
# j = findfirst(isequal(ind), view(indices, offset+1:ii-1))
36+
# if !isnothing(j)
37+
# out[i] = out[j - offset]
38+
# continue
39+
# end
3740
if (!li[i]) || multiple_with_name(vptr(ar), ls.lssm[].uniquearrayrefs)
3841
out[i] = false
3942
elseif (isone(ii) && (first(looporder) === ind))
@@ -61,19 +64,21 @@ function use_loop_induct_var!(ls::LoopSet, q::Expr, ar::ArrayReferenceMeta, alla
6164
println(ar)
6265
throw("Length of indices and length of offset do not match!")
6366
end
67+
isbroadcast = ls.isbroadcast[]
6468
gespinds = Expr(:tuple)
6569
for i eachindex(li)
6670
ii = i + offset
6771
ind = indices[ii]
68-
j = findfirst(isequal(ind), view(indices, offset+1:ii-1))
69-
if !isnothing(j)
70-
j -= offset
71-
push!(gespinds.args, gespinds.args[j])
72-
uliv[i] = uliv[j]
73-
elseif (!li[i])
72+
# j = findfirst(isequal(ind), view(indices, offset+1:ii-1))
73+
# if !isnothing(j)
74+
# j -= offset
75+
# push!(gespinds.args, gespinds.args[j])
76+
# uliv[i] = uliv[j]
77+
# else
78+
if (!li[i])
7479
uliv[i] = 0
7580
push!(gespinds.args, Expr(:call, lv(:Zero)))
76-
elseif (isone(ii) && (last(looporder) === ind)) && !(otherindexunrolled(ls, ind, ar)) || multiple_with_name(vptr(ar), allarrayrefs)
81+
elseif isbroadcast || ((isone(ii) && (last(looporder) === ind)) && !(otherindexunrolled(ls, ind, ar)) || multiple_with_name(vptr(ar), allarrayrefs))
7782
uliv[i] = -findfirst(isequal(ind), looporder)::Int
7883
push!(gespinds.args, Expr(:call, lv(:Zero)))
7984
else

src/lower_memory_common.jl

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,12 @@ function addoffset!(ret::Expr, ex, offset::Integer, _mm::Bool = false)
3939
nothing
4040
end
4141
function addoffset!(ret::Expr, offset::Int, _mm::Bool = false)
42-
if iszero(offset)
43-
ex = Expr(:call, lv(:Zero))
44-
if _mm
45-
push!(ret.args, _MMind(ex))
46-
else
47-
push!(ret.args, ex)
48-
end
49-
elseif isone(offset)
50-
ex = Expr(:call, Expr(:curly, lv(:Static), offset))
51-
if _mm
52-
push!(ret.args, _MMind(ex))
53-
else
54-
push!(ret.args, ex)
55-
end
56-
elseif _mm
57-
push!(ret.args, _MMind(offset))
42+
ex = Expr(:call, Expr(:curly, lv(:Static), offset))
43+
if _mm
44+
push!(ret.args, _MMind(ex))
5845
else
59-
push!(ret.args, offset)
60-
end
46+
push!(ret.args, ex)
47+
end
6148
nothing
6249
end
6350

@@ -146,6 +133,7 @@ function mem_offset_u(op::Operation, td::UnrollArgs, inds_calc_by_ptr_offset::Ve
146133
ret = Expr(:tuple)
147134
indices = getindicesonly(op)
148135
offsets = getoffsets(op)
136+
# allbasezero = all(inds_calc_by_ptr_offset) && all(iszero, offsets)
149137
loopedindex = op.ref.loopedindex
150138
if iszero(incr₁) & iszero(incr₂)
151139
return mem_offset(op, td, inds_calc_by_ptr_offset)

test/broadcast.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,25 @@
44
for T (Float32, Float64, Int32, Int64)
55
@show T, @__LINE__
66
R = T <: Integer ? (T(-100):T(100)) : T
7-
a = rand(R,100,100,100);
8-
b = rand(R,100,100,1);
7+
a = rand(R,99,99,99);
8+
b = rand(R,99,99,1);
99
bl = LowDimArray{(true,true,false)}(b);
10-
br = reshape(b, (100,100));
10+
br = reshape(b, (99,99));
1111
c1 = a .+ b;
1212
c2 = @avx a .+ bl;
1313
@test c1 c2
1414
fill!(c2, 99999); @avx c2 .= a .+ br;
1515
@test c1 c2
1616
fill!(c2, 99999); @avx c2 .= a .+ b;
1717
@test c1 c2
18-
br = reshape(b, (100,1,100));
18+
br = reshape(b, (99,1,99));
1919
bl = LowDimArray{(true,false,true)}(br);
2020
@. c1 = a + br;
2121
fill!(c2, 99999); @avx @. c2 = a + bl;
2222
@test c1 c2
2323
fill!(c2, 99999); @avx @. c2 = a + br;
2424
@test c1 c2
25-
br = reshape(b, (1,100,100));
25+
br = reshape(b, (1,99,99));
2626
bl = LowDimArray{(false,true,true)}(br);
2727
@. c1 = a + br;
2828
fill!(c2, 99999);
@@ -33,6 +33,16 @@
3333
max_ = maximum(xs, dims=1)
3434
@test (@avx exp.(xs .- LowDimArray{(false,)}(max_))) exp.(xs .- LowDimArray{(false,)}(max_))
3535

36+
if T === Int32
37+
a = rand(T(1):T(100), 73, 1)
38+
@test sqrt.(Float32.(a)) @avx sqrt.(a)
39+
elseif T === Int64
40+
a = rand(T(1):T(100), 73, 1)
41+
@test sqrt.(a) @avx sqrt.(a)
42+
else
43+
a = rand(T, 73, 1)
44+
@test sqrt.(a) @avx sqrt.(a)
45+
end
3646

3747
a = rand(R, M); B = rand(R, M, N); c = rand(R, N); c′ = c';
3848
d1 = @. a + B * c′;

0 commit comments

Comments
 (0)