Skip to content

Commit 792451c

Browse files
committed
KernelIntrinsics
1 parent 7a00e02 commit 792451c

File tree

2 files changed

+42
-19
lines changed

2 files changed

+42
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ ExprTools = "0.1"
4242
GPUArrays = "11.2.1"
4343
GPUCompiler = "1.7.1"
4444
GPUToolbox = "0.1, 0.2, 0.3, 1"
45-
KernelAbstractions = "0.9.38"
45+
KernelAbstractions = "0.10"
4646
LLVM = "7.2, 8, 9"
4747
LLVMDowngrader_jll = "0.6"
4848
LinearAlgebra = "1"

src/MetalKernels.jl

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using ..Metal
44
using ..Metal: @device_override, DefaultStorageMode, SharedStorage
55

66
import KernelAbstractions as KA
7+
import KernelAbstractions: KernelIntrinsics as KI
78

89
using StaticArrays: MArray
910

@@ -133,35 +134,58 @@ function (obj::KA.Kernel{MetalBackend})(args...; ndrange=nothing, workgroupsize=
133134
return nothing
134135
end
135136

137+
function KI.KIKernel(::MetalBackend, f, args...; kwargs...)
138+
kern = eval(quote
139+
@metal launch=false $(kwargs...) $(f)($(args...))
140+
end)
141+
KI.KIKernel{MetalBackend, typeof(kern)}(MetalBackend(), kern)
142+
end
143+
144+
function (obj::KI.KIKernel{MetalBackend})(args...; numworkgroups=nothing, workgroupsize=nothing)
145+
threadsPerThreadgroup = isnothing(workgroupsize) ? 1 : workgroupsize
146+
threadgroupsPerGrid = isnothing(numworkgroups) ? 1 : numworkgroups
147+
148+
obj.kern(args...; threads=threadsPerThreadgroup, groups=threadgroupsPerGrid)
149+
end
150+
151+
152+
function KI.kernel_max_work_group_size(::B, kikern::KI.KIKernel{B}; max_work_items::Int=typemax(Int)) where B<:MetalBackend
153+
min(kikern.kern.pipeline.maxTotalThreadsPerThreadgroup, max_work_items)
154+
end
155+
function KI.max_work_group_size(::MetalBackend)
156+
device().maxThreadsPerThreadgroup.width
157+
end
158+
function KI.multiprocessor_count(::MetalBackend)
159+
Metal.num_gpu_cores()
160+
end
161+
162+
136163

137164
## indexing
138165

139166
## COV_EXCL_START
140-
@device_override @inline function KA.__index_Local_Linear(ctx)
141-
return thread_position_in_threadgroup().x
167+
@device_override @inline function KI.get_local_id()
168+
return (; x = Int(thread_position_in_threadgroup().x), y = Int(thread_position_in_threadgroup().y), z = Int(thread_position_in_threadgroup().z))
142169
end
143170

144-
@device_override @inline function KA.__index_Group_Linear(ctx)
145-
return threadgroup_position_in_grid().x
171+
@device_override @inline function KI.get_group_id()
172+
return (; x = Int(threadgroup_position_in_grid().x), y = Int(threadgroup_position_in_grid().y), z = Int(threadgroup_position_in_grid().z))
146173
end
147174

148-
@device_override @inline function KA.__index_Global_Linear(ctx)
149-
I = @inbounds KA.expand(KA.__iterspace(ctx), threadgroup_position_in_grid().x, thread_position_in_threadgroup().x)
150-
# TODO: This is unfortunate, can we get the linear index cheaper
151-
@inbounds LinearIndices(KA.__ndrange(ctx))[I]
175+
@device_override @inline function KI.get_global_id()
176+
return (; x = Int(thread_position_in_grid().x), y = Int(thread_position_in_grid().y), z = Int(thread_position_in_grid().z))
152177
end
153178

154-
@device_override @inline function KA.__index_Local_Cartesian(ctx)
155-
@inbounds KA.workitems(KA.__iterspace(ctx))[thread_position_in_threadgroup().x]
179+
@device_override @inline function KI.get_local_size()
180+
return (; x = Int(threads_per_threadgroup().x), y = Int(threads_per_threadgroup().y), z = Int(threads_per_threadgroup().z))
156181
end
157182

158-
@device_override @inline function KA.__index_Group_Cartesian(ctx)
159-
@inbounds KA.blocks(KA.__iterspace(ctx))[threadgroup_position_in_grid().x]
183+
@device_override @inline function KI.get_num_groups()
184+
return (; x = Int(threadgroups_per_grid().x), y = Int(threadgroups_per_grid().y), z = Int(threadgroups_per_grid().z))
160185
end
161186

162-
@device_override @inline function KA.__index_Global_Cartesian(ctx)
163-
return @inbounds KA.expand(KA.__iterspace(ctx), threadgroup_position_in_grid().x,
164-
thread_position_in_threadgroup().x)
187+
@device_override @inline function KI.get_global_size()
188+
return (; x = Int(threads_per_grid().x), y = Int(threads_per_grid().y), z = Int(threads_per_grid().z))
165189
end
166190

167191
@device_override @inline function KA.__validindex(ctx)
@@ -177,8 +201,7 @@ end
177201

178202
## shared memory
179203

180-
@device_override @inline function KA.SharedMemory(::Type{T}, ::Val{Dims},
181-
::Val{Id}) where {T, Dims, Id}
204+
@device_override @inline function KI.localmemory(::Type{T}, ::Val{Dims}) where {T, Dims}
182205
ptr = Metal.emit_threadgroup_memory(T, Val(prod(Dims)))
183206
MtlDeviceArray(Dims, ptr)
184207
end
@@ -190,7 +213,7 @@ end
190213

191214
## other
192215

193-
@device_override @inline function KA.__synchronize()
216+
@device_override @inline function KI.barrier()
194217
threadgroup_barrier(Metal.MemoryFlagDevice | Metal.MemoryFlagThreadGroup)
195218
end
196219

0 commit comments

Comments
 (0)