Skip to content

Commit 5c6b8ca

Browse files
committed
Add KernelIntrinsics support
1 parent 89ad72f commit 5c6b8ca

File tree

2 files changed

+48
-19
lines changed

2 files changed

+48
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ SPIRVIntrinsics = {path = "lib/intrinsics"}
2828
Adapt = "4"
2929
GPUArrays = "11.2.1"
3030
GPUCompiler = "1.7.1"
31-
KernelAbstractions = "0.9.38"
31+
KernelAbstractions = "0.10"
3232
LLVM = "9.1"
3333
LinearAlgebra = "1"
3434
OpenCL_jll = "=2024.10.24"

src/OpenCLKernels.jl

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
module OpenCLKernels
22

33
using ..OpenCL
4-
using ..OpenCL: @device_override, method_table
4+
using ..OpenCL: @device_override, method_table, kernel_convert, clfunction
55

66
import KernelAbstractions as KA
7+
import KernelAbstractions.KernelIntrinsics as KI
78

89
import StaticArrays
910

@@ -126,33 +127,61 @@ function (obj::KA.Kernel{OpenCLBackend})(args...; ndrange=nothing, workgroupsize
126127
return nothing
127128
end
128129

130+
KI.argconvert(::OpenCLBackend, arg) = kernel_convert(arg)
131+
132+
function KI.kernel_function(::OpenCLBackend, f::F, tt::TT=Tuple{}; name = nothing, kwargs...) where {F,TT}
133+
kern = clfunction(f, tt; name, kwargs...)
134+
KI.KIKernel{OpenCLBackend, typeof(kern)}(OpenCLBackend(), kern)
135+
end
136+
137+
function (obj::KI.KIKernel{OpenCLBackend})(args...; numworkgroups = 1, workgroupsize = 1)
138+
KI.check_launch_args(numworkgroups, workgroupsize)
139+
140+
local_size = (workgroupsize..., ntuple(_ -> 1, 3 - length(workgroupsize))...)
141+
142+
numworkgroups = (numworkgroups..., ntuple(_ -> 1, 3 - length(numworkgroups))...)
143+
global_size = local_size .* numworkgroups
144+
145+
obj.kern(args...; local_size, global_size)
146+
return nothing
147+
end
148+
149+
150+
function KI.kernel_max_work_group_size(kernel::KI.KIKernel{<:OpenCLBackend}; max_work_items::Int=typemax(Int))::Int
151+
wginfo = cl.work_group_info(kernel.kern.fun, cl.device())
152+
Int(min(wginfo.size, max_work_items))
153+
end
154+
function KI.max_work_group_size(::OpenCLBackend)::Int
155+
Int(cl.device().max_work_group_size)
156+
end
157+
function KI.multiprocessor_count(::OpenCLBackend)::Int
158+
Int(cl.device().max_compute_units)
159+
end
129160

130161
## Indexing Functions
131162

132-
@device_override @inline function KA.__index_Local_Linear(ctx)
133-
return get_local_id(1)
163+
@device_override @inline function KI.get_local_id()
164+
return (; x = Int(get_local_id(1)), y = Int(get_local_id(2)), z = Int(get_local_id(3)))
134165
end
135166

136-
@device_override @inline function KA.__index_Group_Linear(ctx)
137-
return get_group_id(1)
167+
@device_override @inline function KI.get_group_id()
168+
return (; x = Int(get_group_id(1)), y = Int(get_group_id(2)), z = Int(get_group_id(3)))
138169
end
139170

140-
@device_override @inline function KA.__index_Global_Linear(ctx)
141-
#return get_global_id(1) # JuliaGPU/OpenCL.jl#346
142-
I = KA.__index_Global_Cartesian(ctx)
143-
@inbounds LinearIndices(KA.__ndrange(ctx))[I]
171+
@device_override @inline function KI.get_global_id()
172+
return (; x = Int(get_global_id(1)), y = Int(get_global_id(2)), z = Int(get_global_id(3)))
144173
end
145174

146-
@device_override @inline function KA.__index_Local_Cartesian(ctx)
147-
@inbounds KA.workitems(KA.__iterspace(ctx))[get_local_id(1)]
175+
@device_override @inline function KI.get_local_size()
176+
return (; x = Int(get_local_size(1)), y = Int(get_local_size(2)), z = Int(get_local_size(3)))
148177
end
149178

150-
@device_override @inline function KA.__index_Group_Cartesian(ctx)
151-
@inbounds KA.blocks(KA.__iterspace(ctx))[get_group_id(1)]
179+
@device_override @inline function KI.get_num_groups()
180+
return (; x = Int(get_num_groups(1)), y = Int(get_num_groups(2)), z = Int(get_num_groups(3)))
152181
end
153182

154-
@device_override @inline function KA.__index_Global_Cartesian(ctx)
155-
return @inbounds KA.expand(KA.__iterspace(ctx), get_group_id(1), get_local_id(1))
183+
@device_override @inline function KI.get_global_size()
184+
return (; x = Int(get_global_size(1)), y = Int(get_global_size(2)), z = Int(get_global_size(3)))
156185
end
157186

158187
@device_override @inline function KA.__validindex(ctx)
@@ -167,7 +196,7 @@ end
167196

168197
## Shared and Scratch Memory
169198

170-
@device_override @inline function KA.SharedMemory(::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
199+
@device_override @inline function KI.localmemory(::Type{T}, ::Val{Dims}) where {T, Dims}
171200
ptr = OpenCL.emit_localmemory(T, Val(prod(Dims)))
172201
CLDeviceArray(Dims, ptr)
173202
end
@@ -179,11 +208,11 @@ end
179208

180209
## Synchronization and Printing
181210

182-
@device_override @inline function KA.__synchronize()
211+
@device_override @inline function KI.barrier()
183212
work_group_barrier(OpenCL.LOCAL_MEM_FENCE | OpenCL.GLOBAL_MEM_FENCE)
184213
end
185214

186-
@device_override @inline function KA.__print(args...)
215+
@device_override @inline function KI._print(args...)
187216
OpenCL._print(args...)
188217
end
189218

0 commit comments

Comments
 (0)