Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ steps:
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
Expand All @@ -24,6 +22,18 @@ steps:
build.message !~ /\[only/ &&
build.message !~ /\[skip tests\]/ &&
build.message !~ /\[skip julia\]/
commands: |
julia -e 'println("--- :julia: Developing KernelAbstractions")
using Pkg
Pkg.add(url="https://github.com/christiangnrd/KernelAbstractions.jl", rev="intrinsics")'

julia -e 'println("--- :julia: Instantiating project")
using Pkg
Pkg.develop(; path=pwd())' || exit 3

julia -e 'println("+++ :julia: Running tests")
using Pkg
Pkg.test("Metal"; coverage=true)'
timeout_in_minutes: 60
matrix:
setup:
Expand Down Expand Up @@ -99,6 +109,12 @@ steps:
- JuliaCI/julia#v1:
version: "1.12"
command: |
julia --project -e '
using Pkg

println("--- :julia: Developing KernelAbstractions")
Pkg.add(url="https://github.com/christiangnrd/KernelAbstractions.jl", rev="intrinsics")'

julia --project=perf -e '
using Pkg

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ ExprTools = "0.1"
GPUArrays = "11.2.1"
GPUCompiler = "1.7.1"
GPUToolbox = "0.1, 0.2, 0.3, 1"
KernelAbstractions = "0.9.38"
KernelAbstractions = "0.10"
LLVM = "7.2, 8, 9"
LLVMDowngrader_jll = "0.6"
LinearAlgebra = "1"
Expand Down
1 change: 1 addition & 0 deletions src/Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using ExprTools: splitdef, combinedef
using ObjectiveC, .CoreFoundation, .Foundation, .Dispatch, .OS
import ObjectiveC: is_macos, darwin_version, macos_version
import KernelAbstractions
import KernelAbstractions: KernelIntrinsics as KI
using ScopedValues

include("version.jl")
Expand Down
59 changes: 41 additions & 18 deletions src/MetalKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using ..Metal
using ..Metal: @device_override, DefaultStorageMode, SharedStorage

import KernelAbstractions as KA
import KernelAbstractions: KernelIntrinsics as KI

using StaticArrays: MArray

Expand Down Expand Up @@ -133,35 +134,58 @@ function (obj::KA.Kernel{MetalBackend})(args...; ndrange=nothing, workgroupsize=
return nothing
end

function KI.KIKernel(::MetalBackend, f, args...; kwargs...)
kern = eval(quote
@metal launch=false $(kwargs...) $(f)($(args...))
end)
KI.KIKernel{MetalBackend, typeof(kern)}(MetalBackend(), kern)
end

function (obj::KI.KIKernel{MetalBackend})(args...; numworkgroups=nothing, workgroupsize=nothing)
threadsPerThreadgroup = isnothing(workgroupsize) ? 1 : workgroupsize
threadgroupsPerGrid = isnothing(numworkgroups) ? 1 : numworkgroups

obj.kern(args...; threads=threadsPerThreadgroup, groups=threadgroupsPerGrid)
end


function KI.kernel_max_work_group_size(::B, kikern::KI.KIKernel{B}; max_work_items::Int=typemax(Int)) where B<:MetalBackend
min(kikern.kern.pipeline.maxTotalThreadsPerThreadgroup, max_work_items)
end
function KI.max_work_group_size(::MetalBackend)
device().maxThreadsPerThreadgroup.width
end
function KI.multiprocessor_count(::MetalBackend)
Metal.num_gpu_cores()
end



## indexing

## COV_EXCL_START
@device_override @inline function KA.__index_Local_Linear(ctx)
return thread_position_in_threadgroup().x
@device_override @inline function KI.get_local_id()
return (; x = Int(thread_position_in_threadgroup().x), y = Int(thread_position_in_threadgroup().y), z = Int(thread_position_in_threadgroup().z))
end

@device_override @inline function KA.__index_Group_Linear(ctx)
return threadgroup_position_in_grid().x
@device_override @inline function KI.get_group_id()
return (; x = Int(threadgroup_position_in_grid().x), y = Int(threadgroup_position_in_grid().y), z = Int(threadgroup_position_in_grid().z))
end

@device_override @inline function KA.__index_Global_Linear(ctx)
I = @inbounds KA.expand(KA.__iterspace(ctx), threadgroup_position_in_grid().x, thread_position_in_threadgroup().x)
# TODO: This is unfortunate, can we get the linear index cheaper
@inbounds LinearIndices(KA.__ndrange(ctx))[I]
@device_override @inline function KI.get_global_id()
return (; x = Int(thread_position_in_grid().x), y = Int(thread_position_in_grid().y), z = Int(thread_position_in_grid().z))
end

@device_override @inline function KA.__index_Local_Cartesian(ctx)
@inbounds KA.workitems(KA.__iterspace(ctx))[thread_position_in_threadgroup().x]
@device_override @inline function KI.get_local_size()
return (; x = Int(threads_per_threadgroup().x), y = Int(threads_per_threadgroup().y), z = Int(threads_per_threadgroup().z))
end

@device_override @inline function KA.__index_Group_Cartesian(ctx)
@inbounds KA.blocks(KA.__iterspace(ctx))[threadgroup_position_in_grid().x]
@device_override @inline function KI.get_num_groups()
return (; x = Int(threadgroups_per_grid().x), y = Int(threadgroups_per_grid().y), z = Int(threadgroups_per_grid().z))
end

@device_override @inline function KA.__index_Global_Cartesian(ctx)
return @inbounds KA.expand(KA.__iterspace(ctx), threadgroup_position_in_grid().x,
thread_position_in_threadgroup().x)
@device_override @inline function KI.get_global_size()
return (; x = Int(threads_per_grid().x), y = Int(threads_per_grid().y), z = Int(threads_per_grid().z))
end

@device_override @inline function KA.__validindex(ctx)
Expand All @@ -177,8 +201,7 @@ end

## shared memory

@device_override @inline function KA.SharedMemory(::Type{T}, ::Val{Dims},
::Val{Id}) where {T, Dims, Id}
@device_override @inline function KI.localmemory(::Type{T}, ::Val{Dims}) where {T, Dims}
ptr = Metal.emit_threadgroup_memory(T, Val(prod(Dims)))
MtlDeviceArray(Dims, ptr)
end
Expand All @@ -190,7 +213,7 @@ end

## other

@device_override @inline function KA.__synchronize()
@device_override @inline function KI.barrier()
threadgroup_barrier(Metal.MemoryFlagDevice | Metal.MemoryFlagThreadGroup)
end

Expand Down
22 changes: 11 additions & 11 deletions src/accumulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArray,
Rdim, Rpre, Rpost, Rother, neutral, init,
::Val{maxthreads}, ::Val{inclusive}=Val(true)) where {T, maxthreads, inclusive}
threads = threads_per_threadgroup_3d().x
thread = thread_position_in_threadgroup_3d().x
threads = KI.get_local_size().x
thread = KI.get_local_id().x

temp = MtlThreadGroupArray(T, (Int32(2) * maxthreads,))
temp = KI.localmemory(T, (Int32(2) * maxthreads,))

i = (threadgroup_position_in_grid_3d().x - Int32(1)) * threads_per_threadgroup_3d().x + thread_position_in_threadgroup_3d().x
j = (threadgroup_position_in_grid_3d().z - Int32(1)) * threadgroups_per_grid_3d().y + threadgroup_position_in_grid_3d().y
i = (KI.get_group_id().x - Int32(1)) * KI.get_local_size().x + KI.get_local_id().x
j = (KI.get_group_id().z - Int32(1)) * KI.get_num_groups().y + KI.get_group_id().y

if j > length(Rother)
return
Expand All @@ -29,7 +29,7 @@ function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArr
offset = one(thread)
d = threads >> 0x1
while d > zero(d)
threadgroup_barrier(MemoryFlagThreadGroup)
KI.barrier()
@inbounds if thread <= d
ai = offset * (thread << 0x1 - 0x1)
bi = offset * (thread << 0x1)
Expand All @@ -46,7 +46,7 @@ function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArr
d = one(thread)
while d < threads
offset >>= 0x1
threadgroup_barrier(MemoryFlagThreadGroup)
KI.barrier()
@inbounds if thread <= d
ai = offset * (thread << 0x1 - 0x1)
bi = offset * (thread << 0x1)
Expand All @@ -58,7 +58,7 @@ function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArr
d <<= 0x1
end

threadgroup_barrier(MemoryFlagThreadGroup)
KI.barrier()

@inbounds if i <= length(Rdim)
val = if inclusive
Expand All @@ -76,10 +76,10 @@ function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArr
end

function aggregate_partial_scan(op::Function, output::AbstractArray, aggregates::AbstractArray, Rdim, Rpre, Rpost, Rother, init)
block = threadgroup_position_in_grid_3d().x
block = KI.get_group_id().x

i = (threadgroup_position_in_grid_3d().x - Int32(1)) * threads_per_threadgroup_3d().x + thread_position_in_threadgroup_3d().x
j = (threadgroup_position_in_grid_3d().z - Int32(1)) * threadgroups_per_grid_3d().y + threadgroup_position_in_grid_3d().y
i = (KI.get_group_id().x - Int32(1)) * KI.get_local_size().x + KI.get_local_id().x
j = (KI.get_group_id().z - Int32(1)) * KI.get_num_groups().y + KI.get_group_id().y

@inbounds if i <= length(Rdim) && j <= length(Rother)
I = Rother[j]
Expand Down
12 changes: 6 additions & 6 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ end
if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD
## COV_EXCL_START
function broadcast_cartesian_static(dest, bc, Is)
i = thread_position_in_grid().x
stride = threads_per_grid().x
i = KI.get_global_id().x
stride = KI.get_global_size().x
while 1 <= i <= length(dest)
I = @inbounds Is[i]
@inbounds dest[I] = bc[I]
Expand All @@ -91,8 +91,8 @@ end
(isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear))
## COV_EXCL_START
function broadcast_linear(dest, bc)
i = thread_position_in_grid().x
stride = threads_per_grid().x
i = KI.get_global_id().x
stride = KI.get_global_size().x
while 1 <= i <= length(dest)
@inbounds dest[i] = bc[i]
i += stride
Expand Down Expand Up @@ -150,8 +150,8 @@ end
else
## COV_EXCL_START
function broadcast_cartesian(dest, bc)
i = thread_position_in_grid().x
stride = threads_per_grid().x
i = KI.get_global_id().x
stride = KI.get_global_size().x
while 1 <= i <= length(dest)
I = @inbounds CartesianIndices(dest)[i]
@inbounds dest[I] = bc[I]
Expand Down
6 changes: 3 additions & 3 deletions src/device/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ end
elseif field === :ctr1
@inbounds global_random_counters()[simdgroupId]
elseif field === :ctr2
globalId = thread_position_in_grid().x +
(thread_position_in_grid().y - 1i32) * threads_per_grid().x +
(thread_position_in_grid().z - 1i32) * threads_per_grid().x * threads_per_grid().y
globalId = KI.get_global_id().x +
(KI.get_global_id().y - 1i32) * KI.get_global_size().x +
(KI.get_global_id().z - 1i32) * KI.get_global_size().x * KI.get_global_size().y
globalId % UInt32
end::UInt32
end
Expand Down
2 changes: 1 addition & 1 deletion src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ function Base.findall(bools::WrappedMtlArray{Bool})

if n > 0
function kernel(ys::MtlDeviceArray, bools, indices)
i = (threadgroup_position_in_grid().x - Int32(1)) * threads_per_threadgroup().x + thread_position_in_threadgroup().x
i = (KI.get_group_id().x - Int32(1)) * KI.get_local_size().x + KI.get_local_id().x

@inbounds if i <= length(bools) && bools[i]
i′ = CartesianIndices(bools)[i]
Expand Down
Loading
Loading