Skip to content

Commit 0a29447

Browse files
committed
Fix threads (currently broken on master). Bump VectorizationBase version, fixes #263 (bug fix in grouped_strided_pointers), fixes #265 (defines vload(::FastRange, ::Unroll))
1 parent 57a1bf9 commit 0a29447

File tree

5 files changed

+51
-46
lines changed

5 files changed

+51
-46
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Static = "0.2"
3030
StrideArraysCore = "0.1.5"
3131
ThreadingUtilities = "0.4.2"
3232
UnPack = "1"
33-
VectorizationBase = "0.20.1"
33+
VectorizationBase = "0.20.4"
3434
julia = "1.5"
3535

3636
[extras]

src/codegen/lower_threads.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,35 @@
1-
struct AVX{UNROLL,OPS,ARF,AM,LPSYM,LB,V} <: Function end
1+
struct AVX{UNROLL,OPS,ARF,AM,LPSYM,LBV,FLBV} <: Function end
22

33
# This should call the same `_avx_!(Val{UNROLL}(), Val{OPS}(), Val{ARF}(), Val{AM}(), Val{LPSYM}(), _vargs)` as normal so that this
44
# hopefully shouldn't add much to compile time.
55

6-
function (::AVX{UNROLL,OPS,ARF,AM,LPSYM,LB,V})(p::Ptr{UInt}) where {UNROLL,OPS,ARF,AM,LPSYM,LB,V}
7-
(_, _vargs) = ThreadingUtilities.load(p, Tuple{LB,V}, 2*sizeof(UInt))
8-
# Main.VARGS[Threads.threadid()] = first(_vargs)
9-
ret = _avx_!(Val{UNROLL}(), Val{OPS}(), Val{ARF}(), Val{AM}(), Val{LPSYM}(), Val(Tuple{LB,V}), flatten_to_tuple(_vargs)...)
6+
function (::AVX{UNROLL,OPS,ARF,AM,LPSYM,LBV,FLBV})(p::Ptr{UInt}) where {UNROLL,OPS,ARF,AM,LPSYM,K,LBV,FLBV<:Tuple{Vararg{Any,K}}}
7+
(_, _vargs) = ThreadingUtilities.load(p, FLBV, 2*sizeof(UInt))
8+
# Main.VARGS[Threads.threadid()] = first(_vargs)
9+
# Threads.threadid() == 2 && Core.println(typeof(_vargs))
10+
ret = _avx_!(Val{UNROLL}(), Val{OPS}(), Val{ARF}(), Val{AM}(), Val{LPSYM}(), Val{LBV}(), _vargs...)
1011
ThreadingUtilities.store!(p, ret, Int(register_size()))
1112
nothing
1213
end
13-
@generated function Base.pointer(::AVX{UNROLL,OPS,ARF,AM,LPSYM,LB,V}) where {UNROLL,OPS,ARF,AM,LPSYM,LB,V}
14-
f = AVX{UNROLL,OPS,ARF,AM,LPSYM,LB,V}()
14+
@generated function Base.pointer(::AVX{UNROLL,OPS,ARF,AM,LPSYM,LBV,FLBV}) where {UNROLL,OPS,ARF,AM,LPSYM,K,LBV,FLBV<:Tuple{Vararg{Any,K}}}
15+
f = AVX{UNROLL,OPS,ARF,AM,LPSYM,LBV,FLBV}()
1516
precompile(f, (Ptr{UInt},))
1617
quote
1718
$(Expr(:meta,:inline))
1819
@cfunction($f, Cvoid, (Ptr{UInt},))
1920
end
2021
end
2122

22-
@inline function setup_avx_threads!(p::Ptr{UInt}, fptr::Ptr{Cvoid}, args::Tuple{LB,V}) where {LB,V}
23+
@inline function setup_avx_threads!(p::Ptr{UInt}, fptr::Ptr{Cvoid}, args::LBV) where {K,LBV<:Tuple{Vararg{Any,K}}}
2324
offset = ThreadingUtilities.store!(p, fptr, sizeof(UInt))
2425
offset = ThreadingUtilities.store!(p, args, offset)
2526
nothing
2627
end
2728
@inline function avx_launch(
28-
::Val{UNROLL}, ::Val{OPS}, ::Val{ARF}, ::Val{AM}, ::Val{LPSYM}, lb::LB, vargs::V, tid
29-
) where {UNROLL,OPS,ARF,AM,LPSYM,LB,V}
30-
ThreadingUtilities.launch(setup_avx_threads!, tid, pointer(AVX{UNROLL,OPS,ARF,AM,LPSYM,LB,V}()), (lb,vargs))
29+
::Val{UNROLL}, ::Val{OPS}, ::Val{ARF}, ::Val{AM}, ::Val{LPSYM}, lbvargs::LBV, tid
30+
) where {UNROLL,OPS,ARF,AM,LPSYM,K,LBV<:Tuple{Vararg{Any,K}}}
31+
fargs = flatten_to_tuple(lbvargs)
32+
ThreadingUtilities.launch(setup_avx_threads!, tid, pointer(AVX{UNROLL,OPS,ARF,AM,LPSYM,LBV,typeof(fargs)}()), fargs)
3133
end
3234

3335
# function approx_cbrt(x)
@@ -397,7 +399,7 @@ function thread_one_loops_expr(
397399

398400
avx_launch(
399401
Val{$UNROLL}(), $OPS, $ARF, $AM, $LPSYM,
400-
$loopboundexpr, var"#vargs#", var"#thread#id#"
402+
($loopboundexpr, var"#vargs#"), var"#thread#id#"
401403
)
402404

403405
var"#thread#mask#" >>>= var"#trailzing#zeros#"
@@ -587,7 +589,7 @@ function thread_two_loops_expr(
587589
# @show var"#thread#id#" $loopboundexpr
588590
avx_launch(
589591
Val{$UNROLL}(), $OPS, $ARF, $AM, $LPSYM,
590-
$loopboundexpr, var"#vargs#", var"#thread#id#"
592+
($loopboundexpr, var"#vargs#"), var"#thread#id#"
591593
)
592594
var"#thread#mask#" >>>= var"#trailzing#zeros#"
593595

src/condense_loopset.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Base.:(==)(u::Unsigned, it::IndexType) = (u % UInt8) == UInt8(it)
55

66
struct StaticType{T} end
77
@inline gettype(::StaticType{T}) where {T} = T
8+
89
function _append_fields!(t::Expr, body::Expr, sym::Symbol, ::Type{T}) where {T}
910
gf = GlobalRef(Core,:getfield)
1011
for f 1:fieldcount(T)

src/reconstruct_loopset.jl

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ function add_mref!(
185185
) where {T}
186186
@assert B 0 "Batched arrays not supported yet."
187187
_add_mref!(sptrs, ls, ar, typetosym(T), C, B, sp, name)
188+
sizeof(T)
188189
end
189190
typetosym(::Type{T}) where {T<:NativeTypes} = (VectorizationBase.JULIA_TYPES[T])::Symbol
190191
typetosym(T) = T
@@ -239,14 +240,15 @@ function add_mref!(
239240
sptrs::Expr, ::LoopSet, ::ArrayReferenceMeta, @nospecialize(_::Type{VectorizationBase.FastRange{T,F,S,O}}),
240241
::Int, ::Int, sp::Vector{Int}, name::Symbol
241242
) where {T,F,S,O}
242-
extract_gsp!(sptrs, name)
243+
extract_gsp!(sptrs, name)
244+
sizeof(T)
243245
end
244246
function create_mrefs!(
245247
ls::LoopSet, arf::Vector{ArrayRefStruct}, as::Vector{Symbol}, os::Vector{Symbol},
246248
nopsv::Vector{NOpsType}, expanded::Vector{Bool}, ::Type{Tuple{}}
247249
)
248250
length(arf) == 0 || throw(ArgumentError("Length of array ref vector should be 0 if there are no stridedpointers."))
249-
Vector{ArrayReferenceMeta}(undef, length(arf))
251+
Vector{ArrayReferenceMeta}(undef, length(arf)), Int[]
250252
end
251253
function stabilize_grouped_stridedpointer_type(C, B, R)
252254
N = (length(C))::Int
@@ -271,12 +273,12 @@ function create_mrefs!(
271273
Cv,Bv,Rv = stabilize_grouped_stridedpointer_type(C, B, R)
272274
_create_mrefs!(ls, arf, as, os, nopsv, expanded, P.parameters, Cv, Bv, Rv)
273275
end
274-
275276
function _create_mrefs!(
276277
ls::LoopSet, arf::Vector{ArrayRefStruct}, as::Vector{Symbol}, os::Vector{Symbol},
277278
nopsv::Vector{NOpsType}, expanded::Vector{Bool}, P::Core.SimpleVector, C::Vector{Int}, B::Vector{Int}, R::Vector{Tuple{NTuple{8,Int},Int}}
278279
)
279280
mrefs::Vector{ArrayReferenceMeta} = Vector{ArrayReferenceMeta}(undef, length(arf))
281+
elementbytes::Vector{Int} = Vector{Int}(undef, length(arf))
280282
sptrs = Expr(:tuple)
281283
# pushpreamble!(ls, Expr(:(=), sptrs, :(VectorizationBase.stridedpointers(getfield(vargs, 1, false)))))
282284
pushpreamble!(ls, Expr(:(=), sptrs, :(VectorizationBase.stridedpointers(getfield(var"#vargs#", 1, false)))))
@@ -292,6 +294,7 @@ function _create_mrefs!(
292294
# if isassigned(rank_to_sps, k)
293295
Cₖ, sp = rank_to_sps[k]
294296
permute_mref!(ar, Cₖ, sp)
297+
elementbytes[i] = elementbytes[k]
295298
# end
296299
break
297300
end
@@ -300,11 +303,11 @@ function _create_mrefs!(
300303
j += 1
301304
sp = rank_to_sortperm(R[j])::Vector{Int}
302305
rank_to_sps[i] = (C[j],sp)
303-
add_mref!(sptrs, ls, ar, P[j], C[j], B[j], sp, vptr(ar))
306+
elementbytes[i] = add_mref!(sptrs, ls, ar, P[j], C[j], B[j], sp, vptr(ar))
304307
end
305308
mrefs[i] = ar
306309
end
307-
mrefs
310+
mrefs, elementbytes
308311
end
309312

310313
function num_parameters(AM)
@@ -408,11 +411,19 @@ function isexpanded(ls::LoopSet, ops::Vector{OperationStruct}, nopsv::Vector{NOp
408411
false
409412
end
410413
end
414+
function mref_elbytes(os::OperationStruct, mrefs::Vector{ArrayReferenceMeta}, elementbytes::Vector{Int})
415+
if isload(os) | isstore(os)
416+
mrefs[os.array], elementbytes[os.array]
417+
else
418+
NOTAREFERENCE, 4
419+
end
420+
end
411421
function add_op!(
412422
ls::LoopSet, instr::Instruction, ops::Vector{OperationStruct}, nopsv::Vector{NOpsType}, expandedv::Vector{Bool}, i::Int,
413-
mrefs::Vector{ArrayReferenceMeta}, opsymbol, elementbytes::Int
423+
mrefs::Vector{ArrayReferenceMeta}, opsymbol, elementbytes::Vector{Int}
414424
)
415425
os = ops[i]
426+
mref, elbytes = mref_elbytes(os, mrefs, elementbytes)
416427
# opsymbol = (isconstant(os) && instr != LOOPCONSTANT) ? instr.instr : opsymbol
417428
# If it's a CartesianIndex add or subtract, we may have to add multiple operations
418429
expanded = expandedv[i]# isexpanded(ls, ops, nopsv, i)
@@ -421,10 +432,9 @@ function add_op!(
421432
optyp = optype(os)
422433
if !expanded
423434
op = Operation(
424-
length(operations(ls)), opsymbol, elementbytes, instr,
435+
length(operations(ls)), opsymbol, elbytes, instr,
425436
optyp, loopdependencies(ls, os, true), reduceddependencies(ls, os, true),
426-
Operation[], (isload(os) | isstore(os)) ? mrefs[os.array] : NOTAREFERENCE,
427-
childdependencies(ls, os, true)
437+
Operation[], mref, childdependencies(ls, os, true)
428438
)
429439
push!(ls.operations, op)
430440
push!(opoffsets, opoffsets[end] + 1)
@@ -435,10 +445,9 @@ function add_op!(
435445
for offset = 0:nops-1
436446
sym = nops === 1 ? opsymbol : expandedopname(opsymbol, offset)
437447
op = Operation(
438-
length(operations(ls)), sym, elementbytes, instr,
439-
optyp, loopdependencies(ls, os, false, offset), reduceddependencies(ls, os, false, offset),
440-
Operation[], (isload(os) | isstore(os)) ? mrefs[os.array] : NOTAREFERENCE,
441-
childdependencies(ls, os, false, offset)
448+
length(operations(ls)), sym, elbytes, instr, optyp,
449+
loopdependencies(ls, os, false, offset), reduceddependencies(ls, os, false, offset),
450+
Operation[], mref, childdependencies(ls, os, false, offset)
442451
)
443452
push!(ls.operations, op)
444453
end
@@ -491,8 +500,8 @@ function add_parents_to_ops!(ls::LoopSet, ops::Vector{OperationStruct}, constoff
491500
constoffset
492501
end
493502
function add_ops!(
494-
ls::LoopSet, instr::Vector{Instruction}, ops::Vector{OperationStruct}, mrefs::Vector{ArrayReferenceMeta},
495-
opsymbols::Vector{Symbol}, constoffset::Int, nopsv::Vector{NOpsType}, expandedv::Vector{Bool}, elementbytes::Int
503+
ls::LoopSet, instr::Vector{Instruction}, ops::Vector{OperationStruct}, mrefs::Vector{ArrayReferenceMeta}, elementbytes::Vector{Int},
504+
opsymbols::Vector{Symbol}, constoffset::Int, nopsv::Vector{NOpsType}, expandedv::Vector{Bool}
496505
)
497506
# @show ls.loopsymbols ls.loopsymbol_offsets
498507
for i eachindex(ops)
@@ -584,12 +593,6 @@ function avx_loopset!(
584593
ls::LoopSet, instr::Vector{Instruction}, ops::Vector{OperationStruct}, arf::Vector{ArrayRefStruct},
585594
AM::Vector{Any}, LPSYM::Vector{Any}, LB::Core.SimpleVector, vargs::Core.SimpleVector
586595
)
587-
# TODO: check outer reduction types instead
588-
elementbytes = if length(vargs[1].parameters) > 0
589-
sizeofeltypes(vargs[1].parameters[1].parameters)
590-
else
591-
8
592-
end
593596
pushpreamble!(ls, :((var"#loop#bounds#", var"#vargs#") = var"#lv#tuple#args#"))
594597
add_loops!(ls, LPSYM, LB)
595598
resize!(ls.loop_order, ls.loopsymbol_offsets[end])
@@ -599,12 +602,12 @@ function avx_loopset!(
599602
expandedv = [isexpanded(ls, ops, nopsv, i) for i eachindex(ops)]
600603

601604
resize!(ls.loopindexesbit, length(ls.loops)); fill!(ls.loopindexesbit, false);
602-
mrefs = create_mrefs!(ls, arf, arraysymbolinds, opsymbols, nopsv, expandedv, vargs[1])
605+
mrefs, elementbytes = create_mrefs!(ls, arf, arraysymbolinds, opsymbols, nopsv, expandedv, vargs[1])
603606
for mref mrefs
604607
push!(ls.includedactualarrays, vptr(mref))
605608
end
606609
# extra args extraction
607-
extractind = add_ops!(ls, instr, ops, mrefs, opsymbols, 1, nopsv, expandedv, elementbytes)
610+
extractind = add_ops!(ls, instr, ops, mrefs, elementbytes, opsymbols, 1, nopsv, expandedv)
608611
extractind = process_metadata!(ls, AM, extractind)
609612
extractind = add_array_symbols!(ls, arraysymbolinds, extractind)
610613
extractind = extract_external_functions!(ls, extractind, vargs)
@@ -645,12 +648,11 @@ function _avx_loopset(
645648
ls = LoopSet(:LoopVectorization)
646649
inline, u₁, u₂, isbroadcast, W, rs, rc, cls, l1, l2, l3, nt = UNROLL
647650
set_hw!(ls, rs, rc, cls, l1, l2, l3); ls.vector_width = W; ls.isbroadcast = isbroadcast
648-
avx_loopset!(
649-
ls, instr, ops,
650-
ArrayRefStruct[ARFsv...],
651-
tovector(AMsv), tovector(LPSYMsv), LBsv, vargs
652-
)::LoopSet
653-
ls
651+
arsv = Vector{ArrayRefStruct}(undef, length(ARFsv))
652+
for i eachindex(arsv)
653+
arsv[i] = ARFsv[i]
654+
end
655+
avx_loopset!(ls, instr, ops, arsv, tovector(AMsv), tovector(LPSYMsv), LBsv, vargs)
654656
end
655657

656658
@static if VERSION v"1.7.0-DEV.421"

test/runtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ const START_TIME = time()
3131

3232
@time include("shuffleloadstores.jl")
3333

34-
if (v"1.5" < VERSION < v"1.7") && Sys.iswindows()
35-
println("Skipping Zygote tests.")
36-
else
34+
if VERSION < v"1.7-DEV"
3735
@time include("zygote.jl")
36+
else
37+
println("Skipping Zygote tests.")
3838
end
3939

4040
@time include("offsetarrays.jl")

0 commit comments

Comments
 (0)