-
Couldn't load subscription status.
- Fork 46
[Do not merge] Test KernelIntrinsics
#688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/MetalKernels.jl b/src/MetalKernels.jl
index 4c73cead..4cea05a4 100644
--- a/src/MetalKernels.jl
+++ b/src/MetalKernels.jl
@@ -135,28 +135,30 @@ function (obj::KA.Kernel{MetalBackend})(args...; ndrange=nothing, workgroupsize=
end
function KI.KIKernel(::MetalBackend, f, args...; kwargs...)
- kern = eval(quote
- @metal launch=false $(kwargs...) $(f)($(args...))
- end)
- KI.KIKernel{MetalBackend, typeof(kern)}(MetalBackend(), kern)
+ kern = eval(
+ quote
+ @metal launch = false $(kwargs...) $(f)($(args...))
+ end
+ )
+ return KI.KIKernel{MetalBackend, typeof(kern)}(MetalBackend(), kern)
end
-function (obj::KI.KIKernel{MetalBackend})(args...; numworkgroups=nothing, workgroupsize=nothing, kwargs...)
+function (obj::KI.KIKernel{MetalBackend})(args...; numworkgroups = nothing, workgroupsize = nothing, kwargs...)
threadsPerThreadgroup = isnothing(workgroupsize) ? 1 : workgroupsize
threadgroupsPerGrid = isnothing(numworkgroups) ? 1 : numworkgroups
- obj.kern(args...; threads=threadsPerThreadgroup, groups=threadgroupsPerGrid, kwargs...)
+ return obj.kern(args...; threads = threadsPerThreadgroup, groups = threadgroupsPerGrid, kwargs...)
end
-function KI.kernel_max_work_group_size(::MetalBackend, kikern::KI.KIKernel{<:MetalBackend}; max_work_items::Int=typemax(Int))::Int
- Int(min(kikern.kern.pipeline.maxTotalThreadsPerThreadgroup, max_work_items))
+function KI.kernel_max_work_group_size(::MetalBackend, kikern::KI.KIKernel{<:MetalBackend}; max_work_items::Int = typemax(Int))::Int
+ return Int(min(kikern.kern.pipeline.maxTotalThreadsPerThreadgroup, max_work_items))
end
function KI.max_work_group_size(::MetalBackend)::Int
- Int(device().maxThreadsPerThreadgroup.width)
+ return Int(device().maxThreadsPerThreadgroup.width)
end
function KI.multiprocessor_count(::MetalBackend)::Int
- Metal.num_gpu_cores()
+ return Metal.num_gpu_cores()
end
diff --git a/src/broadcast.jl b/src/broadcast.jl
index 72ced3ed..e90f5826 100644
--- a/src/broadcast.jl
+++ b/src/broadcast.jl
@@ -66,8 +66,8 @@ end
if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD
## COV_EXCL_START
function broadcast_cartesian_static(dest, bc, Is)
- i = KI.get_global_id().x
- stride = KI.get_global_size().x
+ i = KI.get_global_id().x
+ stride = KI.get_global_size().x
while 1 <= i <= length(dest)
I = @inbounds Is[i]
@inbounds dest[I] = bc[I]
@@ -91,8 +91,8 @@ end
(isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear))
## COV_EXCL_START
function broadcast_linear(dest, bc)
- i = KI.get_global_id().x
- stride = KI.get_global_size().x
+ i = KI.get_global_id().x
+ stride = KI.get_global_size().x
while 1 <= i <= length(dest)
@inbounds dest[i] = bc[i]
i += stride
@@ -150,8 +150,8 @@ end
else
## COV_EXCL_START
function broadcast_cartesian(dest, bc)
- i = KI.get_global_id().x
- stride = KI.get_global_size().x
+ i = KI.get_global_id().x
+ stride = KI.get_global_size().x
while 1 <= i <= length(dest)
I = @inbounds CartesianIndices(dest)[i]
@inbounds dest[I] = bc[I]
diff --git a/src/device/random.jl b/src/device/random.jl
index 383862d3..36b471d2 100644
--- a/src/device/random.jl
+++ b/src/device/random.jl
@@ -89,8 +89,8 @@ end
@inbounds global_random_counters()[simdgroupId]
elseif field === :ctr2
globalId = KI.get_global_id().x +
- (KI.get_global_id().y - 1i32) * KI.get_global_size().x +
- (KI.get_global_id().z - 1i32) * KI.get_global_size().x * KI.get_global_size().y
+ (KI.get_global_id().y - 1i32) * KI.get_global_size().x +
+ (KI.get_global_id().z - 1i32) * KI.get_global_size().x * KI.get_global_size().y
globalId % UInt32
end::UInt32
end
diff --git a/src/mapreduce.jl b/src/mapreduce.jl
index 3e83e9a7..82a488a6 100644
--- a/src/mapreduce.jl
+++ b/src/mapreduce.jl
@@ -197,9 +197,9 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
# If `Rother` is large enough, then a naive loop is more efficient than partial reductions.
if length(Rother) >= serial_mapreduce_threshold(device(R))
kernel = KI.KIKernel(backend, serial_mapreduce_kernel, f, op, init, Val(Rreduce), Val(Rother), R, A)
- threads = KI.kernel_max_work_group_size(backend, kernel; max_work_items=length(Rother))
+ threads = KI.kernel_max_work_group_size(backend, kernel; max_work_items = length(Rother))
groups = cld(length(Rother), threads)
- kernel(f, op, init, Val(Rreduce), Val(Rother), R, A; numworkgroups=groups, workgroupsize=threads)
+ kernel(f, op, init, Val(Rreduce), Val(Rother), R, A; numworkgroups = groups, workgroupsize = threads)
return R
end
@@ -224,7 +224,8 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
# we might not be able to launch all those threads to reduce each slice in one go.
# that's why each threads also loops across their inputs, processing multiple values
# so that we can span the entire reduction dimension using a single item group.
- kernel = KI.KIKernel(backend, partial_mapreduce_device, f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother),
+ kernel = KI.KIKernel(
+ backend, partial_mapreduce_device, f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother),
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R, A)
# how many threads do we want?
@@ -263,7 +264,8 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
# we can cover the dimensions to reduce using a single group
kernel(f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother),
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R, A;
- numworkgroups=groups, workgroupsize=threads)
+ numworkgroups = groups, workgroupsize = threads
+ )
else
# we need multiple steps to cover all values to reduce
partial = similar(R, (size(R)..., reduce_groups))
@@ -274,12 +276,15 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
end
# NOTE: we can't use the previously-compiled kernel, since the type of `partial`
# might not match the original output container (e.g. if that was a view).
- KI.KIKernel(backend, partial_mapreduce_device,
+ KI.KIKernel(
+ backend, partial_mapreduce_device,
f, op, init, Val(threads), Val(Rreduce), Val(Rother),
- Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A)(
+ Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A
+ )(
f, op, init, Val(threads), Val(Rreduce), Val(Rother),
Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A;
- numworkgroups=groups, workgroupsize=threads)
+ numworkgroups = groups, workgroupsize = threads
+ )
GPUArrays.mapreducedim!(identity, op, R, partial; init=init)
end
diff --git a/test/kernelabstractions.jl b/test/kernelabstractions.jl
index 221ee680..6f9d5a2c 100644
--- a/test/kernelabstractions.jl
+++ b/test/kernelabstractions.jl
@@ -7,6 +7,6 @@ Testsuite.testsuite(()->MetalBackend(), "Metal", Metal, MtlArray, Metal.MtlDevic
"Convert", # depends on https://github.com/JuliaGPU/Metal.jl/issues/69
"SpecialFunctions", # no equivalent Metal intrinsics for gamma, erf, etc
"sparse", # not supported yet
- "CPU synchronization",
- "fallback test: callable types",
+ "CPU synchronization",
+ "fallback test: callable types",
])) |
6ea1ed2 to
9ac3d49
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Metal Benchmarks
| Benchmark suite | Current: dbabf6f | Previous: 7a00e02 | Ratio |
|---|---|---|---|
latency/precompile |
28785892834 ns |
24988042500 ns |
1.15 |
latency/ttfp |
2281610541 ns |
2127632334 ns |
1.07 |
latency/import |
1370464854.5 ns |
1222851792 ns |
1.12 |
integration/metaldevrt |
939833 ns |
927875 ns |
1.01 |
integration/byval/slices=1 |
1627375 ns |
1620292 ns |
1.00 |
integration/byval/slices=3 |
10568646 ns |
8536708.5 ns |
1.24 |
integration/byval/reference |
1618459 ns |
1612333 ns |
1.00 |
integration/byval/slices=2 |
2669292 ns |
2645042 ns |
1.01 |
kernel/indexing |
685125 ns |
679583 ns |
1.01 |
kernel/indexing_checked |
687166 ns |
693562.5 ns |
0.99 |
kernel/launch |
11500 ns |
12125 ns |
0.95 |
array/construct |
6209 ns |
6167 ns |
1.01 |
array/broadcast |
651437.5 ns |
662458 ns |
0.98 |
array/random/randn/Float32 |
779291 ns |
676959 ns |
1.15 |
array/random/randn!/Float32 |
635583 ns |
632812.5 ns |
1.00 |
array/random/rand!/Int64 |
564208 ns |
568500 ns |
0.99 |
array/random/rand!/Float32 |
593938 ns |
602917 ns |
0.99 |
array/random/rand/Int64 |
763062.5 ns |
753958.5 ns |
1.01 |
array/random/rand/Float32 |
598000 ns |
630583 ns |
0.95 |
array/accumulate/Int64/1d |
1414125 ns |
1355125 ns |
1.04 |
array/accumulate/Int64/dims=1 |
1925749.5 ns |
1889750 ns |
1.02 |
array/accumulate/Int64/dims=2 |
2335542 ns |
2192333 ns |
1.07 |
array/accumulate/Int64/dims=1L |
11694791 ns |
11536666 ns |
1.01 |
array/accumulate/Int64/dims=2L |
10047042 ns |
9881020.5 ns |
1.02 |
array/accumulate/Float32/1d |
1260791.5 ns |
1204667 ns |
1.05 |
array/accumulate/Float32/dims=1 |
1738125 ns |
1616041 ns |
1.08 |
array/accumulate/Float32/dims=2 |
2110292 ns |
1935833 ns |
1.09 |
array/accumulate/Float32/dims=1L |
9985334 ns |
9856584 ns |
1.01 |
array/accumulate/Float32/dims=2L |
8223708 ns |
7308562 ns |
1.13 |
array/reductions/reduce/Int64/1d |
2841396 ns |
1359750 ns |
2.09 |
array/reductions/reduce/Int64/dims=1 |
1515833.5 ns |
1174500 ns |
1.29 |
array/reductions/reduce/Int64/dims=2 |
1738667 ns |
1282354 ns |
1.36 |
array/reductions/reduce/Int64/dims=1L |
1905146 ns |
2086875 ns |
0.91 |
array/reductions/reduce/Int64/dims=2L |
5492833.5 ns |
3540312.5 ns |
1.55 |
array/reductions/reduce/Float32/1d |
2828313 ns |
1033041 ns |
2.74 |
array/reductions/reduce/Float32/dims=1 |
1242250 ns |
892792 ns |
1.39 |
array/reductions/reduce/Float32/dims=2 |
1297666.5 ns |
845187.5 ns |
1.54 |
array/reductions/reduce/Float32/dims=1L |
1371729.5 ns |
1374416.5 ns |
1.00 |
array/reductions/reduce/Float32/dims=2L |
2864375 ns |
1883167 ns |
1.52 |
array/reductions/mapreduce/Int64/1d |
2843333.5 ns |
1350646 ns |
2.11 |
array/reductions/mapreduce/Int64/dims=1 |
1536208.5 ns |
1153667 ns |
1.33 |
array/reductions/mapreduce/Int64/dims=2 |
1762021 ns |
1268854.5 ns |
1.39 |
array/reductions/mapreduce/Int64/dims=1L |
1878563 ns |
2090792 ns |
0.90 |
array/reductions/mapreduce/Int64/dims=2L |
5553896 ns |
3437250 ns |
1.62 |
array/reductions/mapreduce/Float32/1d |
2883187.5 ns |
1076708 ns |
2.68 |
array/reductions/mapreduce/Float32/dims=1 |
1257021 ns |
883375 ns |
1.42 |
array/reductions/mapreduce/Float32/dims=2 |
1309834 ns |
809458.5 ns |
1.62 |
array/reductions/mapreduce/Float32/dims=1L |
1382041 ns |
1342917 ns |
1.03 |
array/reductions/mapreduce/Float32/dims=2L |
2870562.5 ns |
1879000 ns |
1.53 |
array/private/copyto!/gpu_to_gpu |
641145.5 ns |
673209 ns |
0.95 |
array/private/copyto!/cpu_to_gpu |
817875 ns |
825542 ns |
0.99 |
array/private/copyto!/gpu_to_cpu |
802228.5 ns |
819667 ns |
0.98 |
array/private/iteration/findall/int |
1758291 ns |
1644375 ns |
1.07 |
array/private/iteration/findall/bool |
1567771 ns |
1487875 ns |
1.05 |
array/private/iteration/findfirst/int |
2876437 ns |
1997458 ns |
1.44 |
array/private/iteration/findfirst/bool |
2772625 ns |
1848771 ns |
1.50 |
array/private/iteration/scalar |
3819083 ns |
5166791.5 ns |
0.74 |
array/private/iteration/logical |
4177146 ns |
2604979 ns |
1.60 |
array/private/iteration/findmin/1d |
2962021 ns |
2057958.5 ns |
1.44 |
array/private/iteration/findmin/2d |
2001313 ns |
1622521 ns |
1.23 |
array/private/copy |
617084 ns |
618062.5 ns |
1.00 |
array/shared/copyto!/gpu_to_gpu |
84667 ns |
83583 ns |
1.01 |
array/shared/copyto!/cpu_to_gpu |
83875 ns |
81875 ns |
1.02 |
array/shared/copyto!/gpu_to_cpu |
83917 ns |
78833 ns |
1.06 |
array/shared/iteration/findall/int |
1790625 ns |
1610667 ns |
1.11 |
array/shared/iteration/findall/bool |
1583083 ns |
1511396 ns |
1.05 |
array/shared/iteration/findfirst/int |
2718542 ns |
1423167 ns |
1.91 |
array/shared/iteration/findfirst/bool |
2572979 ns |
1430792 ns |
1.80 |
array/shared/iteration/scalar |
153270.5 ns |
151500 ns |
1.01 |
array/shared/iteration/logical |
3947146 ns |
2532417 ns |
1.56 |
array/shared/iteration/findmin/1d |
2740166.5 ns |
1572771 ns |
1.74 |
array/shared/iteration/findmin/2d |
1999167 ns |
1627458 ns |
1.23 |
array/shared/copy |
250625 ns |
255750 ns |
0.98 |
array/permutedims/4d |
2855187.5 ns |
2444937.5 ns |
1.17 |
array/permutedims/2d |
1228375 ns |
1178854 ns |
1.04 |
array/permutedims/3d |
1738187 ns |
1722146 ns |
1.01 |
metal/synchronization/stream |
14208 ns |
14834 ns |
0.96 |
metal/synchronization/context |
14750 ns |
15125 ns |
0.98 |
This comment was automatically generated by workflow using github-action-benchmark.
b4cc06a to
22e754e
Compare
Not a draft to also run benchmarks