Skip to content

Commit 9f23773

Browse files
authored
Minor mapreduce improvements (#303)
1 parent 0e06817 commit 9f23773

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

src/device/intrinsics/simd.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ Returns `a * b + c`.
8686

8787
## SIMD Shuffle Up/Down
8888

89-
simd_shuffle_map = ((Float32, "f16"),
90-
(Float16, "f32"),
89+
simd_shuffle_map = ((Float32, "f32"),
90+
(Float16, "f16"),
9191
(Int32, "s.i32"),
9292
(UInt32, "u.i32"),
9393
(Int16, "s.i16"),
@@ -133,4 +133,4 @@ modify the lower delta lanes of data because it doesn’t wrap values around the
133133
134134
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
135135
"""
136-
simd_shuffle_up
136+
simd_shuffle_up

src/mapreduce.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
148148
Base.check_reducedims(R, A)
149149
length(A) == 0 && return R # isempty(::Broadcasted) iterates
150150

151-
# be conservative about using shuffle instructions
152-
shuffle = T <: Union{Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8}
151+
# be conservative about using shuffle instructions
152+
shuffle = T <: Union{Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8}
153153

154154
# add singleton dimensions to the output container, if needed
155155
if ndims(R) < ndims(A)
@@ -184,8 +184,8 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
184184
Int(dev.maxThreadgroupMemoryLength) ÷ sizeof(T))
185185

186186
# also want to make sure the grain size is not too high as to starve threads of work.
187-
# as a simple heuristic, assume we can launch the maximum number of threads.
188-
grain = min(grain, prevpow(2, cld(length(Rreduce), maxthreads)))
187+
# as a simple heuristic, ensure we can launch the maximum number of threads.
188+
grain = min(grain, nextpow(2, cld(length(Rreduce), maxthreads)))
189189

190190
# how many threads can we launch?
191191
#

0 commit comments

Comments
 (0)