Skip to content

Commit 7c2b820

Browse files
Remove the unnecessary reshape during mapreduce (#615)
* Remove the unnecessary reshape during mapreduce * Work around issue #616 * Add details * Only create new kernel when needed. Comment from CUDA implementation
1 parent 2a99e5c commit 7c2b820

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

src/mapreduce.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,6 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
166166
# CartesianIndices object with UnitRanges that behave badly on the GPU.
167167
@assert length(Rall) == length(Rother) * length(Rreduce)
168168

169-
# allocate an additional, empty dimension to write the reduced value to.
170-
# this does not affect the actual location in memory of the final values,
171-
# but allows us to write a generalized kernel supporting partial reductions.
172-
R′ = reshape(R, (size(R)..., 1))
173-
174169
# when the reduction dimension is contiguous in memory, we can improve performance
175170
# by having each thread read multiple consecutive elements. base on experiments,
176171
# 16 / sizeof(T) elements is usually a good choice.
@@ -193,7 +188,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
193188
# that's why each threads also loops across their inputs, processing multiple values
194189
# so that we can span the entire reduction dimension using a single item group.
195190
kernel = @metal launch=false partial_mapreduce_device(f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother),
196-
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R, A)
191+
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R, A)
197192

198193
# how many threads do we want?
199194
#
@@ -208,7 +203,11 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
208203
end
209204
end
210205

211-
reduce_threads = compute_threads(kernel.pipeline.maxTotalThreadsPerThreadgroup)
206+
# XXX: Properly fix (issue #616) the issue is that the maxTotalThreadsPerThreadgroup of the unlaunched
207+
# kernel above may be greater than the maxTotalThreadsPerThreadgroup of the eventually launched
208+
# kernel below, causing errors
209+
# reduce_threads = compute_threads(kernel.pipeline.maxTotalThreadsPerThreadgroup)
210+
reduce_threads = compute_threads(512)
212211

213212
# how many groups should we launch?
214213
#
@@ -225,9 +224,9 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
225224
# perform the actual reduction
226225
if reduce_groups == 1
227226
# we can cover the dimensions to reduce using a single group
228-
@metal threads groups partial_mapreduce_device(
229-
f, op, init, Val(threads), Val(Rreduce), Val(Rother),
230-
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R′, A)
227+
kernel(f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother),
228+
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R, A;
229+
threads, groups)
231230
else
232231
# we need multiple steps to cover all values to reduce
233232
partial = similar(R, (size(R)..., reduce_groups))
@@ -236,11 +235,13 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
236235
# use broadcasting to extend singleton dimensions
237236
partial .= R
238237
end
238+
# NOTE: we can't use the previously-compiled kernel, since the type of `partial`
239+
# might not match the original output container (e.g. if that was a view).
239240
@metal threads groups partial_mapreduce_device(
240241
f, op, init, Val(threads), Val(Rreduce), Val(Rother),
241242
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A)
242243

243-
GPUArrays.mapreducedim!(identity, op, R, partial; init=init)
244+
GPUArrays.mapreducedim!(identity, op, R, partial; init=init)
244245
end
245246

246247
return R

0 commit comments

Comments
 (0)