@@ -19,16 +19,17 @@ struct oneAPIBackend <: KA.GPU
1919 always_inline:: Bool
2020end
2121
22- oneAPIBackend (; prefer_blocks= false , always_inline= false ) = oneAPIBackend (prefer_blocks, always_inline)
22+ oneAPIBackend (; prefer_blocks = false , always_inline = false ) = oneAPIBackend (prefer_blocks, always_inline)
2323
24- @inline KA. allocate (:: oneAPIBackend , :: Type{T} , dims:: Tuple ; unified:: Bool = false ) where T = oneArray {T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer} (undef, dims)
25- @inline KA. zeros (:: oneAPIBackend , :: Type{T} , dims:: Tuple ; unified:: Bool = false ) where T = fill! (oneArray {T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer} (undef, dims), zero (T))
26- @inline KA. ones (:: oneAPIBackend , :: Type{T} , dims:: Tuple ; unified:: Bool = false ) where T = fill! (oneArray {T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer} (undef, dims), one (T))
24+ @inline KA. allocate (:: oneAPIBackend , :: Type{T} , dims:: Tuple ; unified:: Bool = false ) where {T} = oneArray {T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer} (undef, dims)
25+ @inline KA. zeros (:: oneAPIBackend , :: Type{T} , dims:: Tuple ; unified:: Bool = false ) where {T} = fill! (oneArray {T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer} (undef, dims), zero (T))
26+ @inline KA. ones (:: oneAPIBackend , :: Type{T} , dims:: Tuple ; unified:: Bool = false ) where {T} = fill! (oneArray {T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer} (undef, dims), one (T))
2727
2828KA. get_backend (:: oneArray ) = oneAPIBackend ()
2929# TODO should be non-blocking
3030KA. synchronize (:: oneAPIBackend ) = oneAPI. oneL0. synchronize ()
3131KA. supports_float64 (:: oneAPIBackend ) = false # TODO : Check if this is device dependent
32+ KA. supports_unified (:: oneAPIBackend ) = true
3233
3334KA. functional (:: oneAPIBackend ) = oneAPI. functional ()
3435
@@ -59,7 +60,7 @@ function KA.device(::oneAPIBackend)::Int
5960end
6061
6162function KA. device! (backend:: oneAPIBackend , id:: Int )
62- oneAPI. device! (id)
63+ return oneAPI. device! (id)
6364end
6465
6566
@@ -121,7 +122,7 @@ function (obj::KA.Kernel{oneAPIBackend})(args...; ndrange=nothing, workgroupsize
121122 # maxthreads = nothing
122123 end
123124
124- kernel = @oneapi launch= false always_inline= backend. always_inline obj. f (ctx, args... )
125+ kernel = @oneapi launch = false always_inline = backend. always_inline obj. f (ctx, args... )
125126
126127 # figure out the optimal workgroupsize automatically
127128 if KA. workgroupsize (obj) <: KA.DynamicSize && workgroupsize === nothing
@@ -137,9 +138,9 @@ function (obj::KA.Kernel{oneAPIBackend})(args...; ndrange=nothing, workgroupsize
137138 # (Simplified logic compared to CUDA.jl which uses explicit occupancy calculators)
138139 total_items = prod (ndrange)
139140 if total_items < items * 16 # Heuristic factor
140- # Force at least a few blocks if possible by reducing items per block
141- target_blocks = 16 # Target at least 16 blocks
142- items = max (1 , min (items, cld (total_items, target_blocks)))
141+ # Force at least a few blocks if possible by reducing items per block
142+ target_blocks = 16 # Target at least 16 blocks
143+ items = max (1 , min (items, cld (total_items, target_blocks)))
143144 end
144145 end
145146
@@ -251,7 +252,8 @@ function KA.priority!(::oneAPIBackend, prio::Symbol)
251252 # Replace the queue in task_local_storage
252253 # The key used by global_queue is (:ZeCommandQueue, ctx, dev)
253254
254- new_queue = oneAPI. oneL0. ZeCommandQueue (ctx, dev;
255+ new_queue = oneAPI. oneL0. ZeCommandQueue (
256+ ctx, dev;
255257 flags = oneAPI. oneL0. ZE_COMMAND_QUEUE_FLAG_IN_ORDER,
256258 priority = priority_enum
257259 )
0 commit comments