Skip to content

Commit 422c43f

Browse files
committed
Port mapreduce optimization from CUDA
1 parent ef12be1 commit 422c43f

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

src/Metal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module Metal
33
using GPUArrays
44
using Adapt
55
using GPUCompiler
6-
using GPUToolbox: SimpleVersion, @sv_str
6+
using GPUToolbox
77
using LLVM
88
using LLVM.Interop
99
import LLVMDowngrader_jll

src/mapreduce.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,34 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
140140
return
141141
end
142142

143+
function big_mapreduce_kernel(f, op, neutral, ::Val{Rreduce}, ::Val{Rother}, R, As) where {Rreduce, Rother}
144+
grid_idx = thread_position_in_threadgroup_1d() + (threadgroup_position_in_grid_1d() - 1u32) * threadgroups_per_grid_1d()
145+
146+
@inbounds if grid_idx <= length(Rother)
147+
Iother = Rother[grid_idx]
148+
149+
# load the neutral value
150+
neutral = if neutral === nothing
151+
R[Iother]
152+
else
153+
neutral
154+
end
155+
156+
val = op(neutral, neutral)
157+
158+
Ibegin = Rreduce[1]
159+
for Ireduce in Rreduce
160+
val = op(val, f(As[Iother + Ireduce - Ibegin]))
161+
end
162+
R[Iother] = val
163+
end
164+
return
165+
end
166+
143167
## COV_EXCL_STOP
144168

169+
_big_mapreduce_threshold(dev) = dev.maxThreadsPerThreadgroup.width * num_gpu_cores()
170+
145171
function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
146172
A::Union{AbstractArray,Broadcast.Broadcasted};
147173
init=nothing) where {F, OP, T}
@@ -165,6 +191,15 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
165191
# NOTE: we hard-code `OneTo` (`first.(axes(A))` would work too) or we get a
166192
# CartesianIndices object with UnitRanges that behave badly on the GPU.
167193
@assert length(Rall) == length(Rother) * length(Rreduce)
194+
@assert length(Rother) > 0
195+
196+
# If `Rother` is large enough, then a naive loop is more efficient than partial reductions.
197+
if length(Rother) >= _big_mapreduce_threshold(device(R))
198+
threads = min(length(Rreduce), 512)
199+
groups = cld(length(Rother), threads)
200+
kernel = @metal threads groups big_mapreduce_kernel(f, op, init, Val(Rreduce), Val(Rother), R, A)
201+
return R
202+
end
168203

169204
# when the reduction dimension is contiguous in memory, we can improve performance
170205
# by having each thread read multiple consecutive elements. base on experiments,

0 commit comments

Comments
 (0)