Skip to content

Conversation

@christiangnrd
Copy link
Member

Not a draft to also run benchmarks

@github-actions
Copy link
Contributor

github-actions bot commented Oct 22, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/MetalKernels.jl b/src/MetalKernels.jl
index 4c73cead..4cea05a4 100644
--- a/src/MetalKernels.jl
+++ b/src/MetalKernels.jl
@@ -135,28 +135,30 @@ function (obj::KA.Kernel{MetalBackend})(args...; ndrange=nothing, workgroupsize=
 end
 
 function KI.KIKernel(::MetalBackend, f, args...; kwargs...)
-    kern = eval(quote
-        @metal launch=false $(kwargs...) $(f)($(args...))
-    end)
-    KI.KIKernel{MetalBackend, typeof(kern)}(MetalBackend(), kern)
+    kern = eval(
+        quote
+            @metal launch = false $(kwargs...) $(f)($(args...))
+        end
+    )
+    return KI.KIKernel{MetalBackend, typeof(kern)}(MetalBackend(), kern)
 end
 
-function (obj::KI.KIKernel{MetalBackend})(args...; numworkgroups=nothing, workgroupsize=nothing, kwargs...)
+function (obj::KI.KIKernel{MetalBackend})(args...; numworkgroups = nothing, workgroupsize = nothing, kwargs...)
     threadsPerThreadgroup = isnothing(workgroupsize) ? 1 : workgroupsize
     threadgroupsPerGrid = isnothing(numworkgroups) ? 1 : numworkgroups
 
-    obj.kern(args...; threads=threadsPerThreadgroup, groups=threadgroupsPerGrid, kwargs...)
+    return obj.kern(args...; threads = threadsPerThreadgroup, groups = threadgroupsPerGrid, kwargs...)
 end
 
 
-function KI.kernel_max_work_group_size(::MetalBackend, kikern::KI.KIKernel{<:MetalBackend}; max_work_items::Int=typemax(Int))::Int
-    Int(min(kikern.kern.pipeline.maxTotalThreadsPerThreadgroup, max_work_items))
+function KI.kernel_max_work_group_size(::MetalBackend, kikern::KI.KIKernel{<:MetalBackend}; max_work_items::Int = typemax(Int))::Int
+    return Int(min(kikern.kern.pipeline.maxTotalThreadsPerThreadgroup, max_work_items))
 end
 function KI.max_work_group_size(::MetalBackend)::Int
-    Int(device().maxThreadsPerThreadgroup.width)
+    return Int(device().maxThreadsPerThreadgroup.width)
 end
 function KI.multiprocessor_count(::MetalBackend)::Int
-    Metal.num_gpu_cores()
+    return Metal.num_gpu_cores()
 end
 
 
diff --git a/src/broadcast.jl b/src/broadcast.jl
index 72ced3ed..e90f5826 100644
--- a/src/broadcast.jl
+++ b/src/broadcast.jl
@@ -66,8 +66,8 @@ end
     if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD
         ## COV_EXCL_START
         function broadcast_cartesian_static(dest, bc, Is)
-             i = KI.get_global_id().x
-             stride = KI.get_global_size().x
+            i = KI.get_global_id().x
+            stride = KI.get_global_size().x
              while 1 <= i <= length(dest)
                 I = @inbounds Is[i]
                 @inbounds dest[I] = bc[I]
@@ -91,8 +91,8 @@ end
        (isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear))
         ## COV_EXCL_START
         function broadcast_linear(dest, bc)
-             i = KI.get_global_id().x
-             stride = KI.get_global_size().x
+            i = KI.get_global_id().x
+            stride = KI.get_global_size().x
              while 1 <= i <= length(dest)
                  @inbounds dest[i] = bc[i]
                  i += stride
@@ -150,8 +150,8 @@ end
     else
         ## COV_EXCL_START
         function broadcast_cartesian(dest, bc)
-             i = KI.get_global_id().x
-             stride = KI.get_global_size().x
+            i = KI.get_global_id().x
+            stride = KI.get_global_size().x
              while 1 <= i <= length(dest)
                 I = @inbounds CartesianIndices(dest)[i]
                 @inbounds dest[I] = bc[I]
diff --git a/src/device/random.jl b/src/device/random.jl
index 383862d3..36b471d2 100644
--- a/src/device/random.jl
+++ b/src/device/random.jl
@@ -89,8 +89,8 @@ end
         @inbounds global_random_counters()[simdgroupId]
     elseif field === :ctr2
         globalId = KI.get_global_id().x +
-                   (KI.get_global_id().y - 1i32) * KI.get_global_size().x +
-                   (KI.get_global_id().z - 1i32) * KI.get_global_size().x * KI.get_global_size().y
+            (KI.get_global_id().y - 1i32) * KI.get_global_size().x +
+            (KI.get_global_id().z - 1i32) * KI.get_global_size().x * KI.get_global_size().y
         globalId % UInt32
     end::UInt32
 end
diff --git a/src/mapreduce.jl b/src/mapreduce.jl
index 3e83e9a7..82a488a6 100644
--- a/src/mapreduce.jl
+++ b/src/mapreduce.jl
@@ -197,9 +197,9 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
     # If `Rother` is large enough, then a naive loop is more efficient than partial reductions.
     if length(Rother) >= serial_mapreduce_threshold(device(R))
         kernel = KI.KIKernel(backend, serial_mapreduce_kernel, f, op, init, Val(Rreduce), Val(Rother), R, A)
-        threads = KI.kernel_max_work_group_size(backend, kernel; max_work_items=length(Rother))
+        threads = KI.kernel_max_work_group_size(backend, kernel; max_work_items = length(Rother))
         groups = cld(length(Rother), threads)
-        kernel(f, op, init, Val(Rreduce), Val(Rother), R, A; numworkgroups=groups, workgroupsize=threads)
+        kernel(f, op, init, Val(Rreduce), Val(Rother), R, A; numworkgroups = groups, workgroupsize = threads)
         return R
     end
 
@@ -224,7 +224,8 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
     # we might not be able to launch all those threads to reduce each slice in one go.
     # that's why each threads also loops across their inputs, processing multiple values
     # so that we can span the entire reduction dimension using a single item group.
-    kernel = KI.KIKernel(backend, partial_mapreduce_device, f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother),
+    kernel = KI.KIKernel(
+        backend, partial_mapreduce_device, f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother),
                                                           Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R, A)
 
     # how many threads do we want?
@@ -263,7 +264,8 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
         # we can cover the dimensions to reduce using a single group
         kernel(f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother),
                Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R, A;
-                numworkgroups=groups, workgroupsize=threads)
+            numworkgroups = groups, workgroupsize = threads
+        )
     else
         # we need multiple steps to cover all values to reduce
         partial = similar(R, (size(R)..., reduce_groups))
@@ -274,12 +276,15 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
         end
         # NOTE: we can't use the previously-compiled kernel, since the type of `partial`
         #       might not match the original output container (e.g. if that was a view).
-        KI.KIKernel(backend, partial_mapreduce_device,
+        KI.KIKernel(
+            backend, partial_mapreduce_device,
             f, op, init, Val(threads), Val(Rreduce), Val(Rother),
-            Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A)(
+            Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A
+        )(
             f, op, init, Val(threads), Val(Rreduce), Val(Rother),
             Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A;
-            numworkgroups=groups, workgroupsize=threads)
+            numworkgroups = groups, workgroupsize = threads
+        )
 
         GPUArrays.mapreducedim!(identity, op, R, partial; init=init)
     end
diff --git a/test/kernelabstractions.jl b/test/kernelabstractions.jl
index 221ee680..6f9d5a2c 100644
--- a/test/kernelabstractions.jl
+++ b/test/kernelabstractions.jl
@@ -7,6 +7,6 @@ Testsuite.testsuite(()->MetalBackend(), "Metal", Metal, MtlArray, Metal.MtlDevic
     "Convert",           # depends on https://github.com/JuliaGPU/Metal.jl/issues/69
     "SpecialFunctions",  # no equivalent Metal intrinsics for gamma, erf, etc
     "sparse",            # not supported yet
-    "CPU synchronization",
-    "fallback test: callable types",
+            "CPU synchronization",
+            "fallback test: callable types",
 ]))

@christiangnrd christiangnrd force-pushed the kaintr branch 2 times, most recently from 6ea1ed2 to 9ac3d49 Compare October 22, 2025 04:12
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metal Benchmarks

Benchmark suite Current: dbabf6f Previous: 7a00e02 Ratio
latency/precompile 28785892834 ns 24988042500 ns 1.15
latency/ttfp 2281610541 ns 2127632334 ns 1.07
latency/import 1370464854.5 ns 1222851792 ns 1.12
integration/metaldevrt 939833 ns 927875 ns 1.01
integration/byval/slices=1 1627375 ns 1620292 ns 1.00
integration/byval/slices=3 10568646 ns 8536708.5 ns 1.24
integration/byval/reference 1618459 ns 1612333 ns 1.00
integration/byval/slices=2 2669292 ns 2645042 ns 1.01
kernel/indexing 685125 ns 679583 ns 1.01
kernel/indexing_checked 687166 ns 693562.5 ns 0.99
kernel/launch 11500 ns 12125 ns 0.95
array/construct 6209 ns 6167 ns 1.01
array/broadcast 651437.5 ns 662458 ns 0.98
array/random/randn/Float32 779291 ns 676959 ns 1.15
array/random/randn!/Float32 635583 ns 632812.5 ns 1.00
array/random/rand!/Int64 564208 ns 568500 ns 0.99
array/random/rand!/Float32 593938 ns 602917 ns 0.99
array/random/rand/Int64 763062.5 ns 753958.5 ns 1.01
array/random/rand/Float32 598000 ns 630583 ns 0.95
array/accumulate/Int64/1d 1414125 ns 1355125 ns 1.04
array/accumulate/Int64/dims=1 1925749.5 ns 1889750 ns 1.02
array/accumulate/Int64/dims=2 2335542 ns 2192333 ns 1.07
array/accumulate/Int64/dims=1L 11694791 ns 11536666 ns 1.01
array/accumulate/Int64/dims=2L 10047042 ns 9881020.5 ns 1.02
array/accumulate/Float32/1d 1260791.5 ns 1204667 ns 1.05
array/accumulate/Float32/dims=1 1738125 ns 1616041 ns 1.08
array/accumulate/Float32/dims=2 2110292 ns 1935833 ns 1.09
array/accumulate/Float32/dims=1L 9985334 ns 9856584 ns 1.01
array/accumulate/Float32/dims=2L 8223708 ns 7308562 ns 1.13
array/reductions/reduce/Int64/1d 2841396 ns 1359750 ns 2.09
array/reductions/reduce/Int64/dims=1 1515833.5 ns 1174500 ns 1.29
array/reductions/reduce/Int64/dims=2 1738667 ns 1282354 ns 1.36
array/reductions/reduce/Int64/dims=1L 1905146 ns 2086875 ns 0.91
array/reductions/reduce/Int64/dims=2L 5492833.5 ns 3540312.5 ns 1.55
array/reductions/reduce/Float32/1d 2828313 ns 1033041 ns 2.74
array/reductions/reduce/Float32/dims=1 1242250 ns 892792 ns 1.39
array/reductions/reduce/Float32/dims=2 1297666.5 ns 845187.5 ns 1.54
array/reductions/reduce/Float32/dims=1L 1371729.5 ns 1374416.5 ns 1.00
array/reductions/reduce/Float32/dims=2L 2864375 ns 1883167 ns 1.52
array/reductions/mapreduce/Int64/1d 2843333.5 ns 1350646 ns 2.11
array/reductions/mapreduce/Int64/dims=1 1536208.5 ns 1153667 ns 1.33
array/reductions/mapreduce/Int64/dims=2 1762021 ns 1268854.5 ns 1.39
array/reductions/mapreduce/Int64/dims=1L 1878563 ns 2090792 ns 0.90
array/reductions/mapreduce/Int64/dims=2L 5553896 ns 3437250 ns 1.62
array/reductions/mapreduce/Float32/1d 2883187.5 ns 1076708 ns 2.68
array/reductions/mapreduce/Float32/dims=1 1257021 ns 883375 ns 1.42
array/reductions/mapreduce/Float32/dims=2 1309834 ns 809458.5 ns 1.62
array/reductions/mapreduce/Float32/dims=1L 1382041 ns 1342917 ns 1.03
array/reductions/mapreduce/Float32/dims=2L 2870562.5 ns 1879000 ns 1.53
array/private/copyto!/gpu_to_gpu 641145.5 ns 673209 ns 0.95
array/private/copyto!/cpu_to_gpu 817875 ns 825542 ns 0.99
array/private/copyto!/gpu_to_cpu 802228.5 ns 819667 ns 0.98
array/private/iteration/findall/int 1758291 ns 1644375 ns 1.07
array/private/iteration/findall/bool 1567771 ns 1487875 ns 1.05
array/private/iteration/findfirst/int 2876437 ns 1997458 ns 1.44
array/private/iteration/findfirst/bool 2772625 ns 1848771 ns 1.50
array/private/iteration/scalar 3819083 ns 5166791.5 ns 0.74
array/private/iteration/logical 4177146 ns 2604979 ns 1.60
array/private/iteration/findmin/1d 2962021 ns 2057958.5 ns 1.44
array/private/iteration/findmin/2d 2001313 ns 1622521 ns 1.23
array/private/copy 617084 ns 618062.5 ns 1.00
array/shared/copyto!/gpu_to_gpu 84667 ns 83583 ns 1.01
array/shared/copyto!/cpu_to_gpu 83875 ns 81875 ns 1.02
array/shared/copyto!/gpu_to_cpu 83917 ns 78833 ns 1.06
array/shared/iteration/findall/int 1790625 ns 1610667 ns 1.11
array/shared/iteration/findall/bool 1583083 ns 1511396 ns 1.05
array/shared/iteration/findfirst/int 2718542 ns 1423167 ns 1.91
array/shared/iteration/findfirst/bool 2572979 ns 1430792 ns 1.80
array/shared/iteration/scalar 153270.5 ns 151500 ns 1.01
array/shared/iteration/logical 3947146 ns 2532417 ns 1.56
array/shared/iteration/findmin/1d 2740166.5 ns 1572771 ns 1.74
array/shared/iteration/findmin/2d 1999167 ns 1627458 ns 1.23
array/shared/copy 250625 ns 255750 ns 0.98
array/permutedims/4d 2855187.5 ns 2444937.5 ns 1.17
array/permutedims/2d 1228375 ns 1178854 ns 1.04
array/permutedims/3d 1738187 ns 1722146 ns 1.01
metal/synchronization/stream 14208 ns 14834 ns 0.96
metal/synchronization/context 14750 ns 15125 ns 0.98

This comment was automatically generated by workflow using github-action-benchmark.

@christiangnrd christiangnrd force-pushed the kaintr branch 2 times, most recently from b4cc06a to 22e754e Compare October 22, 2025 13:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants