11module OpenCLKernels
22
33using .. OpenCL
4- using .. OpenCL: @device_override , method_table
4+ using .. OpenCL: @device_override , method_table, kernel_convert, clfunction
55
66import KernelAbstractions as KA
7+ import KernelAbstractions. KernelIntrinsics as KI
78
89import StaticArrays
910
@@ -126,33 +127,62 @@ function (obj::KA.Kernel{OpenCLBackend})(args...; ndrange=nothing, workgroupsize
126127 return nothing
127128end
128129
130+ KI. argconvert (:: OpenCLBackend , arg) = kernel_convert (arg)
131+
132+ function KI. kernel_function (:: OpenCLBackend , f:: F , tt:: TT = Tuple{}; name = nothing , kwargs... ) where {F,TT}
133+ kern = clfunction (f, tt; name, kwargs... )
134+ KI. Kernel {OpenCLBackend, typeof(kern)} (OpenCLBackend (), kern)
135+ end
136+
137+ function (obj:: KI.Kernel{OpenCLBackend} )(args... ; numworkgroups = 1 , workgroupsize = 1 )
138+ KI. check_launch_args (numworkgroups, workgroupsize)
139+
140+ local_size = (workgroupsize... , ntuple (_ -> 1 , 3 - length (workgroupsize))... )
141+
142+ numworkgroups = (numworkgroups... , ntuple (_ -> 1 , 3 - length (numworkgroups))... )
143+ global_size = local_size .* numworkgroups
144+
145+ obj. kern (args... ; local_size, global_size)
146+ return nothing
147+ end
148+
149+
150+ function KI. kernel_max_work_group_size (kernel:: KI.Kernel{<:OpenCLBackend} ; max_work_items:: Int = typemax (Int)):: Int
151+ wginfo = cl. work_group_info (kernel. kern. fun, cl. device ())
152+ Int (min (wginfo. size, max_work_items))
153+ end
154+ function KI. max_work_group_size (:: OpenCLBackend ):: Int
155+ Int (cl. device (). max_work_group_size)
156+ end
157+ function KI. multiprocessor_count (:: OpenCLBackend ):: Int
158+ Int (cl. device (). max_compute_units)
159+ end
129160
130161# # Indexing Functions
162+ # # COV_EXCL_START
131163
132- @device_override @inline function KA . __index_Local_Linear (ctx )
133- return get_local_id (1 )
164+ @device_override @inline function KI . get_local_id ( )
165+ return (; x = Int ( get_local_id (1 )), y = Int ( get_local_id ( 2 )), z = Int ( get_local_id ( 3 )) )
134166end
135167
136- @device_override @inline function KA . __index_Group_Linear (ctx )
137- return get_group_id (1 )
168+ @device_override @inline function KI . get_group_id ( )
169+ return (; x = Int ( get_group_id (1 )), y = Int ( get_group_id ( 2 )), z = Int ( get_group_id ( 3 )) )
138170end
139171
140- @device_override @inline function KA. __index_Global_Linear (ctx)
141- # return get_global_id(1) # JuliaGPU/OpenCL.jl#346
142- I = KA. __index_Global_Cartesian (ctx)
143- @inbounds LinearIndices (KA. __ndrange (ctx))[I]
172+ @device_override @inline function KI. get_global_id ()
173+ return (; x = Int (get_global_id (1 )), y = Int (get_global_id (2 )), z = Int (get_global_id (3 )))
144174end
145175
146- @device_override @inline function KA . __index_Local_Cartesian (ctx )
147- @inbounds KA . workitems (KA . __iterspace (ctx))[ get_local_id ( 1 )]
176+ @device_override @inline function KI . get_local_size ( )
177+ return (; x = Int ( get_local_size ( 1 )), y = Int ( get_local_size ( 2 )), z = Int ( get_local_size ( 3 )))
148178end
149179
150- @device_override @inline function KA . __index_Group_Cartesian (ctx )
151- @inbounds KA . blocks (KA . __iterspace (ctx))[ get_group_id ( 1 )]
180+ @device_override @inline function KI . get_num_groups ( )
181+ return (; x = Int ( get_num_groups ( 1 )), y = Int ( get_num_groups ( 2 )), z = Int ( get_num_groups ( 3 )))
152182end
153183
154- @device_override @inline function KA . __index_Global_Cartesian (ctx )
155- return @inbounds KA . expand (KA . __iterspace (ctx), get_group_id ( 1 ), get_local_id ( 1 ))
184+ @device_override @inline function KI . get_global_size ( )
185+ return (; x = Int ( get_global_size ( 1 )), y = Int ( get_global_size ( 2 )), z = Int ( get_global_size ( 3 ) ))
156186end
157187
158188@device_override @inline function KA. __validindex (ctx)
167197
168198# # Shared and Scratch Memory
169199
170- @device_override @inline function KA . SharedMemory (:: Type{T} , :: Val{Dims} , :: Val{Id} ) where {T, Dims, Id }
200+ @device_override @inline function KI . localmemory (:: Type{T} , :: Val{Dims} ) where {T, Dims}
171201 ptr = OpenCL. emit_localmemory (T, Val (prod (Dims)))
172202 CLDeviceArray (Dims, ptr)
173203end
@@ -179,14 +209,14 @@ end
179209
180210# # Synchronization and Printing
181211
182- @device_override @inline function KA . __synchronize ()
212+ @device_override @inline function KI . barrier ()
183213 work_group_barrier (OpenCL. LOCAL_MEM_FENCE | OpenCL. GLOBAL_MEM_FENCE)
184214end
185215
186- @device_override @inline function KA . __print (args... )
216+ @device_override @inline function KI . _print (args... )
187217 OpenCL. _print (args... )
188218end
189-
219+ # # COV_EXCL_STOP
190220
191221# # Other
192222
0 commit comments