-
Notifications
You must be signed in to change notification settings - Fork 47
Integer & atomic Intrinsics improvements #544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl
index ed4c221a..73a71576 100644
--- a/src/device/intrinsics/math.jl
+++ b/src/device/intrinsics/math.jl
@@ -343,91 +343,91 @@ end
@device_override Base.abs(x::Int8) = ccall("extern air.abs.s.i8", llvmcall, Int8, (Int8,), x)
@device_override Base.abs(x::UInt8) = ccall("extern air.abs.u.i8", llvmcall, UInt8, (UInt8,), x)
-@device_override Base.min(x::Int64, y::Int64) = ccall("extern air.min.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
+@device_override Base.min(x::Int64, y::Int64) = ccall("extern air.min.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
@device_override Base.min(x::UInt64, y::UInt64) = ccall("extern air.min.u.i64", llvmcall, UInt64, (UInt64, UInt64), x, y)
-@device_override Base.min(x::Int32, y::Int32) = ccall("extern air.min.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
+@device_override Base.min(x::Int32, y::Int32) = ccall("extern air.min.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
@device_override Base.min(x::UInt32, y::UInt32) = ccall("extern air.min.u.i32", llvmcall, UInt32, (UInt32, UInt32), x, y)
-@device_override Base.min(x::Int16, y::Int16) = ccall("extern air.min.s.i16", llvmcall, Int16, (Int16, Int16), x, y)
+@device_override Base.min(x::Int16, y::Int16) = ccall("extern air.min.s.i16", llvmcall, Int16, (Int16, Int16), x, y)
@device_override Base.min(x::UInt16, y::UInt16) = ccall("extern air.min.u.i16", llvmcall, UInt16, (UInt16, UInt16), x, y)
-@device_override Base.min(x::Int8, y::Int8) = ccall("extern air.min.s.i8", llvmcall, Int8, (Int8, Int8), x, y)
-@device_override Base.min(x::UInt8, y::UInt8) = ccall("extern air.min.u.i8", llvmcall, UInt8, (UInt8, UInt8), x, y)
+@device_override Base.min(x::Int8, y::Int8) = ccall("extern air.min.s.i8", llvmcall, Int8, (Int8, Int8), x, y)
+@device_override Base.min(x::UInt8, y::UInt8) = ccall("extern air.min.u.i8", llvmcall, UInt8, (UInt8, UInt8), x, y)
# XXX: Breaks mul! when uncommented. MWE: using Revise, Metal;A, x = mtl(rand(Int32, 4, 4)), mtl(rand(Int32, 4)); A*x
# @device_override Base.max(x::Int64, y::Int64) = ccall("extern air.max.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
@device_override Base.max(x::UInt64, y::UInt64) = ccall("extern air.max.u.i64", llvmcall, UInt64, (UInt64, UInt64), x, y)
-@device_override Base.max(x::Int32, y::Int32) = ccall("extern air.max.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
+@device_override Base.max(x::Int32, y::Int32) = ccall("extern air.max.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
@device_override Base.max(x::UInt32, y::UInt32) = ccall("extern air.max.u.i32", llvmcall, UInt32, (UInt32, UInt32), x, y)
-@device_override Base.max(x::Int16, y::Int16) = ccall("extern air.max.s.i16", llvmcall, Int16, (Int16, Int16), x, y)
+@device_override Base.max(x::Int16, y::Int16) = ccall("extern air.max.s.i16", llvmcall, Int16, (Int16, Int16), x, y)
@device_override Base.max(x::UInt16, y::UInt16) = ccall("extern air.max.u.i16", llvmcall, UInt16, (UInt16, UInt16), x, y)
-@device_override Base.max(x::Int8, y::Int8) = ccall("extern air.max.s.i8", llvmcall, Int8, (Int8, Int8), x, y)
-@device_override Base.max(x::UInt8, y::UInt8) = ccall("extern air.max.u.i8", llvmcall, UInt8, (UInt8, UInt8), x, y)
+@device_override Base.max(x::Int8, y::Int8) = ccall("extern air.max.s.i8", llvmcall, Int8, (Int8, Int8), x, y)
+@device_override Base.max(x::UInt8, y::UInt8) = ccall("extern air.max.u.i8", llvmcall, UInt8, (UInt8, UInt8), x, y)
-@device_override Base.min(x::Int64, y::Int64, z::Int64) = ccall("extern air.min3.s.i64", llvmcall, Int64, (Int64, Int64, Int64), x, y, z)
+@device_override Base.min(x::Int64, y::Int64, z::Int64) = ccall("extern air.min3.s.i64", llvmcall, Int64, (Int64, Int64, Int64), x, y, z)
@device_override Base.min(x::UInt64, y::UInt64, z::UInt64) = ccall("extern air.min3.u.i64", llvmcall, UInt64, (UInt64, UInt64, UInt64), x, y, z)
-@device_override Base.min(x::Int32, y::Int32, z::Int32) = ccall("extern air.min3.s.i32", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)
+@device_override Base.min(x::Int32, y::Int32, z::Int32) = ccall("extern air.min3.s.i32", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)
@device_override Base.min(x::UInt32, y::UInt32, z::UInt32) = ccall("extern air.min3.u.i32", llvmcall, UInt32, (UInt32, UInt32, UInt32), x, y, z)
-@device_override Base.min(x::Int16, y::Int16, z::Int16) = ccall("extern air.min3.s.i16", llvmcall, Int16, (Int16, Int16, Int16), x, y, z)
+@device_override Base.min(x::Int16, y::Int16, z::Int16) = ccall("extern air.min3.s.i16", llvmcall, Int16, (Int16, Int16, Int16), x, y, z)
@device_override Base.min(x::UInt16, y::UInt16, z::UInt16) = ccall("extern air.min3.u.i16", llvmcall, UInt16, (UInt16, UInt16, UInt16), x, y, z)
-@device_override Base.min(x::Int8, y::Int8, z::Int8) = ccall("extern air.min3.s.i8", llvmcall, Int8, (Int8, Int8, Int8), x, y, z)
-@device_override Base.min(x::UInt8, y::UInt8, z::UInt8) = ccall("extern air.min3.u.i8", llvmcall, UInt8, (UInt8, UInt8, UInt8), x, y, z)
+@device_override Base.min(x::Int8, y::Int8, z::Int8) = ccall("extern air.min3.s.i8", llvmcall, Int8, (Int8, Int8, Int8), x, y, z)
+@device_override Base.min(x::UInt8, y::UInt8, z::UInt8) = ccall("extern air.min3.u.i8", llvmcall, UInt8, (UInt8, UInt8, UInt8), x, y, z)
-@device_override Base.max(x::Int64, y::Int64, z::Int64) = ccall("extern air.max3.s.i64", llvmcall, Int64, (Int64, Int64, Int64), x, y, z)
+@device_override Base.max(x::Int64, y::Int64, z::Int64) = ccall("extern air.max3.s.i64", llvmcall, Int64, (Int64, Int64, Int64), x, y, z)
@device_override Base.max(x::UInt64, y::UInt64, z::UInt64) = ccall("extern air.max3.u.i64", llvmcall, UInt64, (UInt64, UInt64, UInt64), x, y, z)
-@device_override Base.max(x::Int32, y::Int32, z::Int32) = ccall("extern air.max3.s.i32", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)
+@device_override Base.max(x::Int32, y::Int32, z::Int32) = ccall("extern air.max3.s.i32", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)
@device_override Base.max(x::UInt32, y::UInt32, z::UInt32) = ccall("extern air.max3.u.i32", llvmcall, UInt32, (UInt32, UInt32, UInt32), x, y, z)
-@device_override Base.max(x::Int16, y::Int16, z::Int16) = ccall("extern air.max3.s.i16", llvmcall, Int16, (Int16, Int16, Int16), x, y, z)
+@device_override Base.max(x::Int16, y::Int16, z::Int16) = ccall("extern air.max3.s.i16", llvmcall, Int16, (Int16, Int16, Int16), x, y, z)
@device_override Base.max(x::UInt16, y::UInt16, z::UInt16) = ccall("extern air.max3.u.i16", llvmcall, UInt16, (UInt16, UInt16, UInt16), x, y, z)
-@device_override Base.max(x::Int8, y::Int8, z::Int8) = ccall("extern air.max3.s.i8", llvmcall, Int8, (Int8, Int8, Int8), x, y, z)
-@device_override Base.max(x::UInt8, y::UInt8, z::UInt8) = ccall("extern air.max3.u.i8", llvmcall, UInt8, (UInt8, UInt8, UInt8), x, y, z)
+@device_override Base.max(x::Int8, y::Int8, z::Int8) = ccall("extern air.max3.s.i8", llvmcall, Int8, (Int8, Int8, Int8), x, y, z)
+@device_override Base.max(x::UInt8, y::UInt8, z::UInt8) = ccall("extern air.max3.u.i8", llvmcall, UInt8, (UInt8, UInt8, UInt8), x, y, z)
-@device_override Base.leading_zeros(x::Int64) = ccall("extern air.clz.i64", llvmcall, Int64, (Int64,), x)
+@device_override Base.leading_zeros(x::Int64) = ccall("extern air.clz.i64", llvmcall, Int64, (Int64,), x)
@device_override Base.leading_zeros(x::UInt64) = ccall("extern air.clz.i64", llvmcall, UInt64, (UInt64,), x)
-@device_override Base.leading_zeros(x::Int32) = ccall("extern air.clz.i32", llvmcall, Int32, (Int32,), x)
+@device_override Base.leading_zeros(x::Int32) = ccall("extern air.clz.i32", llvmcall, Int32, (Int32,), x)
@device_override Base.leading_zeros(x::UInt32) = ccall("extern air.clz.i32", llvmcall, UInt32, (UInt32,), x)
-@device_override Base.leading_zeros(x::Int16) = ccall("extern air.clz.i16", llvmcall, Int16, (Int16,), x)
+@device_override Base.leading_zeros(x::Int16) = ccall("extern air.clz.i16", llvmcall, Int16, (Int16,), x)
@device_override Base.leading_zeros(x::UInt16) = ccall("extern air.clz.i16", llvmcall, UInt16, (UInt16,), x)
-@device_override Base.leading_zeros(x::Int8) = ccall("extern air.clz.i8", llvmcall, Int8, (Int8,), x)
-@device_override Base.leading_zeros(x::UInt8) = ccall("extern air.clz.i8", llvmcall, UInt8, (UInt8,), x)
+@device_override Base.leading_zeros(x::Int8) = ccall("extern air.clz.i8", llvmcall, Int8, (Int8,), x)
+@device_override Base.leading_zeros(x::UInt8) = ccall("extern air.clz.i8", llvmcall, UInt8, (UInt8,), x)
const clz = leading_zeros
-@device_override Base.trailing_zeros(x::Int64) = ccall("extern air.ctz.i64", llvmcall, Int64, (Int64,), x)
+@device_override Base.trailing_zeros(x::Int64) = ccall("extern air.ctz.i64", llvmcall, Int64, (Int64,), x)
@device_override Base.trailing_zeros(x::UInt64) = ccall("extern air.ctz.i64", llvmcall, UInt64, (UInt64,), x)
-@device_override Base.trailing_zeros(x::Int32) = ccall("extern air.ctz.i32", llvmcall, Int32, (Int32,), x)
+@device_override Base.trailing_zeros(x::Int32) = ccall("extern air.ctz.i32", llvmcall, Int32, (Int32,), x)
@device_override Base.trailing_zeros(x::UInt32) = ccall("extern air.ctz.i32", llvmcall, UInt32, (UInt32,), x)
-@device_override Base.trailing_zeros(x::Int16) = ccall("extern air.ctz.i16", llvmcall, Int16, (Int16,), x)
+@device_override Base.trailing_zeros(x::Int16) = ccall("extern air.ctz.i16", llvmcall, Int16, (Int16,), x)
@device_override Base.trailing_zeros(x::UInt16) = ccall("extern air.ctz.i16", llvmcall, UInt16, (UInt16,), x)
-@device_override Base.trailing_zeros(x::Int8) = ccall("extern air.ctz.i8", llvmcall, Int8, (Int8,), x)
-@device_override Base.trailing_zeros(x::UInt8) = ccall("extern air.ctz.i8", llvmcall, UInt8, (UInt8,), x)
+@device_override Base.trailing_zeros(x::Int8) = ccall("extern air.ctz.i8", llvmcall, Int8, (Int8,), x)
+@device_override Base.trailing_zeros(x::UInt8) = ccall("extern air.ctz.i8", llvmcall, UInt8, (UInt8,), x)
const ctz = trailing_zeros
-@device_override Base.count_ones(x::Int64) = ccall("extern air.popcount.i64", llvmcall, Int64, (Int64,), x)
+@device_override Base.count_ones(x::Int64) = ccall("extern air.popcount.i64", llvmcall, Int64, (Int64,), x)
@device_override Base.count_ones(x::UInt64) = ccall("extern air.popcount.i64", llvmcall, UInt64, (UInt64,), x)
-@device_override Base.count_ones(x::Int32) = ccall("extern air.popcount.i32", llvmcall, Int32, (Int32,), x)
+@device_override Base.count_ones(x::Int32) = ccall("extern air.popcount.i32", llvmcall, Int32, (Int32,), x)
@device_override Base.count_ones(x::UInt32) = ccall("extern air.popcount.i32", llvmcall, UInt32, (UInt32,), x)
-@device_override Base.count_ones(x::Int16) = ccall("extern air.popcount.i16", llvmcall, Int16, (Int16,), x)
+@device_override Base.count_ones(x::Int16) = ccall("extern air.popcount.i16", llvmcall, Int16, (Int16,), x)
@device_override Base.count_ones(x::UInt16) = ccall("extern air.popcount.i16", llvmcall, UInt16, (UInt16,), x)
-@device_override Base.count_ones(x::Int8) = ccall("extern air.popcount.i8", llvmcall, Int8, (Int8,), x)
-@device_override Base.count_ones(x::UInt8) = ccall("extern air.popcount.i8", llvmcall, UInt8, (UInt8,), x)
+@device_override Base.count_ones(x::Int8) = ccall("extern air.popcount.i8", llvmcall, Int8, (Int8,), x)
+@device_override Base.count_ones(x::UInt8) = ccall("extern air.popcount.i8", llvmcall, UInt8, (UInt8,), x)
const popcount = count_ones
-@device_override Base.bitreverse(x::Int64) = ccall("extern air.reverse_bits.i64", llvmcall, Int64, (Int64,), x)
+@device_override Base.bitreverse(x::Int64) = ccall("extern air.reverse_bits.i64", llvmcall, Int64, (Int64,), x)
@device_override Base.bitreverse(x::UInt64) = ccall("extern air.reverse_bits.i64", llvmcall, UInt64, (UInt64,), x)
-@device_override Base.bitreverse(x::Int32) = ccall("extern air.reverse_bits.i32", llvmcall, Int32, (Int32,), x)
+@device_override Base.bitreverse(x::Int32) = ccall("extern air.reverse_bits.i32", llvmcall, Int32, (Int32,), x)
@device_override Base.bitreverse(x::UInt32) = ccall("extern air.reverse_bits.i32", llvmcall, UInt32, (UInt32,), x)
-@device_override Base.bitreverse(x::Int16) = ccall("extern air.reverse_bits.i16", llvmcall, Int16, (Int16,), x)
+@device_override Base.bitreverse(x::Int16) = ccall("extern air.reverse_bits.i16", llvmcall, Int16, (Int16,), x)
@device_override Base.bitreverse(x::UInt16) = ccall("extern air.reverse_bits.i16", llvmcall, UInt16, (UInt16,), x)
-@device_override Base.bitreverse(x::Int8) = ccall("extern air.reverse_bits.i8", llvmcall, Int8, (Int8,), x)
-@device_override Base.bitreverse(x::UInt8) = ccall("extern air.reverse_bits.i8", llvmcall, UInt8, (UInt8,), x)
+@device_override Base.bitreverse(x::Int8) = ccall("extern air.reverse_bits.i8", llvmcall, Int8, (Int8,), x)
+@device_override Base.bitreverse(x::UInt8) = ccall("extern air.reverse_bits.i8", llvmcall, UInt8, (UInt8,), x)
const reverse_bits = bitreverse
-@device_override Base.MultiplicativeInverses._mul_high(x::Int64, y::Int64) = ccall("extern air.mul_hi.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
+@device_override Base.MultiplicativeInverses._mul_high(x::Int64, y::Int64) = ccall("extern air.mul_hi.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
@device_override Base.MultiplicativeInverses._mul_high(x::UInt64, y::UInt64) = ccall("extern air.mul_hi.u.i64", llvmcall, UInt64, (UInt64, UInt64), x, y)
-@device_override Base.MultiplicativeInverses._mul_high(x::Int32, y::Int32) = ccall("extern air.mul_hi.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
+@device_override Base.MultiplicativeInverses._mul_high(x::Int32, y::Int32) = ccall("extern air.mul_hi.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
@device_override Base.MultiplicativeInverses._mul_high(x::UInt32, y::UInt32) = ccall("extern air.mul_hi.u.i32", llvmcall, UInt32, (UInt32, UInt32), x, y)
-@device_override Base.MultiplicativeInverses._mul_high(x::Int16, y::Int16) = ccall("extern air.mul_hi.s.i16", llvmcall, Int16, (Int16, Int16), x, y)
+@device_override Base.MultiplicativeInverses._mul_high(x::Int16, y::Int16) = ccall("extern air.mul_hi.s.i16", llvmcall, Int16, (Int16, Int16), x, y)
@device_override Base.MultiplicativeInverses._mul_high(x::UInt16, y::UInt16) = ccall("extern air.mul_hi.u.i16", llvmcall, UInt16, (UInt16, UInt16), x, y)
-@device_override Base.MultiplicativeInverses._mul_high(x::Int8, y::Int8) = ccall("extern air.mul_hi.s.i8", llvmcall, Int8, (Int8, Int8), x, y)
-@device_override Base.MultiplicativeInverses._mul_high(x::UInt8, y::UInt8) = ccall("extern air.mul_hi.u.i8", llvmcall, UInt8, (UInt8, UInt8), x, y)
+@device_override Base.MultiplicativeInverses._mul_high(x::Int8, y::Int8) = ccall("extern air.mul_hi.s.i8", llvmcall, Int8, (Int8, Int8), x, y)
+@device_override Base.MultiplicativeInverses._mul_high(x::UInt8, y::UInt8) = ccall("extern air.mul_hi.u.i8", llvmcall, UInt8, (UInt8, UInt8), x, y)
const mulhi = Base.MultiplicativeInverses._mul_high
# From: https://forums.developer.nvidia.com/t/a-faster-and-more-accurate-implementation-of-expm1f/48085/2
diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl
index 8bd018f9..c2661566 100644
--- a/test/device/intrinsics.jl
+++ b/test/device/intrinsics.jl
@@ -174,7 +174,7 @@ FLOAT_MATH_INTR_FUNCS_3_ARG = [
@testset "float math" begin
# 1-arg functions
-@testset "$(fun)()::$T" for fun in FLOAT_MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16)
+ @testset "$(fun)()::$T" for fun in FLOAT_MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16)
cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt]
rand(T, 4)
else
@@ -194,7 +194,7 @@ FLOAT_MATH_INTR_FUNCS_3_ARG = [
@eval @test Array($mtlout) ≈ $fun.($cpuarr)
end
# 2-arg functions
-@testset "$(fun)()::$T" for T in (Float32, Float16), fun in FLOAT_MATH_INTR_FUNCS_2_ARG
+ @testset "$(fun)()::$T" for T in (Float32, Float16), fun in FLOAT_MATH_INTR_FUNCS_2_ARG
N = 4
arr1 = randn(T, N)
arr2 = randn(T, N)
@@ -212,7 +212,7 @@ end
@eval @test Array($mtlout) ≈ $fun.($arr1, $arr2)
end
# 3-arg functions
-@testset "$(fun)()::$T" for T in (Float32, Float16), fun in FLOAT_MATH_INTR_FUNCS_3_ARG
+ @testset "$(fun)()::$T" for T in (Float32, Float16), fun in FLOAT_MATH_INTR_FUNCS_3_ARG
N = 4
arr1 = randn(T, N)
arr2 = randn(T, N)
@@ -429,61 +429,61 @@ INT_MATH_INTR_FUNCS_3_ARG = [
]
@testset "int math" begin
-# 1-arg functions
-@testset "$(fun)()::$T" for fun in INT_MATH_INTR_FUNCS_1_ARG, T in (Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64)
- cpuarr = T[0.0, -0.0, rand(T), -rand(T)]
+ # 1-arg functions
+ @testset "$(fun)()::$T" for fun in INT_MATH_INTR_FUNCS_1_ARG, T in (Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64)
+ cpuarr = T[0.0, -0.0, rand(T), -rand(T)]
- mtlarr = MtlArray(cpuarr)
+ mtlarr = MtlArray(cpuarr)
- mtlout = fill!(similar(mtlarr), 0)
+ mtlout = fill!(similar(mtlarr), 0)
- function kernel(res, arr)
- idx = thread_position_in_grid_1d()
- res[idx] = fun(arr[idx])
- return nothing
+ function kernel(res, arr)
+ idx = thread_position_in_grid_1d()
+ res[idx] = fun(arr[idx])
+ return nothing
+ end
+ Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr)
+ @eval @test Array($mtlout) ≈ $fun.($cpuarr)
end
- Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr)
- @eval @test Array($mtlout) ≈ $fun.($cpuarr)
-end
-# 2-arg functions
-@testset "$(fun)()::$T" for T in (Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64), fun in INT_MATH_INTR_FUNCS_2_ARG
- N = 4
- arr1 = rand(T, N)
- arr2 = rand(T, N)
- mtlarr1 = MtlArray(arr1)
- mtlarr2 = MtlArray(arr2)
+ # 2-arg functions
+ @testset "$(fun)()::$T" for T in (Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64), fun in INT_MATH_INTR_FUNCS_2_ARG
+ N = 4
+ arr1 = rand(T, N)
+ arr2 = rand(T, N)
+ mtlarr1 = MtlArray(arr1)
+ mtlarr2 = MtlArray(arr2)
- mtlout = fill!(similar(mtlarr1), 0)
+ mtlout = fill!(similar(mtlarr1), 0)
- function kernel(res, x, y)
- idx = thread_position_in_grid_1d()
- res[idx] = fun(x[idx], y[idx])
- return nothing
+ function kernel(res, x, y)
+ idx = thread_position_in_grid_1d()
+ res[idx] = fun(x[idx], y[idx])
+ return nothing
+ end
+ Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
+ @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2)
end
- Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
- @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2)
-end
-# 3-arg functions
-@testset "$(fun)()::$T" for T in (Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64), fun in INT_MATH_INTR_FUNCS_3_ARG
- N = 4
- arr1 = rand(T, N)
- arr2 = rand(T, N)
- arr3 = rand(T, N)
+ # 3-arg functions
+ @testset "$(fun)()::$T" for T in (Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64), fun in INT_MATH_INTR_FUNCS_3_ARG
+ N = 4
+ arr1 = rand(T, N)
+ arr2 = rand(T, N)
+ arr3 = rand(T, N)
- mtlarr1 = MtlArray(arr1)
- mtlarr2 = MtlArray(arr2)
- mtlarr3 = MtlArray(arr3)
+ mtlarr1 = MtlArray(arr1)
+ mtlarr2 = MtlArray(arr2)
+ mtlarr3 = MtlArray(arr3)
- mtlout = fill!(similar(mtlarr1), 0)
+ mtlout = fill!(similar(mtlarr1), 0)
- function kernel(res, x, y, z)
- idx = thread_position_in_grid_1d()
- res[idx] = fun(x[idx], y[idx], z[idx])
- return nothing
+ function kernel(res, x, y, z)
+ idx = thread_position_in_grid_1d()
+ res[idx] = fun(x[idx], y[idx], z[idx])
+ return nothing
+ end
+ Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3)
+ @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3)
end
- Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3)
- @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3)
-end
end
############################################################################################
@@ -757,9 +757,9 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
@testset "low-level" begin
# TODO: make these tests actually write to the overlapping memory locations
- atomic_store_load_exch_cmpexch_types = (Int32, UInt32, Float32)
- # The Metal Shading Language spec states: "Metal 3 supports the atomic_float for device memory only"
- local_atomic_store_load_exch_cmpexch_types = setdiff(atomic_store_load_exch_cmpexch_types, [Float32])
+ atomic_store_load_exch_cmpexch_types = (Int32, UInt32, Float32)
+ # The Metal Shading Language spec states: "Metal 3 supports the atomic_float for device memory only"
+ local_atomic_store_load_exch_cmpexch_types = setdiff(atomic_store_load_exch_cmpexch_types, [Float32])
@testset "store_explicit" begin
function global_kernel(a, val)
@@ -768,7 +768,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
return
end
- @testset for T in atomic_store_load_exch_cmpexch_types
+ @testset for T in atomic_store_load_exch_cmpexch_types
a = Metal.zeros(T, n)
@metal threads=n global_kernel(a, T(42))
@test all(isequal(42), Array(a))
@@ -782,7 +782,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
return
end
- @testset for T in local_atomic_store_load_exch_cmpexch_types
+ @testset for T in local_atomic_store_load_exch_cmpexch_types
a = Metal.zeros(T, n)
@metal threads=n local_kernel(a, T(42))
@test all(isequal(42), Array(a))
@@ -797,7 +797,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
return
end
- @testset for T in atomic_store_load_exch_cmpexch_types
+ @testset for T in atomic_store_load_exch_cmpexch_types
a = MtlArray(rand(T, n))
b = Metal.zeros(T, n)
@metal threads=n global_kernel(a, b)
@@ -816,7 +816,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
return
end
- @testset for T in local_atomic_store_load_exch_cmpexch_types
+ @testset for T in local_atomic_store_load_exch_cmpexch_types
a = MtlArray(rand(T, n))
b = Metal.zeros(T, n)
@metal threads=n local_kernel(a, b)
@@ -831,7 +831,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
return
end
- @testset for T in atomic_store_load_exch_cmpexch_types
+ @testset for T in atomic_store_load_exch_cmpexch_types
a = MtlArray(rand(T, n))
@metal threads=n global_kernel(a, T(42))
@test all(isequal(42), Array(a))
@@ -845,7 +845,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
return
end
- @testset for T in local_atomic_store_load_exch_cmpexch_types
+ @testset for T in local_atomic_store_load_exch_cmpexch_types
a = Metal.zeros(T, n)
@metal threads=n local_kernel(a, T(42))
@test all(isequal(42), Array(a))
@@ -861,7 +861,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
return
end
- @testset for T in atomic_store_load_exch_cmpexch_types
+ @testset for T in atomic_store_load_exch_cmpexch_types
a = MtlArray(rand(T, n))
expected = copy(a)
desired = T(42)
@@ -884,7 +884,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
return
end
- @testset for T in local_atomic_store_load_exch_cmpexch_types
+ @testset for T in local_atomic_store_load_exch_cmpexch_types
a = Metal.zeros(T, n)
expected = copy(a)
desired = T(42)
@@ -894,7 +894,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
end
@testset "fetch and modify" begin
- add_sub_types = [Int32, UInt32, Float32]
+ add_sub_types = [Int32, UInt32, Float32]
other_types = [Int32, UInt32]
for (jlfun, mtlfun, types) in [(min, Metal.atomic_fetch_min_explicit, other_types),
(max, Metal.atomic_fetch_max_explicit, other_types),
@@ -953,7 +953,7 @@ n = 128 # NOTE: also hard-coded in MtlThreadGroupArray constructors
return
end
- @testset for T in (Int32, UInt32, Float32)
+ @testset for T in (Int32, UInt32, Float32)
a = rand(T, n)
b = MtlArray(a)
val = rand(T)
@@ -989,7 +989,7 @@ end
# covered by the low-level tests above, but only the atomic macro functionality.
@testset "load" begin
- types = [Int32, UInt32, Float32]
+ types = [Int32, UInt32, Float32]
function kernel(a, b)
i = thread_position_in_grid_1d()
@@ -1006,7 +1006,7 @@ end
end
@testset "store" begin
- types = [Int32, UInt32, Float32]
+ types = [Int32, UInt32, Float32]
function kernel(a, b)
i = thread_position_in_grid_1d()
@@ -1024,7 +1024,7 @@ end
end
@testset "add" begin
- types = [Int32, UInt32, Float32]
+ types = [Int32, UInt32, Float32]
function kernel(a)
Metal.@atomic a[1] = a[1] + 1
@@ -1040,7 +1040,7 @@ end
end
@testset "sub" begin
- types = [Int32, UInt32, Float32]
+ types = [Int32, UInt32, Float32]
function kernel(a)
Metal.@atomic a[1] = a[1] - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Metal Benchmarks
| Benchmark suite | Current: 74109bc | Previous: 4324871 | Ratio |
|---|---|---|---|
private array/construct |
24364.583333333336 ns |
24639 ns |
0.99 |
private array/broadcast |
459125 ns |
465708 ns |
0.99 |
private array/random/randn/Float32 |
917375 ns |
830416 ns |
1.10 |
private array/random/randn!/Float32 |
608291 ns |
632792 ns |
0.96 |
private array/random/rand!/Int64 |
557645.5 ns |
557083 ns |
1.00 |
private array/random/rand!/Float32 |
563541 ns |
600125 ns |
0.94 |
private array/random/rand/Int64 |
923312.5 ns |
766062.5 ns |
1.21 |
private array/random/rand/Float32 |
818625 ns |
634208 ns |
1.29 |
private array/copyto!/gpu_to_gpu |
575250 ns |
664208 ns |
0.87 |
private array/copyto!/cpu_to_gpu |
663000 ns |
773250 ns |
0.86 |
private array/copyto!/gpu_to_cpu |
724770.5 ns |
699709 ns |
1.04 |
private array/accumulate/1d |
1406292 ns |
1339667 ns |
1.05 |
private array/accumulate/2d |
1481000 ns |
1386666.5 ns |
1.07 |
private array/iteration/findall/int |
2270833 ns |
2103625 ns |
1.08 |
private array/iteration/findall/bool |
1991166.5 ns |
1839124.5 ns |
1.08 |
private array/iteration/findfirst/int |
1830541.5 ns |
1695250 ns |
1.08 |
private array/iteration/findfirst/bool |
1733917 ns |
1666458 ns |
1.04 |
private array/iteration/scalar |
2498959 ns |
3433416 ns |
0.73 |
private array/iteration/logical |
3453437 ns |
3197875 ns |
1.08 |
private array/iteration/findmin/1d |
1863584 ns |
1765958 ns |
1.06 |
private array/iteration/findmin/2d |
1423292 ns |
1344812.5 ns |
1.06 |
private array/reductions/reduce/1d |
988958 ns |
1043625 ns |
0.95 |
private array/reductions/reduce/2d |
698834 ns |
661229 ns |
1.06 |
private array/reductions/mapreduce/1d |
972042 ns |
1014646 ns |
0.96 |
private array/reductions/mapreduce/2d |
698083 ns |
666687.5 ns |
1.05 |
private array/permutedims/4d |
2625333 ns |
2533875 ns |
1.04 |
private array/permutedims/2d |
1093416 ns |
1025083.5 ns |
1.07 |
private array/permutedims/3d |
1812125 ns |
1582229 ns |
1.15 |
private array/copy |
836000 ns |
579250 ns |
1.44 |
latency/precompile |
9189597500 ns |
9071946333 ns |
1.01 |
latency/ttfp |
3732490041 ns |
3672313458 ns |
1.02 |
latency/import |
1261503020.5 ns |
1239159916 ns |
1.02 |
integration/metaldevrt |
750750 ns |
723334 ns |
1.04 |
integration/byval/slices=1 |
1639500 ns |
1627542 ns |
1.01 |
integration/byval/slices=3 |
20364291.5 ns |
10224103.5 ns |
1.99 |
integration/byval/reference |
1640083 ns |
1593833 ns |
1.03 |
integration/byval/slices=2 |
2816875 ns |
2576042 ns |
1.09 |
kernel/indexing |
457250 ns |
459437.5 ns |
1.00 |
kernel/indexing_checked |
459125 ns |
464146 ns |
0.99 |
kernel/launch |
8083 ns |
8000 ns |
1.01 |
metal/synchronization/stream |
15292 ns |
14709 ns |
1.04 |
metal/synchronization/context |
15812.5 ns |
14834 ns |
1.07 |
shared array/construct |
24000 ns |
24382 ns |
0.98 |
shared array/broadcast |
460687 ns |
459666.5 ns |
1.00 |
shared array/random/randn/Float32 |
914916.5 ns |
841104 ns |
1.09 |
shared array/random/randn!/Float32 |
605500 ns |
640708 ns |
0.95 |
shared array/random/rand!/Int64 |
563708 ns |
571334 ns |
0.99 |
shared array/random/rand!/Float32 |
558145.5 ns |
596354.5 ns |
0.94 |
shared array/random/rand/Int64 |
904125 ns |
774333 ns |
1.17 |
shared array/random/rand/Float32 |
819875 ns |
644416 ns |
1.27 |
shared array/copyto!/gpu_to_gpu |
80250 ns |
82959 ns |
0.97 |
shared array/copyto!/cpu_to_gpu |
79916 ns |
83750 ns |
0.95 |
shared array/copyto!/gpu_to_cpu |
80375 ns |
82583.5 ns |
0.97 |
shared array/accumulate/1d |
1430979.5 ns |
1341375 ns |
1.07 |
shared array/accumulate/2d |
1487000 ns |
1394854 ns |
1.07 |
shared array/iteration/findall/int |
2032792 ns |
1790000 ns |
1.14 |
shared array/iteration/findall/bool |
1752270.5 ns |
1571083 ns |
1.12 |
shared array/iteration/findfirst/int |
1508166 ns |
1381542 ns |
1.09 |
shared array/iteration/findfirst/bool |
1419833 ns |
1367708 ns |
1.04 |
shared array/iteration/scalar |
163292 ns |
157917 ns |
1.03 |
shared array/iteration/logical |
3281833.5 ns |
2978354.5 ns |
1.10 |
shared array/iteration/findmin/1d |
1565750 ns |
1465666.5 ns |
1.07 |
shared array/iteration/findmin/2d |
1434166 ns |
1367417 ns |
1.05 |
shared array/reductions/reduce/1d |
707208 ns |
733625 ns |
0.96 |
shared array/reductions/reduce/2d |
708500 ns |
661083 ns |
1.07 |
shared array/reductions/mapreduce/1d |
745625 ns |
735437.5 ns |
1.01 |
shared array/reductions/mapreduce/2d |
703354.5 ns |
665875 ns |
1.06 |
shared array/permutedims/4d |
2661541.5 ns |
2500209 ns |
1.06 |
shared array/permutedims/2d |
1094042 ns |
1022354 ns |
1.07 |
shared array/permutedims/3d |
1780145.5 ns |
1576083 ns |
1.13 |
shared array/copy |
210375 ns |
239334 ns |
0.88 |
This comment was automatically generated by workflow using github-action-benchmark.
a78d65b to
74109bc
Compare
|
LGTM. The single-arg
Mind opening a separate issue for this? |
|
Atomics changes:
Integer changes are analogous to #531.
Integer Intrinsics:
min/maxa valid intrinsic??_mul_highso remove unused codeVery weirdly, overriding
Base.max(x::Int64, y::Int64)breaks matrix multiplication. I have not idea what's going on.Should we split up the intrinsics tests by their respective files (or at least atomics, math, others)? I originally had it in but it made the PR harder to review so I'll make a new one if we go that way.