Skip to content

Commit 278c566

Browse files
committed
Major progress on the threading front; a simple example works when on the right branch of all the deps
1 parent 5e54996 commit 278c566

File tree

8 files changed

+540
-210
lines changed

8 files changed

+540
-210
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.12.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8+
CheapThreads = "b630d9fa-e28e-4980-896d-83ce5e2106b2"
89
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
910
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -18,6 +19,7 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1819

1920
[compat]
2021
ArrayInterface = "3"
22+
CheapThreads = "0.1.1"
2123
DocStringExtensions = "0.8"
2224
IfElse = "0.1"
2325
OffsetArrays = "1.4.1, 1.5"

src/LoopVectorization.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module LoopVectorization
22

33
using Static: StaticInt, gt
44
using VectorizationBase, SLEEFPirates, UnPack, OffsetArrays
5-
using VectorizationBase: register_size, register_count, cache_linesize, has_opmask_registers,
5+
using VectorizationBase: register_size, register_count, cache_linesize, cache_size, has_opmask_registers,
66
mask, pick_vector_width, MM, AbstractMask, data, grouped_strided_pointer,
77
maybestaticlength, maybestaticsize, staticm1, staticp1, staticmul, vzero,
88
maybestaticrange, offsetprecalc, lazymul,
@@ -24,7 +24,7 @@ using VectorizationBase: register_size, register_count, cache_linesize, has_opma
2424

2525
using IfElse: ifelse
2626

27-
using ThreadingUtilities
27+
using ThreadingUtilities, CheapThreads
2828
using SLEEFPirates: pow
2929
using Base.Broadcast: Broadcasted, DefaultArrayStyle
3030
using LinearAlgebra: Adjoint, Transpose
@@ -43,7 +43,7 @@ using Requires
4343

4444

4545
export LowDimArray, stridedpointer, indices,
46-
@avx, @_avx, *ˡ, _avx_!,
46+
@avx, @avxt, @_avx, *ˡ, _avx_!,
4747
vmap, vmap!, vmapt, vmapt!, vmapnt, vmapnt!, vmapntt, vmapntt!,
4848
tanh_fast, sigmoid_fast,
4949
vfilter, vfilter!, vmapreduce, vreduce

src/broadcast.jl

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,8 @@ end
356356
# function vmaterialize!(
357357
@generated function vmaterialize!(
358358
dest::AbstractArray{T,N}, bc::BC,
359-
::Val{Mod}, ::StaticInt{RS}, ::StaticInt{RC}, ::StaticInt{CLS}
360-
) where {T <: NativeTypes, N, BC <: Union{Broadcasted,Product}, Mod, RS, RC, CLS}
359+
::Val{Mod}, ::Val{UNROLL}, ::StaticInt{RS}, ::StaticInt{RC}, ::StaticInt{CLS}
360+
) where {T <: NativeTypes, N, BC <: Union{Broadcasted,Product}, Mod, UNROLL, RS, RC, CLS}
361361
# 2+1
362362
# we have an N dimensional loop.
363363
# need to construct the LoopSet
@@ -372,7 +372,8 @@ end
372372
add_simple_store!(ls, :dest, ArrayReference(:dest, loopsyms), elementbytes)
373373
resize!(ls.loop_order, num_loops(ls)) # num_loops may be greater than N, eg Product
374374
# return ls
375-
q = lower(ls, 0)
375+
inline, u₁, u₂, threads = UNROLL
376+
q = lower(ls, u₁ % Int, u₂ % Int, inline % Int)
376377
push!(q.args, :dest)
377378
# @show q
378379
# q
@@ -388,8 +389,8 @@ end
388389
end
389390
@generated function vmaterialize!(
390391
dest′::Union{Adjoint{T,A},Transpose{T,A}}, bc::BC,
391-
::Val{Mod}, ::StaticInt{RS}, ::StaticInt{RC}, ::StaticInt{CLS}
392-
) where {T <: NativeTypes, N, A <: AbstractArray{T,N}, BC <: Union{Broadcasted,Product}, Mod, RS, RC, CLS}
392+
::Val{Mod}, ::Val{UNROLL}, ::StaticInt{RS}, ::StaticInt{RC}, ::StaticInt{CLS}
393+
) where {T <: NativeTypes, N, A <: AbstractArray{T,N}, BC <: Union{Broadcasted,Product}, Mod, UNROLL, RS, RC, CLS}
393394
# we have an N dimensional loop.
394395
# need to construct the LoopSet
395396
ls = LoopSet(Mod)
@@ -402,7 +403,8 @@ end
402403
add_broadcast!(ls, :dest, :bc, loopsyms, BC, elementbytes)
403404
add_simple_store!(ls, :dest, ArrayReference(:dest, reverse(loopsyms)), elementbytes)
404405
resize!(ls.loop_order, num_loops(ls)) # num_loops may be greater than N, eg Product
405-
q = lower(ls, 0)
406+
inline, u₁, u₂, threads = UNROLL
407+
q = lower(ls, u₁ % Int, u₂ % Int, inline % Int)
406408
push!(q.args, :dest′)
407409
q = Expr(
408410
:block,
@@ -414,32 +416,42 @@ end
414416
# ls
415417
end
416418
# these are marked `@inline` so the `@avx` itself can choose whether or not to inline.
417-
@inline function vmaterialize!(
419+
@generated function vmaterialize!(
418420
dest::AbstractArray{T,N}, bc::Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(identity),Tuple{T2}},
419-
::Val{Mod}, RS::Static, RC::Static, CLS::Static
420-
) where {T <: NativeTypes, N, T2 <: Number, Mod}
421-
arg = T(first(bc.args))
422-
@avx for i eachindex(dest)
423-
dest[i] = arg
421+
::Val{Mod}, ::Val{UNROLL}, ::StaticInt{RS}, ::StaticInt{RC}, ::StaticInt{CLS}
422+
) where {T <: NativeTypes, N, T2 <: Number, Mod, UNROLL,RS,RC,CLS}
423+
inline, u₁, u₂, threads = UNROLL
424+
quote
425+
$(Expr(:meta,:inline))
426+
arg = T(first(bc.args))
427+
@avx inline=$inline unroll=($u₁,$u₂) thread=$threads for i eachindex(dest)
428+
dest[i] = arg
429+
end
430+
dest
424431
end
425-
dest
426432
end
427-
@inline function vmaterialize!(
433+
@generated function vmaterialize!(
428434
dest′::Union{Adjoint{T,A},Transpose{T,A}}, bc::Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(identity),Tuple{T2}},
429-
::Val{Mod}, RS::Static, RC::Static, CLS::Static
430-
) where {T <: NativeTypes, N, A <: AbstractArray{T,N}, T2 <: Number, Mod}
431-
arg = T(first(bc.args))
432-
dest = parent(dest′)
433-
@avx for i eachindex(dest)
434-
dest[i] = arg
435+
::Val{Mod}, ::Val{UNROLL}, ::StaticInt{RS}, ::StaticInt{RC}, ::StaticInt{CLS}
436+
) where {T <: NativeTypes, N, A <: AbstractArray{T,N}, T2 <: Number, Mod, UNROLL,RS,RC,CLS}
437+
inline, u₁, u₂, threads = UNROLL
438+
quote
439+
$(Expr(:meta,:inline))
440+
arg = T(first(bc.args))
441+
dest = parent(dest′)
442+
@avx inline=$inline unroll=($u₁,$u₂) thread=$threads for i eachindex(dest)
443+
dest[i] = arg
444+
end
445+
dest′
435446
end
436-
dest′
437447
end
438448

439-
@inline function vmaterialize(bc::Broadcasted, ::Val{Mod}, RS::Static, RC::Static, CLS::Static) where {Mod}
449+
@inline function vmaterialize(
450+
bc::Broadcasted, ::Val{Mod}, ::Val{UNROLL}, ::StaticInt{RS}, ::StaticInt{RC}, ::StaticInt{CLS}
451+
) where {Mod,UNROLL,RS,RC,CLS}
440452
ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args)
441-
vmaterialize!(similar(bc, ElType), bc, Val{Mod}(), RS, RC, CLS)
453+
vmaterialize!(similar(bc, ElType), bc, Val{Mod}(), StaticInt{UNROLL}(), StaticInt{RS}(), StaticInt{RC}(), StaticInt{CLS}())
442454
end
443455

444-
vmaterialize!(dest, bc, ::Val{mod}, ::StaticInt, ::StaticInt, ::StaticInt) where {mod} = Base.Broadcast.materialize!(dest, bc)
456+
vmaterialize!(dest, bc, ::Val, ::Val, ::StaticInt, ::StaticInt, ::StaticInt) = Base.Broadcast.materialize!(dest, bc)
445457

0 commit comments

Comments
 (0)