Skip to content

Commit fdff9e3

Browse files
Fix serial reductions (#662)
* Test serial_mapreduce * Fix serial mapreduce kernel * Properly calculate number of threads for serial mapreduce * Bump version
1 parent 8bf1b70 commit fdff9e3

File tree

5 files changed

+47
-10
lines changed

5 files changed

+47
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Metal"
22
uuid = "dde4c033-4e86-420c-a63e-0dd931031962"
3-
version = "1.8.0"
3+
version = "1.8.1"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ julia> Metal.versioninfo()
6666
macOS 26.0.0, Darwin 25.0.0
6767
6868
Toolchain:
69-
- Julia: 1.11.6
69+
- Julia: 1.11.7
7070
- LLVM: 16.0.6
7171
7272
Julia packages:
73-
- Metal.jl: 1.8.0
73+
- Metal.jl: 1.8.1
7474
- GPUArrays: 11.2.5
7575
- GPUCompiler: 1.6.1
7676
- KernelAbstractions: 0.9.38

src/mapreduce.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
140140
return
141141
end
142142

143-
function big_mapreduce_kernel(f, op, neutral, ::Val{Rreduce}, ::Val{Rother}, R, As) where {Rreduce, Rother}
144-
grid_idx = thread_position_in_threadgroup_1d() + (threadgroup_position_in_grid_1d() - 1u32) * threadgroups_per_grid_1d()
143+
function serial_mapreduce_kernel(f, op, neutral, ::Val{Rreduce}, ::Val{Rother}, R, As) where {Rreduce, Rother}
144+
grid_idx = thread_position_in_grid_1d()
145145

146146
@inbounds if grid_idx <= length(Rother)
147147
Iother = Rother[grid_idx]
@@ -166,7 +166,7 @@ end
166166

167167
## COV_EXCL_STOP
168168

169-
_big_mapreduce_threshold(dev) = dev.maxThreadsPerThreadgroup.width * num_gpu_cores()
169+
serial_mapreduce_threshold(dev) = dev.maxThreadsPerThreadgroup.width * num_gpu_cores()
170170

171171
function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
172172
A::Union{AbstractArray,Broadcast.Broadcasted};
@@ -194,10 +194,11 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
194194
@assert length(Rother) > 0
195195

196196
# If `Rother` is large enough, then a naive loop is more efficient than partial reductions.
197-
if length(Rother) >= _big_mapreduce_threshold(device(R))
198-
threads = min(length(Rreduce), 512)
197+
if length(Rother) >= serial_mapreduce_threshold(device(R))
198+
kernel = @metal launch=false serial_mapreduce_kernel(f, op, init, Val(Rreduce), Val(Rother), R, A)
199+
threads = min(length(Rother), kernel.pipeline.maxTotalThreadsPerThreadgroup)
199200
groups = cld(length(Rother), threads)
200-
kernel = @metal threads groups big_mapreduce_kernel(f, op, init, Val(Rreduce), Val(Rother), R, A)
201+
kernel(f, op, init, Val(Rreduce), Val(Rother), R, A; threads, groups)
201202
return R
202203
end
203204

src/utilities.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function versioninfo(io::IO=stdout)
2222
println(io, "- LLVM: $(LLVM.version())")
2323
println(io)
2424

25-
println(io, "Julia packages: ")
25+
println(io, "Julia packages:")
2626
println(io, "- Metal.jl: $(Base.pkgversion(Metal))")
2727
for name in [:GPUArrays, :GPUCompiler, :KernelAbstractions, :ObjectiveC,
2828
:LLVM, :LLVMDowngrader_jll]

test/array.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,3 +568,39 @@ end
568568

569569

570570
end
571+
572+
@testset "large map reduce" begin
573+
dev = device()
574+
575+
big_size = Metal.serial_mapreduce_threshold(dev) + 5
576+
a = rand(Float32, big_size, 31)
577+
c = MtlArray(a)
578+
579+
expected = minimum(a, dims=2)
580+
actual = minimum(c, dims=2)
581+
@test expected == Array(actual)
582+
583+
expected = findmax(a, dims=2)
584+
actual = findmax(c, dims=2)
585+
@test expected == map(Array, actual)
586+
587+
expected = sum(a, dims=2)
588+
actual = sum(c, dims=2)
589+
@test expected == Array(actual)
590+
591+
a = rand(Int, big_size, 31)
592+
c = MtlArray(a)
593+
594+
expected = minimum(a, dims=2)
595+
actual = minimum(c, dims=2)
596+
@test expected == Array(actual)
597+
598+
expected = findmax(a, dims=2)
599+
actual = findmax(c, dims=2)
600+
@test expected == map(Array, actual)
601+
602+
expected = sum(a, dims=2)
603+
actual = sum(c, dims=2)
604+
@test expected == Array(actual)
605+
end
606+

0 commit comments

Comments
 (0)