Skip to content

Commit d03965e

Browse files
committed
test tweaks, use single precision in choose_num_threads
1 parent 387fac3 commit d03965e

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

src/codegen/lower_threads.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,17 @@ end
165165
# block_per_m, blocks_per_n
166166
# end
167167
if Sys.ARCH === :x86_64
168-
@inline choose_num_threads(C::Float64, NT::UInt, x::Base.BitInteger) = _choose_num_threads(Base.mul_float_fast(C, 0.05460264079015985), NT, x)
168+
@inline function choose_num_threads(C::T, NT::UInt, x::Base.BitInteger) where {T<:Union{Float32,Float64}}
169+
_choose_num_threads(Base.mul_float_fast(T(C), T(0.05460264079015985)), NT, x)
170+
end
169171
else
170-
@inline choose_num_threads(C::Float64, NT::UInt, x::Base.BitInteger) = _choose_num_threads(Base.mul_float_fast(C, 0.05460264079015985 * 0.25), NT, x)
172+
@inline function choose_num_threads(C::T, NT::UInt, x::Base.BitInteger) where {T<:Union{Float32,Float64}}
173+
_choose_num_threads(Base.mul_float_fast(C, T(0.05460264079015985) * T(0.25)), NT, x)
174+
end
175+
end
176+
@inline function _choose_num_threads(C::T, NT::UInt, x::Base.BitInteger) where {T<:Union{Float32,Float64}}
177+
min(Base.fptoui(UInt, Base.ceil_llvm(Base.mul_float_fast(C, Base.sqrt_llvm_fast(Base.uitofp(T, x))))), NT)
171178
end
172-
@inline _choose_num_threads(C::Float64, NT::UInt, x::Base.BitInteger) = min(Base.fptoui(UInt, Base.ceil_llvm(Base.mul_float_fast(C, Base.sqrt_llvm(Base.uitofp(Float64, x))))), NT)
173179
function push_loop_length_expr!(q::Expr, ls::LoopSet)
174180
l = 1
175181
ndynamic = 0
@@ -342,7 +348,7 @@ function thread_one_loops_expr(
342348
_num_threads > 1 || return avx_body(ls, UNROLL)
343349
choose_nthread = Expr(:(=), Symbol("#nthreads#"), _num_threads)
344350
else
345-
choose_nthread = :(_choose_num_threads($c, $ntmax))
351+
choose_nthread = :(_choose_num_threads($(Float32(c)), $ntmax))
346352
push_loop_length_expr!(choose_nthread, ls)
347353
choose_nthread = Expr(:(=), Symbol("#nthreads#"), choose_nthread)
348354
end
@@ -474,7 +480,7 @@ function thread_two_loops_expr(
474480
_num_threads > 1 || return avx_body(ls, UNROLL)
475481
choose_nthread = Expr(:(=), Symbol("#nthreads#"), _num_threads)
476482
else
477-
choose_nthread = :(_choose_num_threads($c, $ntmax))
483+
choose_nthread = :(_choose_num_threads($(Float32(c)), $ntmax))
478484
push_loop_length_expr!(choose_nthread, ls)
479485
choose_nthread = Expr(:(=), Symbol("#nthreads#"), choose_nthread)
480486
end

test/outer_reductions.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
function awmean_lv(x::Array{T1}, σ::Array{T2}) where {T1<:Number,T2<:Number}
2+
function awmean_lv(x::AbstractArray{T1}, σ::AbstractArray{T2}) where {T1<:Number,T2<:Number}
33
n = length(x)
44
T3 = promote_type(T1,T2)
55
T = sizeof(T3) 4 ? Float32 : Float64
@@ -17,7 +17,7 @@ function awmean_lv(x::Array{T1}, σ::Array{T2}) where {T1<:Number,T2<:Number}
1717
= sqrt(one(T) / sum_of_weights)
1818
return wx, wσ, mswd
1919
end
20-
function awmean_simd(x::Array{T1}, σ::Array{T2}) where {T1<:Number,T2<:Number}
20+
function awmean_simd(x::AbstractArray{T1}, σ::AbstractArray{T2}) where {T1<:Number,T2<:Number}
2121
n = length(x)
2222
T3 = promote_type(T1,T2)
2323
T = sizeof(T3) 4 ? Float32 : Float64
@@ -39,18 +39,18 @@ end
3939
function test_awmean(::Type{T}) where {T}
4040
for n 2:100
4141
if T <: Integer
42-
x = rand(T(-100):T(100), n)
43-
σ = rand(T(1):T(10), n)
42+
x = view(rand(T(-100):T(100), n + 32), 17:n+16)
43+
σ = view(rand(T(1):T(10), n + 32), 17:n+16)
4444
else
45-
x = randn(T, n)
46-
σ = rand(T, n)
45+
x = view(randn(T, n + 32), 17:n+16)
46+
σ = view(rand(T, n + 32), 17:n+16)
4747
end
4848
wx, wσ, mswd = awmean_simd(x, σ)
4949
@test iszero(@allocated((wxlv, wσlv, mswdlv) = awmean_lv(x, σ)))
5050
wxlv, wσlv, mswdlv = awmean_lv(x, σ)
51-
@test wx wxlv
52-
@test wσlv
53-
@test mswd mswdlv
51+
isfinite(wx) && @test wx wxlv
52+
isfinite(wσ) && @test wσlv
53+
isfinite(mswd) && @test mswd mswdlv
5454
end
5555
end
5656

0 commit comments

Comments
 (0)