diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 1d44ef482..ca721853f 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -7,8 +7,6 @@ steps: plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true dirs: @@ -24,6 +22,18 @@ steps: build.message !~ /\[only/ && build.message !~ /\[skip tests\]/ && build.message !~ /\[skip julia\]/ + commands: | + julia -e 'println("--- :julia: Developing KernelAbstractions") + using Pkg + Pkg.add(url="https://github.com/christiangnrd/KernelAbstractions.jl", rev="intrinsics")' + + julia -e 'println("--- :julia: Instantiating project") + using Pkg + Pkg.develop(; path=pwd())' || exit 3 + + julia -e 'println("+++ :julia: Running tests") + using Pkg + Pkg.test("Metal"; coverage=true)' timeout_in_minutes: 60 matrix: setup: @@ -99,6 +109,12 @@ steps: - JuliaCI/julia#v1: version: "1.12" command: | + julia --project -e ' + using Pkg + + println("--- :julia: Developing KernelAbstractions") + Pkg.add(url="https://github.com/christiangnrd/KernelAbstractions.jl", rev="intrinsics")' + julia --project=perf -e ' using Pkg diff --git a/Project.toml b/Project.toml index a2e5b1ebb..865e56cf3 100644 --- a/Project.toml +++ b/Project.toml @@ -42,7 +42,7 @@ ExprTools = "0.1" GPUArrays = "11.2.1" GPUCompiler = "1.7.1" GPUToolbox = "0.1, 0.2, 0.3, 1" -KernelAbstractions = "0.9.38" +KernelAbstractions = "0.10" LLVM = "7.2, 8, 9" LLVMDowngrader_jll = "0.6" LinearAlgebra = "1" diff --git a/src/Metal.jl b/src/Metal.jl index 90d859d92..8ea712b7c 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -12,6 +12,7 @@ using ExprTools: splitdef, combinedef using ObjectiveC, .CoreFoundation, .Foundation, .Dispatch, .OS import ObjectiveC: is_macos, darwin_version, macos_version import KernelAbstractions +import KernelAbstractions: KernelIntrinsics as KI using ScopedValues include("version.jl") diff --git a/src/MetalKernels.jl b/src/MetalKernels.jl index 4171c6614..4c73ceadb 100644 --- a/src/MetalKernels.jl +++ b/src/MetalKernels.jl @@ -4,6 +4,7 @@ using ..Metal using ..Metal: @device_override, DefaultStorageMode, SharedStorage import KernelAbstractions as KA +import KernelAbstractions: KernelIntrinsics as KI using StaticArrays: MArray @@ -133,35 +134,58 @@ function (obj::KA.Kernel{MetalBackend})(args...; ndrange=nothing, workgroupsize= return nothing end +function KI.KIKernel(::MetalBackend, f, args...; kwargs...) + kern = eval(quote + @metal launch=false $(kwargs...) $(f)($(args...)) + end) + KI.KIKernel{MetalBackend, typeof(kern)}(MetalBackend(), kern) +end + +function (obj::KI.KIKernel{MetalBackend})(args...; numworkgroups=nothing, workgroupsize=nothing, kwargs...) + threadsPerThreadgroup = isnothing(workgroupsize) ? 1 : workgroupsize + threadgroupsPerGrid = isnothing(numworkgroups) ? 1 : numworkgroups + + obj.kern(args...; threads=threadsPerThreadgroup, groups=threadgroupsPerGrid, kwargs...) +end + + +function KI.kernel_max_work_group_size(::MetalBackend, kikern::KI.KIKernel{<:MetalBackend}; max_work_items::Int=typemax(Int))::Int + Int(min(kikern.kern.pipeline.maxTotalThreadsPerThreadgroup, max_work_items)) +end +function KI.max_work_group_size(::MetalBackend)::Int + Int(device().maxThreadsPerThreadgroup.width) +end +function KI.multiprocessor_count(::MetalBackend)::Int + Metal.num_gpu_cores() +end + + ## indexing ## COV_EXCL_START -@device_override @inline function KA.__index_Local_Linear(ctx) - return thread_position_in_threadgroup().x +@device_override @inline function KI.get_local_id() + return (; x = Int(thread_position_in_threadgroup().x), y = Int(thread_position_in_threadgroup().y), z = Int(thread_position_in_threadgroup().z)) end -@device_override @inline function KA.__index_Group_Linear(ctx) - return threadgroup_position_in_grid().x +@device_override @inline function KI.get_group_id() + return (; x = Int(threadgroup_position_in_grid().x), y = Int(threadgroup_position_in_grid().y), z = Int(threadgroup_position_in_grid().z)) end -@device_override @inline function KA.__index_Global_Linear(ctx) - I = @inbounds KA.expand(KA.__iterspace(ctx), threadgroup_position_in_grid().x, thread_position_in_threadgroup().x) - # TODO: This is unfortunate, can we get the linear index cheaper - @inbounds LinearIndices(KA.__ndrange(ctx))[I] +@device_override @inline function KI.get_global_id() + return (; x = Int(thread_position_in_grid().x), y = Int(thread_position_in_grid().y), z = Int(thread_position_in_grid().z)) end -@device_override @inline function KA.__index_Local_Cartesian(ctx) - @inbounds KA.workitems(KA.__iterspace(ctx))[thread_position_in_threadgroup().x] +@device_override @inline function KI.get_local_size() + return (; x = Int(threads_per_threadgroup().x), y = Int(threads_per_threadgroup().y), z = Int(threads_per_threadgroup().z)) end -@device_override @inline function KA.__index_Group_Cartesian(ctx) - @inbounds KA.blocks(KA.__iterspace(ctx))[threadgroup_position_in_grid().x] +@device_override @inline function KI.get_num_groups() + return (; x = Int(threadgroups_per_grid().x), y = Int(threadgroups_per_grid().y), z = Int(threadgroups_per_grid().z)) end -@device_override @inline function KA.__index_Global_Cartesian(ctx) - return @inbounds KA.expand(KA.__iterspace(ctx), threadgroup_position_in_grid().x, - thread_position_in_threadgroup().x) +@device_override @inline function KI.get_global_size() + return (; x = Int(threads_per_grid().x), y = Int(threads_per_grid().y), z = Int(threads_per_grid().z)) end @device_override @inline function KA.__validindex(ctx) @@ -177,8 +201,7 @@ end ## shared memory -@device_override @inline function KA.SharedMemory(::Type{T}, ::Val{Dims}, - ::Val{Id}) where {T, Dims, Id} +@device_override @inline function KI.localmemory(::Type{T}, ::Val{Dims}) where {T, Dims} ptr = Metal.emit_threadgroup_memory(T, Val(prod(Dims))) MtlDeviceArray(Dims, ptr) end @@ -190,7 +213,7 @@ end ## other -@device_override @inline function KA.__synchronize() +@device_override @inline function KI.barrier() threadgroup_barrier(Metal.MemoryFlagDevice | Metal.MemoryFlagThreadGroup) end diff --git a/src/accumulate.jl b/src/accumulate.jl index 31e2dc4fe..5dfb013bd 100644 --- a/src/accumulate.jl +++ b/src/accumulate.jl @@ -2,13 +2,13 @@ function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArray, Rdim, Rpre, Rpost, Rother, neutral, init, ::Val{maxthreads}, ::Val{inclusive}=Val(true)) where {T, maxthreads, inclusive} - threads = threads_per_threadgroup_3d().x - thread = thread_position_in_threadgroup_3d().x + threads = UInt32(KI.get_local_size().x) + thread = UInt32(KI.get_local_id().x) - temp = MtlThreadGroupArray(T, (Int32(2) * maxthreads,)) + temp = KI.localmemory(T, (UInt32(2) * maxthreads,)) - i = (threadgroup_position_in_grid_3d().x - Int32(1)) * threads_per_threadgroup_3d().x + thread_position_in_threadgroup_3d().x - j = (threadgroup_position_in_grid_3d().z - Int32(1)) * threadgroups_per_grid_3d().y + threadgroup_position_in_grid_3d().y + i = UInt32((UInt32(KI.get_group_id().x) - UInt32(1)) * UInt32(KI.get_local_size().x) + UInt32(KI.get_local_id().x)) + j = UInt32((UInt32(KI.get_group_id().z) - UInt32(1)) * UInt32(KI.get_num_groups().y) + UInt32(KI.get_group_id().y)) if j > length(Rother) return @@ -29,7 +29,7 @@ function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArr offset = one(thread) d = threads >> 0x1 while d > zero(d) - threadgroup_barrier(MemoryFlagThreadGroup) + KI.barrier() @inbounds if thread <= d ai = offset * (thread << 0x1 - 0x1) bi = offset * (thread << 0x1) @@ -46,7 +46,7 @@ function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArr d = one(thread) while d < threads offset >>= 0x1 - threadgroup_barrier(MemoryFlagThreadGroup) + KI.barrier() @inbounds if thread <= d ai = offset * (thread << 0x1 - 0x1) bi = offset * (thread << 0x1) @@ -58,7 +58,7 @@ function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArr d <<= 0x1 end - threadgroup_barrier(MemoryFlagThreadGroup) + KI.barrier() @inbounds if i <= length(Rdim) val = if inclusive @@ -76,10 +76,10 @@ function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArr end function aggregate_partial_scan(op::Function, output::AbstractArray, aggregates::AbstractArray, Rdim, Rpre, Rpost, Rother, init) - block = threadgroup_position_in_grid_3d().x + block = UInt32(KI.get_group_id().x) - i = (threadgroup_position_in_grid_3d().x - Int32(1)) * threads_per_threadgroup_3d().x + thread_position_in_threadgroup_3d().x - j = (threadgroup_position_in_grid_3d().z - Int32(1)) * threadgroups_per_grid_3d().y + threadgroup_position_in_grid_3d().y + i = UInt32((UInt32(KI.get_group_id().x) - UInt32(1)) * UInt32(KI.get_local_size().x) + UInt32(KI.get_local_id().x)) + j = UInt32((UInt32(KI.get_group_id().z) - UInt32(1)) * UInt32(KI.get_num_groups().y) + UInt32(KI.get_group_id().y)) @inbounds if i <= length(Rdim) && j <= length(Rother) I = Rother[j] diff --git a/src/broadcast.jl b/src/broadcast.jl index 26979706a..72ced3edd 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -66,8 +66,8 @@ end if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD ## COV_EXCL_START function broadcast_cartesian_static(dest, bc, Is) - i = thread_position_in_grid().x - stride = threads_per_grid().x + i = KI.get_global_id().x + stride = KI.get_global_size().x while 1 <= i <= length(dest) I = @inbounds Is[i] @inbounds dest[I] = bc[I] @@ -91,8 +91,8 @@ end (isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear)) ## COV_EXCL_START function broadcast_linear(dest, bc) - i = thread_position_in_grid().x - stride = threads_per_grid().x + i = KI.get_global_id().x + stride = KI.get_global_size().x while 1 <= i <= length(dest) @inbounds dest[i] = bc[i] i += stride @@ -150,8 +150,8 @@ end else ## COV_EXCL_START function broadcast_cartesian(dest, bc) - i = thread_position_in_grid().x - stride = threads_per_grid().x + i = KI.get_global_id().x + stride = KI.get_global_size().x while 1 <= i <= length(dest) I = @inbounds CartesianIndices(dest)[i] @inbounds dest[I] = bc[I] diff --git a/src/device/random.jl b/src/device/random.jl index 979adbe36..383862d32 100644 --- a/src/device/random.jl +++ b/src/device/random.jl @@ -88,9 +88,9 @@ end elseif field === :ctr1 @inbounds global_random_counters()[simdgroupId] elseif field === :ctr2 - globalId = thread_position_in_grid().x + - (thread_position_in_grid().y - 1i32) * threads_per_grid().x + - (thread_position_in_grid().z - 1i32) * threads_per_grid().x * threads_per_grid().y + globalId = KI.get_global_id().x + + (KI.get_global_id().y - 1i32) * KI.get_global_size().x + + (KI.get_global_id().z - 1i32) * KI.get_global_size().x * KI.get_global_size().y globalId % UInt32 end::UInt32 end diff --git a/src/indexing.jl b/src/indexing.jl index d5cc8bc96..da3530f4c 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -33,7 +33,7 @@ function Base.findall(bools::WrappedMtlArray{Bool}) if n > 0 function kernel(ys::MtlDeviceArray, bools, indices) - i = (threadgroup_position_in_grid().x - Int32(1)) * threads_per_threadgroup().x + thread_position_in_threadgroup().x + i = (KI.get_group_id().x - Int32(1)) * KI.get_local_size().x + KI.get_local_id().x @inbounds if i <= length(bools) && bools[i] i′ = CartesianIndices(bools)[i] diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 78b3806ac..3e83e9a79 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -20,7 +20,7 @@ end @inline function reduce_group(op, val::T, neutral, shuffle::Val{true}, ::Val{maxthreads}) where {T, maxthreads} # shared mem for partial sums assume(threads_per_simdgroup() == 32) - shared = MtlThreadGroupArray(T, 32) + shared = KI.localmemory(T, 32) wid = simdgroup_index_in_threadgroup() lane = thread_index_in_simdgroup() @@ -34,10 +34,10 @@ end end # wait for all partial reductions - threadgroup_barrier(MemoryFlagThreadGroup) + KI.barrier() # read from shared memory only if that warp existed - val = if thread_index_in_threadgroup() <= fld1(threads_per_threadgroup().x, 32) + val = if KI.get_local_id().x <= fld1(KI.get_local_size().x, 32) @inbounds shared[lane] else neutral @@ -52,17 +52,17 @@ end # Reduce a value across a group, using local memory for communication @inline function reduce_group(op, val::T, neutral, shuffle::Val{false}, ::Val{maxthreads}) where {T, maxthreads} - threads = threads_per_threadgroup().x - thread = thread_position_in_threadgroup().x + threads = KI.get_local_size().x + thread = KI.get_local_id().x # local mem for a complete reduction - shared = MtlThreadGroupArray(T, (maxthreads,)) + shared = KI.localmemory(T, (maxthreads,)) @inbounds shared[thread] = val # perform a reduction d = 1 while d < threads - threadgroup_barrier(MemoryFlagThreadGroup) + KI.barrier() index = 2 * d * (thread-1) + 1 @inbounds if index <= threads other_val = if index + d <= threads @@ -94,9 +94,9 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce}, ::Val{Rother}, ::Val{Rlen}, ::Val{grain}, shuffle, R, As...) where {Rreduce, Rother, Rlen, grain} # decompose the 1D hardware indices into separate ones for reduction (across items # and possibly groups if it doesn't fit) and other elements (remaining groups) - localIdx_reduce = thread_position_in_threadgroup().x - localDim_reduce = threads_per_threadgroup().x * grain - groupIdx_reduce, groupIdx_other = fldmod1(threadgroup_position_in_grid().x, Rlen) + localIdx_reduce = KI.get_local_id().x + localDim_reduce = KI.get_local_size().x * grain + groupIdx_reduce, groupIdx_other = fldmod1(KI.get_group_id().x, Rlen) # group-based indexing into the values outside of the reduction dimension # (that means we can safely synchronize items within this group) @@ -141,7 +141,7 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce}, end function serial_mapreduce_kernel(f, op, neutral, ::Val{Rreduce}, ::Val{Rother}, R, As) where {Rreduce, Rother} - grid_idx = thread_position_in_grid().x + grid_idx = KI.get_global_id().x @inbounds if grid_idx <= length(Rother) Iother = Rother[grid_idx] @@ -166,11 +166,12 @@ end ## COV_EXCL_STOP -serial_mapreduce_threshold(dev) = dev.maxThreadsPerThreadgroup.width * num_gpu_cores() +serial_mapreduce_threshold(dev) = KI.max_work_group_size(MetalBackend()) * KI.multiprocessor_count(MetalBackend()) function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T}, A::Union{AbstractArray,Broadcast.Broadcasted}; init=nothing) where {F, OP, T} + backend = MetalBackend() Base.check_reducedims(R, A) length(A) == 0 && return R # isempty(::Broadcasted) iterates @@ -195,10 +196,10 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T}, # If `Rother` is large enough, then a naive loop is more efficient than partial reductions. if length(Rother) >= serial_mapreduce_threshold(device(R)) - kernel = @metal launch=false serial_mapreduce_kernel(f, op, init, Val(Rreduce), Val(Rother), R, A) - threads = min(length(Rother), kernel.pipeline.maxTotalThreadsPerThreadgroup) + kernel = KI.KIKernel(backend, serial_mapreduce_kernel, f, op, init, Val(Rreduce), Val(Rother), R, A) + threads = KI.kernel_max_work_group_size(backend, kernel; max_work_items=length(Rother)) groups = cld(length(Rother), threads) - kernel(f, op, init, Val(Rreduce), Val(Rother), R, A; threads, groups) + kernel(f, op, init, Val(Rreduce), Val(Rother), R, A; numworkgroups=groups, workgroupsize=threads) return R end @@ -223,17 +224,17 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T}, # we might not be able to launch all those threads to reduce each slice in one go. # that's why each threads also loops across their inputs, processing multiple values # so that we can span the entire reduction dimension using a single item group. - kernel = @metal launch=false partial_mapreduce_device(f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother), + kernel = KI.KIKernel(backend, partial_mapreduce_device, f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother), Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R, A) # how many threads do we want? # # threads in a group work together to reduce values across the reduction dimensions; # we want as many as possible to improve algorithm efficiency and execution occupancy. - wanted_threads = shuffle ? nextwarp(kernel.pipeline, length(Rreduce)) : length(Rreduce) + wanted_threads = shuffle ? nextwarp(kernel.kern.pipeline, length(Rreduce)) : length(Rreduce) function compute_threads(max_threads) if wanted_threads > max_threads - shuffle ? prevwarp(kernel.pipeline, max_threads) : max_threads + shuffle ? prevwarp(kernel.kern.pipeline, max_threads) : max_threads else wanted_threads end @@ -243,7 +244,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T}, # kernel above may be greater than the maxTotalThreadsPerThreadgroup of the eventually launched # kernel below, causing errors # reduce_threads = compute_threads(kernel.pipeline.maxTotalThreadsPerThreadgroup) - reduce_threads = compute_threads(512) + reduce_threads = compute_threads(KI.kernel_max_work_group_size(backend, kernel)) # how many groups should we launch? # @@ -262,7 +263,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T}, # we can cover the dimensions to reduce using a single group kernel(f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother), Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R, A; - threads, groups) + numworkgroups=groups, workgroupsize=threads) else # we need multiple steps to cover all values to reduce partial = similar(R, (size(R)..., reduce_groups)) @@ -273,9 +274,12 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T}, end # NOTE: we can't use the previously-compiled kernel, since the type of `partial` # might not match the original output container (e.g. if that was a view). - @metal threads groups partial_mapreduce_device( + KI.KIKernel(backend, partial_mapreduce_device, f, op, init, Val(threads), Val(Rreduce), Val(Rother), - Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A) + Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A)( + f, op, init, Val(threads), Val(Rreduce), Val(Rother), + Val(UInt64(length(Rother))), Val(grain), Val(shuffle), partial, A; + numworkgroups=groups, workgroupsize=threads) GPUArrays.mapreducedim!(identity, op, R, partial; init=init) end diff --git a/test/kernelabstractions.jl b/test/kernelabstractions.jl index 116205d3a..221ee680d 100644 --- a/test/kernelabstractions.jl +++ b/test/kernelabstractions.jl @@ -7,4 +7,6 @@ Testsuite.testsuite(()->MetalBackend(), "Metal", Metal, MtlArray, Metal.MtlDevic "Convert", # depends on https://github.com/JuliaGPU/Metal.jl/issues/69 "SpecialFunctions", # no equivalent Metal intrinsics for gamma, erf, etc "sparse", # not supported yet + "CPU synchronization", + "fallback test: callable types", ]))