@@ -4,6 +4,7 @@ using ..Metal
44using .. Metal: @device_override , DefaultStorageMode, SharedStorage
55
66import KernelAbstractions as KA
7+ import KernelAbstractions: KernelIntrinsics as KI
78
89using StaticArrays: MArray
910
@@ -133,35 +134,58 @@ function (obj::KA.Kernel{MetalBackend})(args...; ndrange=nothing, workgroupsize=
133134 return nothing
134135end
135136
137+ function KI. KIKernel (:: MetalBackend , f, args... ; kwargs... )
138+ kern = eval (quote
139+ @metal launch= false $ (kwargs... ) $ (f)($ (args... ))
140+ end )
141+ KI. KIKernel {MetalBackend, typeof(kern)} (MetalBackend (), kern)
142+ end
143+
144+ function (obj:: KI.KIKernel{MetalBackend} )(args... ; numworkgroups= nothing , workgroupsize= nothing )
145+ threadsPerThreadgroup = isnothing (workgroupsize) ? 1 : workgroupsize
146+ threadgroupsPerGrid = isnothing (numworkgroups) ? 1 : numworkgroups
147+
148+ obj. kern (args... ; threads= threadsPerThreadgroup, groups= threadgroupsPerGrid)
149+ end
150+
151+
152+ function KI. kernel_max_work_group_size (:: B , kikern:: KI.KIKernel{B} ; max_work_items:: Int = typemax (Int)) where B<: MetalBackend
153+ min (kikern. kern. pipeline. maxTotalThreadsPerThreadgroup, max_work_items)
154+ end
155+ function KI. max_work_group_size (:: MetalBackend )
156+ device (). maxThreadsPerThreadgroup. width
157+ end
158+ function KI. multiprocessor_count (:: MetalBackend )
159+ Metal. num_gpu_cores ()
160+ end
161+
162+
136163
137164# # indexing
138165
139166# # COV_EXCL_START
140- @device_override @inline function KA . __index_Local_Linear (ctx )
141- return thread_position_in_threadgroup (). x
167+ @device_override @inline function KI . get_local_id ( )
168+ return (; x = Int ( thread_position_in_threadgroup (). x), y = Int ( thread_position_in_threadgroup () . y), z = Int ( thread_position_in_threadgroup () . z))
142169end
143170
144- @device_override @inline function KA . __index_Group_Linear (ctx )
145- return threadgroup_position_in_grid (). x
171+ @device_override @inline function KI . get_group_id ( )
172+ return (; x = Int ( threadgroup_position_in_grid (). x), y = Int ( threadgroup_position_in_grid () . y), z = Int ( threadgroup_position_in_grid () . z))
146173end
147174
148- @device_override @inline function KA. __index_Global_Linear (ctx)
149- I = @inbounds KA. expand (KA. __iterspace (ctx), threadgroup_position_in_grid (). x, thread_position_in_threadgroup (). x)
150- # TODO : This is unfortunate, can we get the linear index cheaper
151- @inbounds LinearIndices (KA. __ndrange (ctx))[I]
175+ @device_override @inline function KI. get_global_id ()
176+ return (; x = Int (thread_position_in_grid (). x), y = Int (thread_position_in_grid (). y), z = Int (thread_position_in_grid (). z))
152177end
153178
154- @device_override @inline function KA . __index_Local_Cartesian (ctx )
155- @inbounds KA . workitems (KA . __iterspace (ctx))[ thread_position_in_threadgroup () . x]
179+ @device_override @inline function KI . get_local_size ( )
180+ return (; x = Int ( threads_per_threadgroup () . x), y = Int ( threads_per_threadgroup () . y), z = Int ( threads_per_threadgroup () . z))
156181end
157182
158- @device_override @inline function KA . __index_Group_Cartesian (ctx )
159- @inbounds KA . blocks (KA . __iterspace (ctx))[ threadgroup_position_in_grid () . x]
183+ @device_override @inline function KI . get_num_groups ( )
184+ return (; x = Int ( threadgroups_per_grid () . x), y = Int ( threadgroups_per_grid () . y), z = Int ( threadgroups_per_grid () . z))
160185end
161186
162- @device_override @inline function KA. __index_Global_Cartesian (ctx)
163- return @inbounds KA. expand (KA. __iterspace (ctx), threadgroup_position_in_grid (). x,
164- thread_position_in_threadgroup (). x)
187+ @device_override @inline function KI. get_global_size ()
188+ return (; x = Int (threads_per_grid (). x), y = Int (threads_per_grid (). y), z = Int (threads_per_grid (). z))
165189end
166190
167191@device_override @inline function KA. __validindex (ctx)
177201
178202# # shared memory
179203
180- @device_override @inline function KA. SharedMemory (:: Type{T} , :: Val{Dims} ,
181- :: Val{Id} ) where {T, Dims, Id}
204+ @device_override @inline function KI. localmemory (:: Type{T} , :: Val{Dims} ) where {T, Dims}
182205 ptr = Metal. emit_threadgroup_memory (T, Val (prod (Dims)))
183206 MtlDeviceArray (Dims, ptr)
184207end
190213
191214# # other
192215
193- @device_override @inline function KA . __synchronize ()
216+ @device_override @inline function KI . barrier ()
194217 threadgroup_barrier (Metal. MemoryFlagDevice | Metal. MemoryFlagThreadGroup)
195218end
196219
0 commit comments