@@ -3,11 +3,13 @@ module ROCKernels
33export ROCBackend
44
55import AMDGPU
6+ import AMDGPU: rocconvert, hipfunction
67import AMDGPU. Device: @device_override
7- using AMDGPU: GPUArrays, rocSPARSE
8+ using AMDGPU: GPUArrays, rocSPARSE, HIP
89
910import Adapt
1011import KernelAbstractions as KA
12+ import KernelAbstractions. KernelIntrinsics as KI
1113import LLVM
1214
1315using StaticArraysCore: MArray
@@ -127,32 +129,60 @@ function KA.mkcontext(kernel::KA.Kernel{ROCBackend}, I, _ndrange, iterspace, ::D
127129 metadata = KA. CompilerMetadata {KA.ndrange(kernel), Dynamic} (I, _ndrange, iterspace)
128130end
129131
130- # Indexing.
132+ KI. argconvert (:: ROCBackend , arg) = rocconvert (arg)
133+
134+ function KI. kernel_function (:: ROCBackend , f:: F , tt:: TT = Tuple{}; name= nothing , kwargs... ) where {F,TT}
135+ kern = hipfunction (f, tt; name, kwargs... )
136+ KI. Kernel {ROCBackend, typeof(kern)} (ROCBackend (), kern)
137+ end
138+
139+ function (obj:: KI.Kernel{ROCBackend} )(args... ; numworkgroups = 1 , workgroupsize = 1 )
140+ KI. check_launch_args (numworkgroups, workgroupsize)
141+
142+ groupsize = workgroupsize
143+ gridsize = groupsize .* numworkgroups
144+
145+ obj. kern (args... ; groupsize, gridsize)
146+ return nothing
147+ end
131148
132- @device_override @inline function KA. __index_Local_Linear (ctx)
133- return AMDGPU. Device. threadIdx (). x
149+
150+ function KI. kernel_max_work_group_size (kikern:: KI.Kernel{<:ROCBackend} ; max_work_items:: Int = typemax (Int)):: Int
151+ (; groupsize) = AMDGPU. launch_configuration (kikern. kern; max_block_size = max_work_items)
152+
153+ return Int (groupsize)
154+ end
155+ function KI. max_work_group_size (:: ROCBackend ):: Int
156+ Int (HIP. attribute (AMDGPU. HIP. device (), AMDGPU. HIP. hipDeviceAttributeMaxThreadsPerBlock))
157+ end
158+ function KI. multiprocessor_count (:: ROCBackend ):: Int
159+ Int (HIP. attribute (AMDGPU. HIP. device (), AMDGPU. HIP. hipDeviceAttributeMultiprocessorCount))
160+ end
161+
162+ # Indexing.
163+ # # COV_EXCL_START
164+ @device_override @inline function KI. get_local_id ()
165+ return (; x = Int (AMDGPU. Device. workitemIdx (). x), y = Int (AMDGPU. Device. workitemIdx (). y), z = Int (AMDGPU. Device. workitemIdx (). z))
134166end
135167
136- @device_override @inline function KA . __index_Group_Linear (ctx )
137- return AMDGPU. Device. blockIdx (). x
168+ @device_override @inline function KI . get_group_id ( )
169+ return (; x = Int ( AMDGPU. Device. workgroupIdx (). x), y = Int (AMDGPU . Device . workgroupIdx () . y), z = Int (AMDGPU . Device . workgroupIdx () . z))
138170end
139171
140- @device_override @inline function KA. __index_Global_Linear (ctx)
141- I = @inbounds KA. expand (KA. __iterspace (ctx), AMDGPU. Device. blockIdx (). x, AMDGPU. Device. threadIdx (). x)
142- # TODO : This is unfortunate, can we get the linear index cheaper
143- @inbounds LinearIndices (KA. __ndrange (ctx))[I]
172+ @device_override @inline function KI. get_global_id ()
173+ return (; x = Int ((AMDGPU. Device. workgroupIdx (). x- 1 )* AMDGPU. Device. blockDim (). x + AMDGPU. Device. workitemIdx (). x), y = Int ((AMDGPU. Device. workgroupIdx (). y- 1 )* AMDGPU. Device. blockDim (). y + AMDGPU. Device. workitemIdx (). y), z = Int ((AMDGPU. Device. workgroupIdx (). z- 1 )* AMDGPU. Device. blockDim (). z + AMDGPU. Device. workitemIdx (). z))
144174end
145175
146- @device_override @inline function KA . __index_Local_Cartesian (ctx )
147- @inbounds KA . workitems (KA . __iterspace (ctx))[ AMDGPU. Device. threadIdx (). x]
176+ @device_override @inline function KI . get_local_size ( )
177+ return (; x = Int (AMDGPU . Device . workgroupDim () . x), y = Int (AMDGPU . Device . workgroupDim () . y), z = Int ( AMDGPU. Device. workgroupDim (). z))
148178end
149179
150- @device_override @inline function KA . __index_Group_Cartesian (ctx )
151- @inbounds KA . blocks (KA . __iterspace (ctx))[ AMDGPU. Device. blockIdx (). x]
180+ @device_override @inline function KI . get_num_groups ( )
181+ return (; x = Int (AMDGPU . Device . gridGroupDim () . x), y = Int (AMDGPU . Device . gridGroupDim () . y), z = Int ( AMDGPU. Device. gridGroupDim (). z))
152182end
153183
154- @device_override @inline function KA . __index_Global_Cartesian (ctx )
155- return @inbounds KA . expand (KA . __iterspace (ctx), AMDGPU. Device. blockIdx (). x, AMDGPU. Device. threadIdx (). x )
184+ @device_override @inline function KI . get_global_size ( )
185+ return (; x = Int (AMDGPU . Device . gridItemDim () . x), y = Int ( AMDGPU. Device. gridItemDim (). y), z = Int ( AMDGPU. Device. gridItemDim (). z) )
156186end
157187
158188@device_override @inline function KA. __validindex (ctx)
166196
167197# Shared memory.
168198
169- @device_override @inline function KA . SharedMemory (:: Type{T} , :: Val{Dims} , :: Val{Id} ) where {T, Dims, Id }
170- ptr = AMDGPU. Device. alloc_special (Val (Id ), T, Val (AMDGPU. AS. Local), Val (prod (Dims)))
199+ @device_override @inline function KI . localmemory (:: Type{T} , :: Val{Dims} ) where {T, Dims}
200+ ptr = AMDGPU. Device. alloc_special (Val (:shmem ), T, Val (AMDGPU. AS. Local), Val (prod (Dims)))
171201 AMDGPU. ROCDeviceArray (Dims, ptr)
172202end
173203
@@ -177,12 +207,13 @@ end
177207
178208# Other.
179209
180- @device_override @inline function KA . __synchronize ()
210+ @device_override @inline function KI . barrier ()
181211 AMDGPU. Device. sync_workgroup ()
182212end
183213
184- @device_override @inline function KA . __print (args... )
214+ @device_override @inline function KI . _print (args... )
185215 # TODO
186216end
217+ # # COV_EXCL_STOP
187218
188219end
0 commit comments