Skip to content

Commit e64841b

Browse files
committed
Fix _choose_num_threads for statically sized loops
1 parent 0260365 commit e64841b

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ We expect that any time you use the `@avx` macro with a given block of code that
2121
1. Are not indexing an array out of bounds. `@avx` does not perform any bounds checking.
2222
2. Are not iterating over an empty collection. Iterating over an empty loop such as `for i ∈ eachindex(Float64[])` is undefined behavior, and will likely result in the out of bounds memory accesses. Ensure that loops behave correctly.
2323
3. Are not relying on a specific execution order. `@avx` can and will re-order operations and loops inside its scope, so the correctness cannot depend on a particular order. You cannot implement `cumsum` with `@avx`.
24-
4. Loops increment by 1 on each iteration, e.g. `1:2:N` is not supported at the moment. (This requirement will eventually be lifted.)
2524

2625
## Usage
2726

src/codegen/lower_threads.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,8 @@ end
162162
# block_per_m, blocks_per_n
163163
# end
164164

165-
@inline function choose_num_threads(::Val{C}, ::Val{NT}, x) where {C,NT}
166-
fx = Base.uitofp(Float64, x)
167-
min(Base.fptoui(UInt, Base.ceil_llvm(0.05460264079015985*C*Base.sqrt_llvm(fx))), NT)
168-
end
165+
@inline choose_num_threads(C::Float64, NT::UInt, x::Base.BitInteger) = _choose_num_threads(Base.FastMath.mul_float_fast(C, 0.05460264079015985), NT, x)
166+
@inline _choose_num_threads(C::Float64, NT::UInt, x::Base.BitInteger) = min(Base.fptoui(UInt, Base.ceil_llvm(Base.FastMath.mul_float_fast(C, Base.sqrt_llvm(Base.uitofp(Float64, x))))), NT)
169167
function push_loop_length_expr!(q::Expr, ls::LoopSet)
170168
l = 1
171169
ndynamic = 0
@@ -328,12 +326,14 @@ function thread_one_loops_expr(
328326
ls::LoopSet, ua::UnrollArgs, valid_thread_loop::Vector{Bool}, ntmax::UInt, c::Float64,
329327
UNROLL::Tuple{Bool,Int8,Int8,Int,Int,Int,Int,Int,Int,Int,UInt}, OPS::Expr, ARF::Expr, AM::Expr, LPSYM::Expr
330328
)
329+
looplen = looplengthprod(ls)
330+
c = 0.05460264079015985 * c / looplen
331331
if all(isstaticloop, ls.loops)
332-
_num_threads = choose_num_threads(Val(c), Val(ntmax), 1)::UInt
332+
_num_threads = _choose_num_threads(c, ntmax, Int64(looplen))::UInt
333333
_num_threads > 1 || return avx_body(ls, UNROLL)
334334
choose_nthread = Expr(:(=), Symbol("#nthreads#"), _num_threads)
335335
else
336-
choose_nthread = :(choose_num_threads(Val{$(c/looplengthprod(ls))}(), Val{$ntmax}()))
336+
choose_nthread = :(_choose_num_threads($c, $ntmax))
337337
push_loop_length_expr!(choose_nthread, ls)
338338
choose_nthread = Expr(:(=), Symbol("#nthreads#"), choose_nthread)
339339
end
@@ -444,12 +444,14 @@ function thread_two_loops_expr(
444444
ls::LoopSet, ua::UnrollArgs, valid_thread_loop::Vector{Bool}, ntmax::UInt, c::Float64,
445445
UNROLL::Tuple{Bool,Int8,Int8,Int,Int,Int,Int,Int,Int,Int,UInt}, OPS::Expr, ARF::Expr, AM::Expr, LPSYM::Expr
446446
)
447+
looplen = looplengthprod(ls)
448+
c = 0.05460264079015985 * c / looplen
447449
if all(isstaticloop, ls.loops)
448-
_num_threads = choose_num_threads(Val(c), Val(ntmax), 1)::UInt
450+
_num_threads = _choose_num_threads(c, ntmax, Int64(looplen))::UInt
449451
_num_threads > 1 || return avx_body(ls, UNROLL)
450452
choose_nthread = Expr(:(=), Symbol("#nthreads#"), _num_threads)
451453
else
452-
choose_nthread = :(choose_num_threads(Val{$(c/looplengthprod(ls))}(), Val{$ntmax}()))
454+
choose_nthread = :(_choose_num_threads($c, $ntmax))
453455
push_loop_length_expr!(choose_nthread, ls)
454456
choose_nthread = Expr(:(=), Symbol("#nthreads#"), choose_nthread)
455457
end

0 commit comments

Comments
 (0)