Skip to content

Commit e412534

Browse files
committed
Support keys() and unsafe_cached_load()
1 parent 37c2c4b commit e412534

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

lib/level-zero/device.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ Base.length(iter::ZeDevices) = length(iter.handles)
204204

205205
Base.IteratorSize(::ZeDevices) = Base.HasLength()
206206

207+
Base.keys(iter::ZeDevices) = 1:length(iter)
208+
207209
function Base.show(io::IO, ::MIME"text/plain", iter::ZeDevices)
208210
print(io, "ZeDevice iterator for $(length(iter)) devices")
209211
if !isempty(iter)

src/context.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,13 @@ See also: [`device`](@ref), [`devices`](@ref)
103103
function device!(drv::ZeDevice)
104104
task_local_storage(:ZeDevice, drv)
105105
end
106-
device!(i::Int) = device!(devices(driver())[i])
106+
function device!(i::Int)
107+
devs = devices(driver())
108+
if i < 1 || i > length(devs)
109+
throw(ArgumentError("Invalid device index $i (must be between 1 and $(length(devs)))"))
110+
end
111+
device!(devs[i])
112+
end
107113

108114
const global_contexts = Dict{ZeDriver,ZeContext}()
109115

src/device/array.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,13 @@ end
195195
end
196196
end
197197

198+
@device_function @inline function unsafe_cached_load(ptr::LLVMPtr{T,A}, i::Integer, align::Val) where {T,A}
199+
# For SPIR-V/Level Zero, we don't have explicit cache control intrinsics like CUDA's __ldg
200+
# So we fall back to a regular unsafe_load. The SPIR-V compiler may still apply
201+
# appropriate optimizations based on context.
202+
unsafe_load(ptr, i, align)
203+
end
204+
198205
@device_function @inline function const_arrayref(A::oneDeviceArray{T}, index::Integer) where {T}
199206
# simplified bounds check (see `arrayset`)
200207
#@boundscheck checkbounds(A, index)

0 commit comments

Comments
 (0)