Skip to content

Commit a5aa1b6

Browse files
committed
Add KernelIntrinsics support
1 parent 89ad72f commit a5aa1b6

File tree

2 files changed

+50
-20
lines changed

2 files changed

+50
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ SPIRVIntrinsics = {path = "lib/intrinsics"}
2828
Adapt = "4"
2929
GPUArrays = "11.2.1"
3030
GPUCompiler = "1.7.1"
31-
KernelAbstractions = "0.9.38"
31+
KernelAbstractions = "0.10"
3232
LLVM = "9.1"
3333
LinearAlgebra = "1"
3434
OpenCL_jll = "=2024.10.24"

src/OpenCLKernels.jl

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
module OpenCLKernels
22

33
using ..OpenCL
4-
using ..OpenCL: @device_override, method_table
4+
using ..OpenCL: @device_override, method_table, kernel_convert, clfunction
55

66
import KernelAbstractions as KA
7+
import KernelAbstractions.KernelIntrinsics as KI
78

89
import StaticArrays
910

@@ -126,33 +127,62 @@ function (obj::KA.Kernel{OpenCLBackend})(args...; ndrange=nothing, workgroupsize
126127
return nothing
127128
end
128129

130+
KI.argconvert(::OpenCLBackend, arg) = kernel_convert(arg)
131+
132+
function KI.kernel_function(::OpenCLBackend, f::F, tt::TT=Tuple{}; name = nothing, kwargs...) where {F,TT}
133+
kern = clfunction(f, tt; name, kwargs...)
134+
KI.Kernel{OpenCLBackend, typeof(kern)}(OpenCLBackend(), kern)
135+
end
136+
137+
function (obj::KI.Kernel{OpenCLBackend})(args...; numworkgroups = 1, workgroupsize = 1)
138+
KI.check_launch_args(numworkgroups, workgroupsize)
139+
140+
local_size = (workgroupsize..., ntuple(_ -> 1, 3 - length(workgroupsize))...)
141+
142+
numworkgroups = (numworkgroups..., ntuple(_ -> 1, 3 - length(numworkgroups))...)
143+
global_size = local_size .* numworkgroups
144+
145+
obj.kern(args...; local_size, global_size)
146+
return nothing
147+
end
148+
149+
150+
function KI.kernel_max_work_group_size(kernel::KI.Kernel{<:OpenCLBackend}; max_work_items::Int=typemax(Int))::Int
151+
wginfo = cl.work_group_info(kernel.kern.fun, cl.device())
152+
Int(min(wginfo.size, max_work_items))
153+
end
154+
function KI.max_work_group_size(::OpenCLBackend)::Int
155+
Int(cl.device().max_work_group_size)
156+
end
157+
function KI.multiprocessor_count(::OpenCLBackend)::Int
158+
Int(cl.device().max_compute_units)
159+
end
129160

130161
## Indexing Functions
162+
## COV_EXCL_START
131163

132-
@device_override @inline function KA.__index_Local_Linear(ctx)
133-
return get_local_id(1)
164+
@device_override @inline function KI.get_local_id()
165+
return (; x = Int(get_local_id(1)), y = Int(get_local_id(2)), z = Int(get_local_id(3)))
134166
end
135167

136-
@device_override @inline function KA.__index_Group_Linear(ctx)
137-
return get_group_id(1)
168+
@device_override @inline function KI.get_group_id()
169+
return (; x = Int(get_group_id(1)), y = Int(get_group_id(2)), z = Int(get_group_id(3)))
138170
end
139171

140-
@device_override @inline function KA.__index_Global_Linear(ctx)
141-
#return get_global_id(1) # JuliaGPU/OpenCL.jl#346
142-
I = KA.__index_Global_Cartesian(ctx)
143-
@inbounds LinearIndices(KA.__ndrange(ctx))[I]
172+
@device_override @inline function KI.get_global_id()
173+
return (; x = Int(get_global_id(1)), y = Int(get_global_id(2)), z = Int(get_global_id(3)))
144174
end
145175

146-
@device_override @inline function KA.__index_Local_Cartesian(ctx)
147-
@inbounds KA.workitems(KA.__iterspace(ctx))[get_local_id(1)]
176+
@device_override @inline function KI.get_local_size()
177+
return (; x = Int(get_local_size(1)), y = Int(get_local_size(2)), z = Int(get_local_size(3)))
148178
end
149179

150-
@device_override @inline function KA.__index_Group_Cartesian(ctx)
151-
@inbounds KA.blocks(KA.__iterspace(ctx))[get_group_id(1)]
180+
@device_override @inline function KI.get_num_groups()
181+
return (; x = Int(get_num_groups(1)), y = Int(get_num_groups(2)), z = Int(get_num_groups(3)))
152182
end
153183

154-
@device_override @inline function KA.__index_Global_Cartesian(ctx)
155-
return @inbounds KA.expand(KA.__iterspace(ctx), get_group_id(1), get_local_id(1))
184+
@device_override @inline function KI.get_global_size()
185+
return (; x = Int(get_global_size(1)), y = Int(get_global_size(2)), z = Int(get_global_size(3)))
156186
end
157187

158188
@device_override @inline function KA.__validindex(ctx)
@@ -167,7 +197,7 @@ end
167197

168198
## Shared and Scratch Memory
169199

170-
@device_override @inline function KA.SharedMemory(::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
200+
@device_override @inline function KI.localmemory(::Type{T}, ::Val{Dims}) where {T, Dims}
171201
ptr = OpenCL.emit_localmemory(T, Val(prod(Dims)))
172202
CLDeviceArray(Dims, ptr)
173203
end
@@ -179,14 +209,14 @@ end
179209

180210
## Synchronization and Printing
181211

182-
@device_override @inline function KA.__synchronize()
212+
@device_override @inline function KI.barrier()
183213
work_group_barrier(OpenCL.LOCAL_MEM_FENCE | OpenCL.GLOBAL_MEM_FENCE)
184214
end
185215

186-
@device_override @inline function KA.__print(args...)
216+
@device_override @inline function KI._print(args...)
187217
OpenCL._print(args...)
188218
end
189-
219+
## COV_EXCL_STOP
190220

191221
## Other
192222

0 commit comments

Comments
 (0)