Skip to content

Conversation

@christiangnrd
Copy link
Member

Atomics changes:

  • Puts types to test in array and removes comment about Float32 being supported on threadgroup memory
  • Removes conditional addition of Float32 to test types. This one I'm not sure about since there's been talk about potentially supporting compilation for lower metal versions. However, Metal 3 was added in macOS 13, which is the oldest supported version of macOS.

Integer changes are analogous to #531.
Integer Intrinsics:

  • Overrides the base functions of an intrisic and also associates it to the MSL name and unexported function.
  • Is 1-arg min/max a valid intrinsic??
  • Julia 1.10+ has _mul_high so remove unused code
  • Add some tests for these

Very weirdly, overriding Base.max(x::Int64, y::Int64) breaks matrix multiplication. I have not idea what's going on.

Should we split up the intrinsics tests by their respective files (or at least atomics, math, others)? I originally had it in but it made the PR harder to review so I'll make a new one if we go that way.

@github-actions
Copy link
Contributor

github-actions bot commented Feb 15, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl
index ed4c221a..73a71576 100644
--- a/src/device/intrinsics/math.jl
+++ b/src/device/intrinsics/math.jl
@@ -343,91 +343,91 @@ end
 @device_override Base.abs(x::Int8)    = ccall("extern air.abs.s.i8", llvmcall, Int8, (Int8,), x)
 @device_override Base.abs(x::UInt8)   = ccall("extern air.abs.u.i8", llvmcall, UInt8, (UInt8,), x)
 
-@device_override Base.min(x::Int64, y::Int64)   = ccall("extern air.min.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
+@device_override Base.min(x::Int64, y::Int64) = ccall("extern air.min.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
 @device_override Base.min(x::UInt64, y::UInt64) = ccall("extern air.min.u.i64", llvmcall, UInt64, (UInt64, UInt64), x, y)
-@device_override Base.min(x::Int32, y::Int32)   = ccall("extern air.min.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
+@device_override Base.min(x::Int32, y::Int32) = ccall("extern air.min.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
 @device_override Base.min(x::UInt32, y::UInt32) = ccall("extern air.min.u.i32", llvmcall, UInt32, (UInt32, UInt32), x, y)
-@device_override Base.min(x::Int16, y::Int16)   = ccall("extern air.min.s.i16", llvmcall, Int16, (Int16, Int16), x, y)
+@device_override Base.min(x::Int16, y::Int16) = ccall("extern air.min.s.i16", llvmcall, Int16, (Int16, Int16), x, y)
 @device_override Base.min(x::UInt16, y::UInt16) = ccall("extern air.min.u.i16", llvmcall, UInt16, (UInt16, UInt16), x, y)
-@device_override Base.min(x::Int8, y::Int8)     = ccall("extern air.min.s.i8", llvmcall, Int8, (Int8, Int8), x, y)
-@device_override Base.min(x::UInt8, y::UInt8)   = ccall("extern air.min.u.i8", llvmcall, UInt8, (UInt8, UInt8), x, y)
+@device_override Base.min(x::Int8, y::Int8) = ccall("extern air.min.s.i8", llvmcall, Int8, (Int8, Int8), x, y)
+@device_override Base.min(x::UInt8, y::UInt8) = ccall("extern air.min.u.i8", llvmcall, UInt8, (UInt8, UInt8), x, y)
 
 # XXX: Breaks mul! when uncommented. MWE: using Revise, Metal;A, x = mtl(rand(Int32, 4, 4)), mtl(rand(Int32, 4)); A*x
 # @device_override Base.max(x::Int64, y::Int64)   = ccall("extern air.max.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
 @device_override Base.max(x::UInt64, y::UInt64) = ccall("extern air.max.u.i64", llvmcall, UInt64, (UInt64, UInt64), x, y)
-@device_override Base.max(x::Int32, y::Int32)   = ccall("extern air.max.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
+@device_override Base.max(x::Int32, y::Int32) = ccall("extern air.max.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
 @device_override Base.max(x::UInt32, y::UInt32) = ccall("extern air.max.u.i32", llvmcall, UInt32, (UInt32, UInt32), x, y)
-@device_override Base.max(x::Int16, y::Int16)   = ccall("extern air.max.s.i16", llvmcall, Int16, (Int16, Int16), x, y)
+@device_override Base.max(x::Int16, y::Int16) = ccall("extern air.max.s.i16", llvmcall, Int16, (Int16, Int16), x, y)
 @device_override Base.max(x::UInt16, y::UInt16) = ccall("extern air.max.u.i16", llvmcall, UInt16, (UInt16, UInt16), x, y)
-@device_override Base.max(x::Int8, y::Int8)     = ccall("extern air.max.s.i8", llvmcall, Int8, (Int8, Int8), x, y)
-@device_override Base.max(x::UInt8, y::UInt8)   = ccall("extern air.max.u.i8", llvmcall, UInt8, (UInt8, UInt8), x, y)
+@device_override Base.max(x::Int8, y::Int8) = ccall("extern air.max.s.i8", llvmcall, Int8, (Int8, Int8), x, y)
+@device_override Base.max(x::UInt8, y::UInt8) = ccall("extern air.max.u.i8", llvmcall, UInt8, (UInt8, UInt8), x, y)
 
-@device_override Base.min(x::Int64, y::Int64, z::Int64)    = ccall("extern air.min3.s.i64", llvmcall, Int64, (Int64, Int64, Int64), x, y, z)
+@device_override Base.min(x::Int64, y::Int64, z::Int64) = ccall("extern air.min3.s.i64", llvmcall, Int64, (Int64, Int64, Int64), x, y, z)
 @device_override Base.min(x::UInt64, y::UInt64, z::UInt64) = ccall("extern air.min3.u.i64", llvmcall, UInt64, (UInt64, UInt64, UInt64), x, y, z)
-@device_override Base.min(x::Int32, y::Int32, z::Int32)    = ccall("extern air.min3.s.i32", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)
+@device_override Base.min(x::Int32, y::Int32, z::Int32) = ccall("extern air.min3.s.i32", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)
 @device_override Base.min(x::UInt32, y::UInt32, z::UInt32) = ccall("extern air.min3.u.i32", llvmcall, UInt32, (UInt32, UInt32, UInt32), x, y, z)
-@device_override Base.min(x::Int16, y::Int16, z::Int16)    = ccall("extern air.min3.s.i16", llvmcall, Int16, (Int16, Int16, Int16), x, y, z)
+@device_override Base.min(x::Int16, y::Int16, z::Int16) = ccall("extern air.min3.s.i16", llvmcall, Int16, (Int16, Int16, Int16), x, y, z)
 @device_override Base.min(x::UInt16, y::UInt16, z::UInt16) = ccall("extern air.min3.u.i16", llvmcall, UInt16, (UInt16, UInt16, UInt16), x, y, z)
-@device_override Base.min(x::Int8, y::Int8, z::Int8)       = ccall("extern air.min3.s.i8", llvmcall, Int8, (Int8, Int8, Int8), x, y, z)
-@device_override Base.min(x::UInt8, y::UInt8, z::UInt8)    = ccall("extern air.min3.u.i8", llvmcall, UInt8, (UInt8, UInt8, UInt8), x, y, z)
+@device_override Base.min(x::Int8, y::Int8, z::Int8) = ccall("extern air.min3.s.i8", llvmcall, Int8, (Int8, Int8, Int8), x, y, z)
+@device_override Base.min(x::UInt8, y::UInt8, z::UInt8) = ccall("extern air.min3.u.i8", llvmcall, UInt8, (UInt8, UInt8, UInt8), x, y, z)
 
-@device_override Base.max(x::Int64, y::Int64, z::Int64)    = ccall("extern air.max3.s.i64", llvmcall, Int64, (Int64, Int64, Int64), x, y, z)
+@device_override Base.max(x::Int64, y::Int64, z::Int64) = ccall("extern air.max3.s.i64", llvmcall, Int64, (Int64, Int64, Int64), x, y, z)
 @device_override Base.max(x::UInt64, y::UInt64, z::UInt64) = ccall("extern air.max3.u.i64", llvmcall, UInt64, (UInt64, UInt64, UInt64), x, y, z)
-@device_override Base.max(x::Int32, y::Int32, z::Int32)    = ccall("extern air.max3.s.i32", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)
+@device_override Base.max(x::Int32, y::Int32, z::Int32) = ccall("extern air.max3.s.i32", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)
 @device_override Base.max(x::UInt32, y::UInt32, z::UInt32) = ccall("extern air.max3.u.i32", llvmcall, UInt32, (UInt32, UInt32, UInt32), x, y, z)
-@device_override Base.max(x::Int16, y::Int16, z::Int16)    = ccall("extern air.max3.s.i16", llvmcall, Int16, (Int16, Int16, Int16), x, y, z)
+@device_override Base.max(x::Int16, y::Int16, z::Int16) = ccall("extern air.max3.s.i16", llvmcall, Int16, (Int16, Int16, Int16), x, y, z)
 @device_override Base.max(x::UInt16, y::UInt16, z::UInt16) = ccall("extern air.max3.u.i16", llvmcall, UInt16, (UInt16, UInt16, UInt16), x, y, z)
-@device_override Base.max(x::Int8, y::Int8, z::Int8)       = ccall("extern air.max3.s.i8", llvmcall, Int8, (Int8, Int8, Int8), x, y, z)
-@device_override Base.max(x::UInt8, y::UInt8, z::UInt8)    = ccall("extern air.max3.u.i8", llvmcall, UInt8, (UInt8, UInt8, UInt8), x, y, z)
+@device_override Base.max(x::Int8, y::Int8, z::Int8) = ccall("extern air.max3.s.i8", llvmcall, Int8, (Int8, Int8, Int8), x, y, z)
+@device_override Base.max(x::UInt8, y::UInt8, z::UInt8) = ccall("extern air.max3.u.i8", llvmcall, UInt8, (UInt8, UInt8, UInt8), x, y, z)
 
-@device_override Base.leading_zeros(x::Int64)  = ccall("extern air.clz.i64", llvmcall, Int64, (Int64,), x)
+@device_override Base.leading_zeros(x::Int64) = ccall("extern air.clz.i64", llvmcall, Int64, (Int64,), x)
 @device_override Base.leading_zeros(x::UInt64) = ccall("extern air.clz.i64", llvmcall, UInt64, (UInt64,), x)
-@device_override Base.leading_zeros(x::Int32)  = ccall("extern air.clz.i32", llvmcall, Int32, (Int32,), x)
+@device_override Base.leading_zeros(x::Int32) = ccall("extern air.clz.i32", llvmcall, Int32, (Int32,), x)
 @device_override Base.leading_zeros(x::UInt32) = ccall("extern air.clz.i32", llvmcall, UInt32, (UInt32,), x)
-@device_override Base.leading_zeros(x::Int16)  = ccall("extern air.clz.i16", llvmcall, Int16, (Int16,), x)
+@device_override Base.leading_zeros(x::Int16) = ccall("extern air.clz.i16", llvmcall, Int16, (Int16,), x)
 @device_override Base.leading_zeros(x::UInt16) = ccall("extern air.clz.i16", llvmcall, UInt16, (UInt16,), x)
-@device_override Base.leading_zeros(x::Int8)   = ccall("extern air.clz.i8", llvmcall, Int8, (Int8,), x)
-@device_override Base.leading_zeros(x::UInt8)  = ccall("extern air.clz.i8", llvmcall, UInt8, (UInt8,), x)
+@device_override Base.leading_zeros(x::Int8) = ccall("extern air.clz.i8", llvmcall, Int8, (Int8,), x)
+@device_override Base.leading_zeros(x::UInt8) = ccall("extern air.clz.i8", llvmcall, UInt8, (UInt8,), x)
 const clz = leading_zeros
 
-@device_override Base.trailing_zeros(x::Int64)  = ccall("extern air.ctz.i64", llvmcall, Int64, (Int64,), x)
+@device_override Base.trailing_zeros(x::Int64) = ccall("extern air.ctz.i64", llvmcall, Int64, (Int64,), x)
 @device_override Base.trailing_zeros(x::UInt64) = ccall("extern air.ctz.i64", llvmcall, UInt64, (UInt64,), x)
-@device_override Base.trailing_zeros(x::Int32)  = ccall("extern air.ctz.i32", llvmcall, Int32, (Int32,), x)
+@device_override Base.trailing_zeros(x::Int32) = ccall("extern air.ctz.i32", llvmcall, Int32, (Int32,), x)
 @device_override Base.trailing_zeros(x::UInt32) = ccall("extern air.ctz.i32", llvmcall, UInt32, (UInt32,), x)
-@device_override Base.trailing_zeros(x::Int16)  = ccall("extern air.ctz.i16", llvmcall, Int16, (Int16,), x)
+@device_override Base.trailing_zeros(x::Int16) = ccall("extern air.ctz.i16", llvmcall, Int16, (Int16,), x)
 @device_override Base.trailing_zeros(x::UInt16) = ccall("extern air.ctz.i16", llvmcall, UInt16, (UInt16,), x)
-@device_override Base.trailing_zeros(x::Int8)   = ccall("extern air.ctz.i8", llvmcall, Int8, (Int8,), x)
-@device_override Base.trailing_zeros(x::UInt8)  = ccall("extern air.ctz.i8", llvmcall, UInt8, (UInt8,), x)
+@device_override Base.trailing_zeros(x::Int8) = ccall("extern air.ctz.i8", llvmcall, Int8, (Int8,), x)
+@device_override Base.trailing_zeros(x::UInt8) = ccall("extern air.ctz.i8", llvmcall, UInt8, (UInt8,), x)
 const ctz = trailing_zeros
 
-@device_override Base.count_ones(x::Int64)  = ccall("extern air.popcount.i64", llvmcall, Int64, (Int64,), x)
+@device_override Base.count_ones(x::Int64) = ccall("extern air.popcount.i64", llvmcall, Int64, (Int64,), x)
 @device_override Base.count_ones(x::UInt64) = ccall("extern air.popcount.i64", llvmcall, UInt64, (UInt64,), x)
-@device_override Base.count_ones(x::Int32)  = ccall("extern air.popcount.i32", llvmcall, Int32, (Int32,), x)
+@device_override Base.count_ones(x::Int32) = ccall("extern air.popcount.i32", llvmcall, Int32, (Int32,), x)
 @device_override Base.count_ones(x::UInt32) = ccall("extern air.popcount.i32", llvmcall, UInt32, (UInt32,), x)
-@device_override Base.count_ones(x::Int16)  = ccall("extern air.popcount.i16", llvmcall, Int16, (Int16,), x)
+@device_override Base.count_ones(x::Int16) = ccall("extern air.popcount.i16", llvmcall, Int16, (Int16,), x)
 @device_override Base.count_ones(x::UInt16) = ccall("extern air.popcount.i16", llvmcall, UInt16, (UInt16,), x)
-@device_override Base.count_ones(x::Int8)   = ccall("extern air.popcount.i8", llvmcall, Int8, (Int8,), x)
-@device_override Base.count_ones(x::UInt8)  = ccall("extern air.popcount.i8", llvmcall, UInt8, (UInt8,), x)
+@device_override Base.count_ones(x::Int8) = ccall("extern air.popcount.i8", llvmcall, Int8, (Int8,), x)
+@device_override Base.count_ones(x::UInt8) = ccall("extern air.popcount.i8", llvmcall, UInt8, (UInt8,), x)
 const popcount = count_ones
 
-@device_override Base.bitreverse(x::Int64)  = ccall("extern air.reverse_bits.i64", llvmcall, Int64, (Int64,), x)
+@device_override Base.bitreverse(x::Int64) = ccall("extern air.reverse_bits.i64", llvmcall, Int64, (Int64,), x)
 @device_override Base.bitreverse(x::UInt64) = ccall("extern air.reverse_bits.i64", llvmcall, UInt64, (UInt64,), x)
-@device_override Base.bitreverse(x::Int32)  = ccall("extern air.reverse_bits.i32", llvmcall, Int32, (Int32,), x)
+@device_override Base.bitreverse(x::Int32) = ccall("extern air.reverse_bits.i32", llvmcall, Int32, (Int32,), x)
 @device_override Base.bitreverse(x::UInt32) = ccall("extern air.reverse_bits.i32", llvmcall, UInt32, (UInt32,), x)
-@device_override Base.bitreverse(x::Int16)  = ccall("extern air.reverse_bits.i16", llvmcall, Int16, (Int16,), x)
+@device_override Base.bitreverse(x::Int16) = ccall("extern air.reverse_bits.i16", llvmcall, Int16, (Int16,), x)
 @device_override Base.bitreverse(x::UInt16) = ccall("extern air.reverse_bits.i16", llvmcall, UInt16, (UInt16,), x)
-@device_override Base.bitreverse(x::Int8)   = ccall("extern air.reverse_bits.i8", llvmcall, Int8, (Int8,), x)
-@device_override Base.bitreverse(x::UInt8)  = ccall("extern air.reverse_bits.i8", llvmcall, UInt8, (UInt8,), x)
+@device_override Base.bitreverse(x::Int8) = ccall("extern air.reverse_bits.i8", llvmcall, Int8, (Int8,), x)
+@device_override Base.bitreverse(x::UInt8) = ccall("extern air.reverse_bits.i8", llvmcall, UInt8, (UInt8,), x)
 const reverse_bits = bitreverse
 
-@device_override Base.MultiplicativeInverses._mul_high(x::Int64, y::Int64)   = ccall("extern air.mul_hi.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
+@device_override Base.MultiplicativeInverses._mul_high(x::Int64, y::Int64) = ccall("extern air.mul_hi.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
 @device_override Base.MultiplicativeInverses._mul_high(x::UInt64, y::UInt64) = ccall("extern air.mul_hi.u.i64", llvmcall, UInt64, (UInt64, UInt64), x, y)
-@device_override Base.MultiplicativeInverses._mul_high(x::Int32, y::Int32)   = ccall("extern air.mul_hi.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
+@device_override Base.MultiplicativeInverses._mul_high(x::Int32, y::Int32) = ccall("extern air.mul_hi.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
 @device_override Base.MultiplicativeInverses._mul_high(x::UInt32, y::UInt32) = ccall("extern air.mul_hi.u.i32", llvmcall, UInt32, (UInt32, UInt32), x, y)
-@device_override Base.MultiplicativeInverses._mul_high(x::Int16, y::Int16)   = ccall("extern air.mul_hi.s.i16", llvmcall, Int16, (Int16, Int16), x, y)
+@device_override Base.MultiplicativeInverses._mul_high(x::Int16, y::Int16) = ccall("extern air.mul_hi.s.i16", llvmcall, Int16, (Int16, Int16), x, y)
 @device_override Base.MultiplicativeInverses._mul_high(x::UInt16, y::UInt16) = ccall("extern air.mul_hi.u.i16", llvmcall, UInt16, (UInt16, UInt16), x, y)
-@device_override Base.MultiplicativeInverses._mul_high(x::Int8, y::Int8)     = ccall("extern air.mul_hi.s.i8", llvmcall, Int8, (Int8, Int8), x, y)
-@device_override Base.MultiplicativeInverses._mul_high(x::UInt8, y::UInt8)   = ccall("extern air.mul_hi.u.i8", llvmcall, UInt8, (UInt8, UInt8), x, y)
+@device_override Base.MultiplicativeInverses._mul_high(x::Int8, y::Int8) = ccall("extern air.mul_hi.s.i8", llvmcall, Int8, (Int8, Int8), x, y)
+@device_override Base.MultiplicativeInverses._mul_high(x::UInt8, y::UInt8) = ccall("extern air.mul_hi.u.i8", llvmcall, UInt8, (UInt8, UInt8), x, y)
 const mulhi = Base.MultiplicativeInverses._mul_high
 
 # From: https://forums.developer.nvidia.com/t/a-faster-and-more-accurate-implementation-of-expm1f/48085/2
diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl
index 8bd018f9..c2661566 100644
--- a/test/device/intrinsics.jl
+++ b/test/device/intrinsics.jl
@@ -174,7 +174,7 @@ FLOAT_MATH_INTR_FUNCS_3_ARG = [
 
 @testset "float math" begin
 # 1-arg functions
-@testset "$(fun)()::$T" for fun in FLOAT_MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16)
+    @testset "$(fun)()::$T" for fun in FLOAT_MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16)
     cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt]
         rand(T, 4)
     else
@@ -194,7 +194,7 @@ FLOAT_MATH_INTR_FUNCS_3_ARG = [
     @eval @test Array($mtlout) ≈ $fun.($cpuarr)
 end
 # 2-arg functions
-@testset "$(fun)()::$T" for T in (Float32, Float16), fun in FLOAT_MATH_INTR_FUNCS_2_ARG
+    @testset "$(fun)()::$T" for T in (Float32, Float16), fun in FLOAT_MATH_INTR_FUNCS_2_ARG
     N = 4
     arr1 = randn(T, N)
     arr2 = randn(T, N)
@@ -212,7 +212,7 @@ end
     @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2)
 end
 # 3-arg functions
-@testset "$(fun)()::$T" for T in (Float32, Float16), fun in FLOAT_MATH_INTR_FUNCS_3_ARG
+    @testset "$(fun)()::$T" for T in (Float32, Float16), fun in FLOAT_MATH_INTR_FUNCS_3_ARG
     N = 4
     arr1 = randn(T, N)
     arr2 = randn(T, N)
@@ -429,61 +429,61 @@ INT_MATH_INTR_FUNCS_3_ARG = [
 ]
 
 @testset "int math" begin
-# 1-arg functions
-@testset "$(fun)()::$T" for fun in INT_MATH_INTR_FUNCS_1_ARG, T in (Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64)
-    cpuarr = T[0.0, -0.0, rand(T), -rand(T)]
+    # 1-arg functions
+    @testset "$(fun)()::$T" for fun in INT_MATH_INTR_FUNCS_1_ARG, T in (Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64)
+        cpuarr = T[0.0, -0.0, rand(T), -rand(T)]
 
-    mtlarr = MtlArray(cpuarr)
+        mtlarr = MtlArray(cpuarr)
 
-    mtlout = fill!(similar(mtlarr), 0)
+        mtlout = fill!(similar(mtlarr), 0)
 
-    function kernel(res, arr)
-        idx = thread_position_in_grid_1d()
-        res[idx] = fun(arr[idx])
-        return nothing
+        function kernel(res, arr)
+            idx = thread_position_in_grid_1d()
+            res[idx] = fun(arr[idx])
+            return nothing
+        end
+        Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr)
+        @eval @test Array($mtlout) ≈ $fun.($cpuarr)
     end
-    Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr)
-    @eval @test Array($mtlout) ≈ $fun.($cpuarr)
-end
-# 2-arg functions
-@testset "$(fun)()::$T" for T in (Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64), fun in INT_MATH_INTR_FUNCS_2_ARG
-    N = 4
-    arr1 = rand(T, N)
-    arr2 = rand(T, N)
-    mtlarr1 = MtlArray(arr1)
-    mtlarr2 = MtlArray(arr2)
+    # 2-arg functions
+    @testset "$(fun)()::$T" for T in (Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64), fun in INT_MATH_INTR_FUNCS_2_ARG
+        N = 4
+        arr1 = rand(T, N)
+        arr2 = rand(T, N)
+        mtlarr1 = MtlArray(arr1)
+        mtlarr2 = MtlArray(arr2)
 
-    mtlout = fill!(similar(mtlarr1), 0)
+        mtlout = fill!(similar(mtlarr1), 0)
 
-    function kernel(res, x, y)
-        idx = thread_position_in_grid_1d()
-        res[idx] = fun(x[idx], y[idx])
-        return nothing
+        function kernel(res, x, y)
+            idx = thread_position_in_grid_1d()
+            res[idx] = fun(x[idx], y[idx])
+            return nothing
+        end
+        Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
+        @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2)
     end
-    Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
-    @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2)
-end
-# 3-arg functions
-@testset "$(fun)()::$T" for T in (Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64), fun in INT_MATH_INTR_FUNCS_3_ARG
-    N = 4
-    arr1 = rand(T, N)
-    arr2 = rand(T, N)
-    arr3 = rand(T, N)
+    # 3-arg functions
+    @testset "$(fun)()::$T" for T in (Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64), fun in INT_MATH_INTR_FUNCS_3_ARG
+        N = 4
+        arr1 = rand(T, N)
+        arr2 = rand(T, N)
+        arr3 = rand(T, N)
 
-    mtlarr1 = MtlArray(arr1)
-    mtlarr2 = MtlArray(arr2)
-    mtlarr3 = MtlArray(arr3)
+        mtlarr1 = MtlArray(arr1)
+        mtlarr2 = MtlArray(arr2)
+        mtlarr3 = MtlArray(arr3)
 
-    mtlout = fill!(similar(mtlarr1), 0)
+        mtlout = fill!(similar(mtlarr1), 0)
 
-    function kernel(res, x, y, z)
-        idx = thread_position_in_grid_1d()
-        res[idx] = fun(x[idx], y[idx], z[idx])
-        return nothing
+        function kernel(res, x, y, z)
+            idx = thread_position_in_grid_1d()
+            res[idx] = fun(x[idx], y[idx], z[idx])
+            return nothing
+        end
+        Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3)
+        @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3)
     end
-    Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3)
-    @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3)
-end
 end
 
 ############################################################################################
@@ -757,9 +757,9 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
 @testset "low-level" begin
     # TODO: make these tests actually write to the overlapping memory locations
 
-    atomic_store_load_exch_cmpexch_types = (Int32, UInt32, Float32)
-    # The Metal Shading Language spec states: "Metal 3 supports the atomic_float for device memory only"
-    local_atomic_store_load_exch_cmpexch_types = setdiff(atomic_store_load_exch_cmpexch_types, [Float32])
+        atomic_store_load_exch_cmpexch_types = (Int32, UInt32, Float32)
+        # The Metal Shading Language spec states: "Metal 3 supports the atomic_float for device memory only"
+        local_atomic_store_load_exch_cmpexch_types = setdiff(atomic_store_load_exch_cmpexch_types, [Float32])
 
     @testset "store_explicit" begin
         function global_kernel(a, val)
@@ -768,7 +768,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
             return
         end
 
-        @testset for T in atomic_store_load_exch_cmpexch_types
+            @testset for T in atomic_store_load_exch_cmpexch_types
             a = Metal.zeros(T, n)
             @metal threads=n global_kernel(a, T(42))
             @test all(isequal(42), Array(a))
@@ -782,7 +782,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
             return
         end
 
-        @testset for T in local_atomic_store_load_exch_cmpexch_types
+            @testset for T in local_atomic_store_load_exch_cmpexch_types
             a = Metal.zeros(T, n)
             @metal threads=n local_kernel(a, T(42))
             @test all(isequal(42), Array(a))
@@ -797,7 +797,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
             return
         end
 
-        @testset for T in atomic_store_load_exch_cmpexch_types
+            @testset for T in atomic_store_load_exch_cmpexch_types
             a = MtlArray(rand(T, n))
             b = Metal.zeros(T, n)
             @metal threads=n global_kernel(a, b)
@@ -816,7 +816,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
             return
         end
 
-        @testset for T in local_atomic_store_load_exch_cmpexch_types
+            @testset for T in local_atomic_store_load_exch_cmpexch_types
             a = MtlArray(rand(T, n))
             b = Metal.zeros(T, n)
             @metal threads=n local_kernel(a, b)
@@ -831,7 +831,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
             return
         end
 
-        @testset for T in atomic_store_load_exch_cmpexch_types
+            @testset for T in atomic_store_load_exch_cmpexch_types
             a = MtlArray(rand(T, n))
             @metal threads=n global_kernel(a, T(42))
             @test all(isequal(42), Array(a))
@@ -845,7 +845,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
             return
         end
 
-        @testset for T in local_atomic_store_load_exch_cmpexch_types
+            @testset for T in local_atomic_store_load_exch_cmpexch_types
             a = Metal.zeros(T, n)
             @metal threads=n local_kernel(a, T(42))
             @test all(isequal(42), Array(a))
@@ -861,7 +861,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
             return
         end
 
-        @testset for T in atomic_store_load_exch_cmpexch_types
+            @testset for T in atomic_store_load_exch_cmpexch_types
             a = MtlArray(rand(T, n))
             expected = copy(a)
             desired = T(42)
@@ -884,7 +884,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
             return
         end
 
-        @testset for T in local_atomic_store_load_exch_cmpexch_types
+            @testset for T in local_atomic_store_load_exch_cmpexch_types
             a = Metal.zeros(T, n)
             expected = copy(a)
             desired = T(42)
@@ -894,7 +894,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
     end
 
     @testset "fetch and modify" begin
-        add_sub_types = [Int32, UInt32, Float32]
+            add_sub_types = [Int32, UInt32, Float32]
         other_types = [Int32, UInt32]
         for (jlfun, mtlfun, types) in [(min, Metal.atomic_fetch_min_explicit, other_types),
                                        (max, Metal.atomic_fetch_max_explicit, other_types),
@@ -953,7 +953,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
             return
         end
 
-        @testset for T in (Int32, UInt32, Float32)
+            @testset for T in (Int32, UInt32, Float32)
             a = rand(T, n)
             b = MtlArray(a)
             val = rand(T)
@@ -989,7 +989,7 @@ end
     #       covered by the low-level tests above, but only the atomic macro functionality.
 
     @testset "load" begin
-        types = [Int32, UInt32, Float32]
+            types = [Int32, UInt32, Float32]
 
         function kernel(a, b)
             i = thread_position_in_grid_1d()
@@ -1006,7 +1006,7 @@ end
     end
 
     @testset "store" begin
-        types = [Int32, UInt32, Float32]
+            types = [Int32, UInt32, Float32]
 
         function kernel(a, b)
             i = thread_position_in_grid_1d()
@@ -1024,7 +1024,7 @@ end
     end
 
     @testset "add" begin
-        types = [Int32, UInt32, Float32]
+            types = [Int32, UInt32, Float32]
 
         function kernel(a)
             Metal.@atomic a[1] = a[1] + 1
@@ -1040,7 +1040,7 @@ end
     end
 
     @testset "sub" begin
-        types = [Int32, UInt32, Float32]
+            types = [Int32, UInt32, Float32]
 
         function kernel(a)
             Metal.@atomic a[1] = a[1] - 1

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metal Benchmarks

Benchmark suite Current: 74109bc Previous: 4324871 Ratio
private array/construct 24364.583333333336 ns 24639 ns 0.99
private array/broadcast 459125 ns 465708 ns 0.99
private array/random/randn/Float32 917375 ns 830416 ns 1.10
private array/random/randn!/Float32 608291 ns 632792 ns 0.96
private array/random/rand!/Int64 557645.5 ns 557083 ns 1.00
private array/random/rand!/Float32 563541 ns 600125 ns 0.94
private array/random/rand/Int64 923312.5 ns 766062.5 ns 1.21
private array/random/rand/Float32 818625 ns 634208 ns 1.29
private array/copyto!/gpu_to_gpu 575250 ns 664208 ns 0.87
private array/copyto!/cpu_to_gpu 663000 ns 773250 ns 0.86
private array/copyto!/gpu_to_cpu 724770.5 ns 699709 ns 1.04
private array/accumulate/1d 1406292 ns 1339667 ns 1.05
private array/accumulate/2d 1481000 ns 1386666.5 ns 1.07
private array/iteration/findall/int 2270833 ns 2103625 ns 1.08
private array/iteration/findall/bool 1991166.5 ns 1839124.5 ns 1.08
private array/iteration/findfirst/int 1830541.5 ns 1695250 ns 1.08
private array/iteration/findfirst/bool 1733917 ns 1666458 ns 1.04
private array/iteration/scalar 2498959 ns 3433416 ns 0.73
private array/iteration/logical 3453437 ns 3197875 ns 1.08
private array/iteration/findmin/1d 1863584 ns 1765958 ns 1.06
private array/iteration/findmin/2d 1423292 ns 1344812.5 ns 1.06
private array/reductions/reduce/1d 988958 ns 1043625 ns 0.95
private array/reductions/reduce/2d 698834 ns 661229 ns 1.06
private array/reductions/mapreduce/1d 972042 ns 1014646 ns 0.96
private array/reductions/mapreduce/2d 698083 ns 666687.5 ns 1.05
private array/permutedims/4d 2625333 ns 2533875 ns 1.04
private array/permutedims/2d 1093416 ns 1025083.5 ns 1.07
private array/permutedims/3d 1812125 ns 1582229 ns 1.15
private array/copy 836000 ns 579250 ns 1.44
latency/precompile 9189597500 ns 9071946333 ns 1.01
latency/ttfp 3732490041 ns 3672313458 ns 1.02
latency/import 1261503020.5 ns 1239159916 ns 1.02
integration/metaldevrt 750750 ns 723334 ns 1.04
integration/byval/slices=1 1639500 ns 1627542 ns 1.01
integration/byval/slices=3 20364291.5 ns 10224103.5 ns 1.99
integration/byval/reference 1640083 ns 1593833 ns 1.03
integration/byval/slices=2 2816875 ns 2576042 ns 1.09
kernel/indexing 457250 ns 459437.5 ns 1.00
kernel/indexing_checked 459125 ns 464146 ns 0.99
kernel/launch 8083 ns 8000 ns 1.01
metal/synchronization/stream 15292 ns 14709 ns 1.04
metal/synchronization/context 15812.5 ns 14834 ns 1.07
shared array/construct 24000 ns 24382 ns 0.98
shared array/broadcast 460687 ns 459666.5 ns 1.00
shared array/random/randn/Float32 914916.5 ns 841104 ns 1.09
shared array/random/randn!/Float32 605500 ns 640708 ns 0.95
shared array/random/rand!/Int64 563708 ns 571334 ns 0.99
shared array/random/rand!/Float32 558145.5 ns 596354.5 ns 0.94
shared array/random/rand/Int64 904125 ns 774333 ns 1.17
shared array/random/rand/Float32 819875 ns 644416 ns 1.27
shared array/copyto!/gpu_to_gpu 80250 ns 82959 ns 0.97
shared array/copyto!/cpu_to_gpu 79916 ns 83750 ns 0.95
shared array/copyto!/gpu_to_cpu 80375 ns 82583.5 ns 0.97
shared array/accumulate/1d 1430979.5 ns 1341375 ns 1.07
shared array/accumulate/2d 1487000 ns 1394854 ns 1.07
shared array/iteration/findall/int 2032792 ns 1790000 ns 1.14
shared array/iteration/findall/bool 1752270.5 ns 1571083 ns 1.12
shared array/iteration/findfirst/int 1508166 ns 1381542 ns 1.09
shared array/iteration/findfirst/bool 1419833 ns 1367708 ns 1.04
shared array/iteration/scalar 163292 ns 157917 ns 1.03
shared array/iteration/logical 3281833.5 ns 2978354.5 ns 1.10
shared array/iteration/findmin/1d 1565750 ns 1465666.5 ns 1.07
shared array/iteration/findmin/2d 1434166 ns 1367417 ns 1.05
shared array/reductions/reduce/1d 707208 ns 733625 ns 0.96
shared array/reductions/reduce/2d 708500 ns 661083 ns 1.07
shared array/reductions/mapreduce/1d 745625 ns 735437.5 ns 1.01
shared array/reductions/mapreduce/2d 703354.5 ns 665875 ns 1.06
shared array/permutedims/4d 2661541.5 ns 2500209 ns 1.06
shared array/permutedims/2d 1094042 ns 1022354 ns 1.07
shared array/permutedims/3d 1780145.5 ns 1576083 ns 1.13
shared array/copy 210375 ns 239334 ns 0.88

This comment was automatically generated by workflow using github-action-benchmark.

@maleadt
Copy link
Member

maleadt commented Feb 18, 2025

LGTM. The single-arg min looks like a typo indeed.

Very weirdly, overriding Base.max(x::Int64, y::Int64) breaks matrix multiplication. I have not idea what's going on.

Mind opening a separate issue for this?

@christiangnrd
Copy link
Member Author

Mind opening a separate issue for this?

#547

@maleadt maleadt merged commit 5af28d2 into main Feb 18, 2025
7 checks passed
@maleadt maleadt deleted the atomixfix branch February 18, 2025 12:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants