11module OpenCLKernels
22
33using .. OpenCL
4- using .. OpenCL: @device_override , method_table
4+ using .. OpenCL: @device_override , method_table, kernel_convert, clfunction
55
66import KernelAbstractions as KA
7+ import KernelAbstractions. KernelIntrinsics as KI
78
89import StaticArrays
910
@@ -126,33 +127,61 @@ function (obj::KA.Kernel{OpenCLBackend})(args...; ndrange=nothing, workgroupsize
126127 return nothing
127128end
128129
130+ KI. argconvert (:: OpenCLBackend , arg) = kernel_convert (arg)
131+
132+ function KI. kernel_function (:: OpenCLBackend , f:: F , tt:: TT = Tuple{}; name = nothing , kwargs... ) where {F,TT}
133+ kern = clfunction (f, tt; name, kwargs... )
134+ KI. KIKernel {OpenCLBackend, typeof(kern)} (OpenCLBackend (), kern)
135+ end
136+
137+ function (obj:: KI.KIKernel{OpenCLBackend} )(args... ; numworkgroups = 1 , workgroupsize = 1 )
138+ KI. check_launch_args (numworkgroups, workgroupsize)
139+
140+ local_size = (workgroupsize... , ntuple (_ -> 1 , 3 - length (workgroupsize))... )
141+
142+ numworkgroups = (numworkgroups... , ntuple (_ -> 1 , 3 - length (numworkgroups))... )
143+ global_size = local_size .* numworkgroups
144+
145+ obj. kern (args... ; local_size, global_size)
146+ return nothing
147+ end
148+
149+
150+ function KI. kernel_max_work_group_size (kernel:: KI.KIKernel{<:OpenCLBackend} ; max_work_items:: Int = typemax (Int)):: Int
151+ wginfo = cl. work_group_info (kernel. kern. fun, cl. device ())
152+ Int (min (wginfo. size, max_work_items))
153+ end
154+ function KI. max_work_group_size (:: OpenCLBackend ):: Int
155+ Int (cl. device (). max_work_group_size)
156+ end
157+ function KI. multiprocessor_count (:: OpenCLBackend ):: Int
158+ Int (cl. device (). max_compute_units)
159+ end
129160
130161# # Indexing Functions
131162
132- @device_override @inline function KA . __index_Local_Linear (ctx )
133- return get_local_id (1 )
163+ @device_override @inline function KI . get_local_id ( )
164+ return (; x = Int ( get_local_id (1 )), y = Int ( get_local_id ( 2 )), z = Int ( get_local_id ( 3 )) )
134165end
135166
136- @device_override @inline function KA . __index_Group_Linear (ctx )
137- return get_group_id (1 )
167+ @device_override @inline function KI . get_group_id ( )
168+ return (; x = Int ( get_group_id (1 )), y = Int ( get_group_id ( 2 )), z = Int ( get_group_id ( 3 )) )
138169end
139170
140- @device_override @inline function KA. __index_Global_Linear (ctx)
141- # return get_global_id(1) # JuliaGPU/OpenCL.jl#346
142- I = KA. __index_Global_Cartesian (ctx)
143- @inbounds LinearIndices (KA. __ndrange (ctx))[I]
171+ @device_override @inline function KI. get_global_id ()
172+ return (; x = Int (get_global_id (1 )), y = Int (get_global_id (2 )), z = Int (get_global_id (3 )))
144173end
145174
146- @device_override @inline function KA . __index_Local_Cartesian (ctx )
147- @inbounds KA . workitems (KA . __iterspace (ctx))[ get_local_id ( 1 )]
175+ @device_override @inline function KI . get_local_size ( )
176+ return (; x = Int ( get_local_size ( 1 )), y = Int ( get_local_size ( 2 )), z = Int ( get_local_size ( 3 )))
148177end
149178
150- @device_override @inline function KA . __index_Group_Cartesian (ctx )
151- @inbounds KA . blocks (KA . __iterspace (ctx))[ get_group_id ( 1 )]
179+ @device_override @inline function KI . get_num_groups ( )
180+ return (; x = Int ( get_num_groups ( 1 )), y = Int ( get_num_groups ( 2 )), z = Int ( get_num_groups ( 3 )))
152181end
153182
154- @device_override @inline function KA . __index_Global_Cartesian (ctx )
155- return @inbounds KA . expand (KA . __iterspace (ctx), get_group_id ( 1 ), get_local_id ( 1 ))
183+ @device_override @inline function KI . get_global_size ( )
184+ return (; x = Int ( get_global_size ( 1 )), y = Int ( get_global_size ( 2 )), z = Int ( get_global_size ( 3 ) ))
156185end
157186
158187@device_override @inline function KA. __validindex (ctx)
167196
168197# # Shared and Scratch Memory
169198
170- @device_override @inline function KA . SharedMemory (:: Type{T} , :: Val{Dims} , :: Val{Id} ) where {T, Dims, Id }
199+ @device_override @inline function KI . localmemory (:: Type{T} , :: Val{Dims} ) where {T, Dims}
171200 ptr = OpenCL. emit_localmemory (T, Val (prod (Dims)))
172201 CLDeviceArray (Dims, ptr)
173202end
@@ -179,11 +208,11 @@ end
179208
180209# # Synchronization and Printing
181210
182- @device_override @inline function KA . __synchronize ()
211+ @device_override @inline function KI . barrier ()
183212 work_group_barrier (OpenCL. LOCAL_MEM_FENCE | OpenCL. GLOBAL_MEM_FENCE)
184213end
185214
186- @device_override @inline function KA . __print (args... )
215+ @device_override @inline function KI . _print (args... )
187216 OpenCL. _print (args... )
188217end
189218
0 commit comments