Skip to content

Commit 37c2c4b

Browse files
committed
Update KA API
1 parent a00fad6 commit 37c2c4b

File tree

1 file changed

+94
-6
lines changed

1 file changed

+94
-6
lines changed

src/oneAPIKernels.jl

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,24 @@ import Adapt
1515
export oneAPIBackend
1616

1717
struct oneAPIBackend <: KA.GPU
18+
prefer_blocks::Bool
19+
always_inline::Bool
1820
end
1921

20-
KA.allocate(::oneAPIBackend, ::Type{T}, dims::Tuple) where T = oneArray{T}(undef, dims)
21-
KA.zeros(::oneAPIBackend, ::Type{T}, dims::Tuple) where T = oneAPI.zeros(T, dims)
22-
KA.ones(::oneAPIBackend, ::Type{T}, dims::Tuple) where T = oneAPI.ones(T, dims)
22+
oneAPIBackend(; prefer_blocks=false, always_inline=false) = oneAPIBackend(prefer_blocks, always_inline)
23+
24+
@inline KA.allocate(::oneAPIBackend, ::Type{T}, dims::Tuple; unified::Bool = false) where T = oneArray{T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer}(undef, dims)
25+
@inline KA.zeros(::oneAPIBackend, ::Type{T}, dims::Tuple; unified::Bool = false) where T = fill!(oneArray{T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer}(undef, dims), zero(T))
26+
@inline KA.ones(::oneAPIBackend, ::Type{T}, dims::Tuple; unified::Bool = false) where T = fill!(oneArray{T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer}(undef, dims), one(T))
2327

2428
KA.get_backend(::oneArray) = oneAPIBackend()
2529
# TODO should be non-blocking
26-
KA.synchronize(::oneAPIBackend) = oneL0.synchronize()
30+
KA.synchronize(::oneAPIBackend) = oneAPI.oneL0.synchronize()
2731
KA.supports_float64(::oneAPIBackend) = false # TODO: Check if this is device dependent
2832

29-
Adapt.adapt_storage(::oneAPIBackend, a::Array) = Adapt.adapt(oneArray, a)
33+
KA.functional(::oneAPIBackend) = oneAPI.functional()
34+
35+
Adapt.adapt_storage(::oneAPIBackend, a::AbstractArray) = Adapt.adapt(oneArray, a)
3036
Adapt.adapt_storage(::oneAPIBackend, a::oneArray) = a
3137
Adapt.adapt_storage(::KA.CPU, a::oneArray) = convert(Array, a)
3238

@@ -39,6 +45,24 @@ function KA.copyto!(::oneAPIBackend, A, B)
3945
end
4046

4147

48+
## Device Operations
49+
50+
function KA.ndevices(::oneAPIBackend)
51+
return length(oneAPI.devices())
52+
end
53+
54+
function KA.device(::oneAPIBackend)::Int
55+
dev = oneAPI.device()
56+
devs = oneAPI.devices()
57+
idx = findfirst(==(dev), devs)
58+
return idx === nothing ? 1 : idx
59+
end
60+
61+
function KA.device!(backend::oneAPIBackend, id::Int)
62+
oneAPI.device!(id)
63+
end
64+
65+
4266
## Kernel Launch
4367

4468
function KA.mkcontext(kernel::KA.Kernel{oneAPIBackend}, _ndrange, iterspace)
@@ -83,14 +107,42 @@ function threads_to_workgroupsize(threads, ndrange)
83107
end
84108

85109
function (obj::KA.Kernel{oneAPIBackend})(args...; ndrange=nothing, workgroupsize=nothing)
110+
backend = KA.backend(obj)
111+
86112
ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(obj, ndrange, workgroupsize)
87113
# this might not be the final context, since we may tune the workgroupsize
88114
ctx = KA.mkcontext(obj, ndrange, iterspace)
89-
kernel = @oneapi launch=false obj.f(ctx, args...)
115+
116+
# If the kernel is statically sized we can tell the compiler about that
117+
if KA.workgroupsize(obj) <: KA.StaticSize
118+
# TODO: maxthreads
119+
# maxthreads = prod(KA.get(KA.workgroupsize(obj)))
120+
else
121+
# maxthreads = nothing
122+
end
123+
124+
kernel = @oneapi launch=false always_inline=backend.always_inline obj.f(ctx, args...)
90125

91126
# figure out the optimal workgroupsize automatically
92127
if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing
93128
items = oneAPI.launch_configuration(kernel)
129+
130+
if backend.prefer_blocks
131+
# Prefer blocks over threads:
132+
# Reducing the workgroup size (items) increases the number of workgroups (blocks).
133+
# We use a simple heuristic here since we lack full occupancy info (max_blocks) from launch_configuration.
134+
135+
# If the total range is large enough, full workgroups are fine.
136+
# If the range is small, we might want to reduce 'items' to create more blocks to fill the GPU.
137+
# (Simplified logic compared to CUDA.jl which uses explicit occupancy calculators)
138+
total_items = prod(ndrange)
139+
if total_items < items * 16 # Heuristic factor
140+
# Force at least a few blocks if possible by reducing items per block
141+
target_blocks = 16 # Target at least 16 blocks
142+
items = max(1, min(items, cld(total_items, target_blocks)))
143+
end
144+
end
145+
94146
workgroupsize = threads_to_workgroupsize(items, ndrange)
95147
iterspace, dynamic = KA.partition(obj, ndrange, workgroupsize)
96148
ctx = KA.mkcontext(obj, ndrange, iterspace)
@@ -171,6 +223,42 @@ end
171223

172224
## Other
173225

226+
Adapt.adapt_storage(to::KA.ConstAdaptor, a::oneDeviceArray) = Base.Experimental.Const(a)
227+
174228
KA.argconvert(::KA.Kernel{oneAPIBackend}, arg) = kernel_convert(arg)
175229

230+
function KA.priority!(::oneAPIBackend, prio::Symbol)
231+
if !(prio in (:high, :normal, :low))
232+
error("priority must be one of :high, :normal, :low")
233+
end
234+
235+
priority_enum = if prio == :high
236+
oneAPI.oneL0.ZE_COMMAND_QUEUE_PRIORITY_PRIORITY_HIGH
237+
elseif prio == :low
238+
oneAPI.oneL0.ZE_COMMAND_QUEUE_PRIORITY_PRIORITY_LOW
239+
else
240+
oneAPI.oneL0.ZE_COMMAND_QUEUE_PRIORITY_NORMAL
241+
end
242+
243+
ctx = oneAPI.context()
244+
dev = oneAPI.device()
245+
246+
# Update the cached queue
247+
# We synchronize the current queue first to ensure safety
248+
current_queue = oneAPI.global_queue(ctx, dev)
249+
oneAPI.oneL0.synchronize(current_queue)
250+
251+
# Replace the queue in task_local_storage
252+
# The key used by global_queue is (:ZeCommandQueue, ctx, dev)
253+
254+
new_queue = oneAPI.oneL0.ZeCommandQueue(ctx, dev;
255+
flags = oneAPI.oneL0.ZE_COMMAND_QUEUE_FLAG_IN_ORDER,
256+
priority = priority_enum
257+
)
258+
259+
task_local_storage((:ZeCommandQueue, ctx, dev), new_queue)
260+
261+
return nothing
262+
end
263+
176264
end

0 commit comments

Comments
 (0)