@@ -166,11 +166,6 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
166166 # CartesianIndices object with UnitRanges that behave badly on the GPU.
167167 @assert length (Rall) == length (Rother) * length (Rreduce)
168168
169- # allocate an additional, empty dimension to write the reduced value to.
170- # this does not affect the actual location in memory of the final values,
171- # but allows us to write a generalized kernel supporting partial reductions.
172- R′ = reshape (R, (size (R)... , 1 ))
173-
174169 # when the reduction dimension is contiguous in memory, we can improve performance
175170 # by having each thread read multiple consecutive elements. base on experiments,
176171 # 16 / sizeof(T) elements is usually a good choice.
@@ -193,7 +188,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
193188 # that's why each threads also loops across their inputs, processing multiple values
194189 # so that we can span the entire reduction dimension using a single item group.
195190 kernel = @metal launch= false partial_mapreduce_device (f, op, init, Val (maxthreads), Val (Rreduce), Val (Rother),
196- Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), R′ , A)
191+ Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), R, A)
197192
198193 # how many threads do we want?
199194 #
@@ -208,7 +203,11 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
208203 end
209204 end
210205
211- reduce_threads = compute_threads (kernel. pipeline. maxTotalThreadsPerThreadgroup)
206+ # XXX : Properly fix (issue #616) the issue is that the maxTotalThreadsPerThreadgroup of the unlaunched
207+ # kernel above may be greater than the maxTotalThreadsPerThreadgroup of the eventually launched
208+ # kernel below, causing errors
209+ # reduce_threads = compute_threads(kernel.pipeline.maxTotalThreadsPerThreadgroup)
210+ reduce_threads = compute_threads (512 )
212211
213212 # how many groups should we launch?
214213 #
@@ -225,9 +224,9 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
225224 # perform the actual reduction
226225 if reduce_groups == 1
227226 # we can cover the dimensions to reduce using a single group
228- @metal threads groups partial_mapreduce_device (
229- f, op, init, Val (threads) , Val (Rreduce ), Val (Rother),
230- Val ( UInt64 ( length (Rother))), Val (grain), Val (shuffle), R′, A )
227+ kernel (f, op, init, Val (maxthreads), Val (Rreduce), Val (Rother),
228+ Val (UInt64 ( length (Rother))) , Val (grain ), Val (shuffle), R, A;
229+ threads, groups )
231230 else
232231 # we need multiple steps to cover all values to reduce
233232 partial = similar (R, (size (R)... , reduce_groups))
@@ -236,11 +235,13 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
236235 # use broadcasting to extend singleton dimensions
237236 partial .= R
238237 end
238+ # NOTE: we can't use the previously-compiled kernel, since the type of `partial`
239+ # might not match the original output container (e.g. if that was a view).
239240 @metal threads groups partial_mapreduce_device (
240241 f, op, init, Val (threads), Val (Rreduce), Val (Rother),
241242 Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), partial, A)
242243
243- GPUArrays. mapreducedim! (identity, op, R′ , partial; init= init)
244+ GPUArrays. mapreducedim! (identity, op, R, partial; init= init)
244245 end
245246
246247 return R
0 commit comments