@@ -15,18 +15,24 @@ import Adapt
1515export oneAPIBackend
1616
1717struct oneAPIBackend <: KA.GPU
18+ prefer_blocks:: Bool
19+ always_inline:: Bool
1820end
1921
20- KA. allocate (:: oneAPIBackend , :: Type{T} , dims:: Tuple ) where T = oneArray {T} (undef, dims)
21- KA. zeros (:: oneAPIBackend , :: Type{T} , dims:: Tuple ) where T = oneAPI. zeros (T, dims)
22- KA. ones (:: oneAPIBackend , :: Type{T} , dims:: Tuple ) where T = oneAPI. ones (T, dims)
22+ oneAPIBackend (; prefer_blocks= false , always_inline= false ) = oneAPIBackend (prefer_blocks, always_inline)
23+
24+ @inline KA. allocate (:: oneAPIBackend , :: Type{T} , dims:: Tuple ; unified:: Bool = false ) where T = oneArray {T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer} (undef, dims)
25+ @inline KA. zeros (:: oneAPIBackend , :: Type{T} , dims:: Tuple ; unified:: Bool = false ) where T = fill! (oneArray {T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer} (undef, dims), zero (T))
26+ @inline KA. ones (:: oneAPIBackend , :: Type{T} , dims:: Tuple ; unified:: Bool = false ) where T = fill! (oneArray {T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer} (undef, dims), one (T))
2327
2428KA. get_backend (:: oneArray ) = oneAPIBackend ()
2529# TODO should be non-blocking
26- KA. synchronize (:: oneAPIBackend ) = oneL0. synchronize ()
30+ KA. synchronize (:: oneAPIBackend ) = oneAPI . oneL0. synchronize ()
2731KA. supports_float64 (:: oneAPIBackend ) = false # TODO : Check if this is device dependent
2832
29- Adapt. adapt_storage (:: oneAPIBackend , a:: Array ) = Adapt. adapt (oneArray, a)
33+ KA. functional (:: oneAPIBackend ) = oneAPI. functional ()
34+
35+ Adapt. adapt_storage (:: oneAPIBackend , a:: AbstractArray ) = Adapt. adapt (oneArray, a)
3036Adapt. adapt_storage (:: oneAPIBackend , a:: oneArray ) = a
3137Adapt. adapt_storage (:: KA.CPU , a:: oneArray ) = convert (Array, a)
3238
@@ -39,6 +45,24 @@ function KA.copyto!(::oneAPIBackend, A, B)
3945end
4046
4147
48+ # # Device Operations
49+
50+ function KA. ndevices (:: oneAPIBackend )
51+ return length (oneAPI. devices ())
52+ end
53+
54+ function KA. device (:: oneAPIBackend ):: Int
55+ dev = oneAPI. device ()
56+ devs = oneAPI. devices ()
57+ idx = findfirst (== (dev), devs)
58+ return idx === nothing ? 1 : idx
59+ end
60+
61+ function KA. device! (backend:: oneAPIBackend , id:: Int )
62+ oneAPI. device! (id)
63+ end
64+
65+
4266# # Kernel Launch
4367
4468function KA. mkcontext (kernel:: KA.Kernel{oneAPIBackend} , _ndrange, iterspace)
@@ -83,14 +107,42 @@ function threads_to_workgroupsize(threads, ndrange)
83107end
84108
85109function (obj:: KA.Kernel{oneAPIBackend} )(args... ; ndrange= nothing , workgroupsize= nothing )
110+ backend = KA. backend (obj)
111+
86112 ndrange, workgroupsize, iterspace, dynamic = KA. launch_config (obj, ndrange, workgroupsize)
87113 # this might not be the final context, since we may tune the workgroupsize
88114 ctx = KA. mkcontext (obj, ndrange, iterspace)
89- kernel = @oneapi launch= false obj. f (ctx, args... )
115+
116+ # If the kernel is statically sized we can tell the compiler about that
117+ if KA. workgroupsize (obj) <: KA.StaticSize
118+ # TODO : maxthreads
119+ # maxthreads = prod(KA.get(KA.workgroupsize(obj)))
120+ else
121+ # maxthreads = nothing
122+ end
123+
124+ kernel = @oneapi launch= false always_inline= backend. always_inline obj. f (ctx, args... )
90125
91126 # figure out the optimal workgroupsize automatically
92127 if KA. workgroupsize (obj) <: KA.DynamicSize && workgroupsize === nothing
93128 items = oneAPI. launch_configuration (kernel)
129+
130+ if backend. prefer_blocks
131+ # Prefer blocks over threads:
132+ # Reducing the workgroup size (items) increases the number of workgroups (blocks).
133+ # We use a simple heuristic here since we lack full occupancy info (max_blocks) from launch_configuration.
134+
135+ # If the total range is large enough, full workgroups are fine.
136+ # If the range is small, we might want to reduce 'items' to create more blocks to fill the GPU.
137+ # (Simplified logic compared to CUDA.jl which uses explicit occupancy calculators)
138+ total_items = prod (ndrange)
139+ if total_items < items * 16 # Heuristic factor
140+ # Force at least a few blocks if possible by reducing items per block
141+ target_blocks = 16 # Target at least 16 blocks
142+ items = max (1 , min (items, cld (total_items, target_blocks)))
143+ end
144+ end
145+
94146 workgroupsize = threads_to_workgroupsize (items, ndrange)
95147 iterspace, dynamic = KA. partition (obj, ndrange, workgroupsize)
96148 ctx = KA. mkcontext (obj, ndrange, iterspace)
171223
172224# # Other
173225
226+ Adapt. adapt_storage (to:: KA.ConstAdaptor , a:: oneDeviceArray ) = Base. Experimental. Const (a)
227+
174228KA. argconvert (:: KA.Kernel{oneAPIBackend} , arg) = kernel_convert (arg)
175229
230+ function KA. priority! (:: oneAPIBackend , prio:: Symbol )
231+ if ! (prio in (:high , :normal , :low ))
232+ error (" priority must be one of :high, :normal, :low" )
233+ end
234+
235+ priority_enum = if prio == :high
236+ oneAPI. oneL0. ZE_COMMAND_QUEUE_PRIORITY_PRIORITY_HIGH
237+ elseif prio == :low
238+ oneAPI. oneL0. ZE_COMMAND_QUEUE_PRIORITY_PRIORITY_LOW
239+ else
240+ oneAPI. oneL0. ZE_COMMAND_QUEUE_PRIORITY_NORMAL
241+ end
242+
243+ ctx = oneAPI. context ()
244+ dev = oneAPI. device ()
245+
246+ # Update the cached queue
247+ # We synchronize the current queue first to ensure safety
248+ current_queue = oneAPI. global_queue (ctx, dev)
249+ oneAPI. oneL0. synchronize (current_queue)
250+
251+ # Replace the queue in task_local_storage
252+ # The key used by global_queue is (:ZeCommandQueue, ctx, dev)
253+
254+ new_queue = oneAPI. oneL0. ZeCommandQueue (ctx, dev;
255+ flags = oneAPI. oneL0. ZE_COMMAND_QUEUE_FLAG_IN_ORDER,
256+ priority = priority_enum
257+ )
258+
259+ task_local_storage ((:ZeCommandQueue , ctx, dev), new_queue)
260+
261+ return nothing
262+ end
263+
176264end
0 commit comments