|
48 | 48 | call_function(llvm_f, Core.LLVMPtr{T,AS.ThreadGroup}) |
49 | 49 | end |
50 | 50 | end |
51 | | - |
52 | | - |
53 | | -## device array wrapper extending small element types |
54 | | - |
55 | | -struct MtlLargerDeviceArray{T,N,A} <: DenseArray{T,N} |
56 | | - x::MtlDeviceArray{UInt32,N,A} |
57 | | -end |
58 | | - |
59 | | -Base.elsize(::Type{<:MtlLargerDeviceArray{T}}) where {T} = sizeof(UInt32) |
60 | | - |
61 | | -Base.size(g::MtlLargerDeviceArray) = size(g.x) |
62 | | -Base.sizeof(x::MtlLargerDeviceArray) = Base.elsize(x) * length(x) |
63 | | - |
64 | | -Base.pointer(x::MtlLargerDeviceArray{T,<:Any,A}) where {T,A} = |
65 | | - Base.unsafe_convert(Core.LLVMPtr{T,A}, x) |
66 | | -@inline function Base.pointer(x::MtlLargerDeviceArray{T,<:Any,A}, i::Integer) where {T,A} |
67 | | - Base.unsafe_convert(Core.LLVMPtr{T,A}, x) + Base._memory_offset(x, i) |
68 | | -end |
69 | | - |
70 | | -Base.unsafe_convert(::Type{Core.LLVMPtr{T,A}}, x::MtlLargerDeviceArray{T,<:Any,A}) where {T,A} = |
71 | | - reinterpret(Core.LLVMPtr{T,A}, Base.unsafe_convert(Core.LLVMPtr{UInt32,A}, x.x)) |
72 | | - |
73 | | -Base.@propagate_inbounds Base.getindex(A::MtlLargerDeviceArray{T}, i1::Integer) where {T} = |
74 | | - arrayref(A, i1) |
75 | | -Base.@propagate_inbounds Base.setindex!(A::MtlLargerDeviceArray{T}, x, i1::Integer) where {T} = |
76 | | - arrayset(A, convert(T, x)::T, i1) |
77 | | - |
78 | | -# preserve the specific integer type when indexing device arrays, |
79 | | -# to avoid extending 32-bit hardware indices to 64-bit. |
80 | | -Base.to_index(::MtlLargerDeviceArray, i::Integer) = i |
81 | | - |
82 | | -# Base doesn't like Integer indices, so we need our own ND get and setindex! routines. |
83 | | -# See also: https://github.com/JuliaLang/julia/pull/42289 |
84 | | -Base.@propagate_inbounds Base.getindex(A::MtlLargerDeviceArray, |
85 | | - I::Union{Integer, CartesianIndex}...) = |
86 | | - A[Base._to_linear_index(A, to_indices(A, I)...)] |
87 | | -Base.@propagate_inbounds Base.setindex!(A::MtlLargerDeviceArray, x, |
88 | | - I::Union{Integer, CartesianIndex}...) = |
89 | | - A[Base._to_linear_index(A, to_indices(A, I)...)] = x |
90 | | - |
91 | | -@inline function arrayref(A::MtlLargerDeviceArray{T}, index::Integer) where {T} |
92 | | - @boundscheck checkbounds(A, index) |
93 | | - align = Base.datatype_alignment(T) |
94 | | - unsafe_load(pointer(A), index, Val(align)) |
95 | | -end |
96 | | - |
97 | | -@inline function arrayset(A::MtlLargerDeviceArray{T}, x::T, index::Integer) where {T} |
98 | | - @boundscheck checkbounds(A, index) |
99 | | - align = Base.datatype_alignment(T) |
100 | | - unsafe_store!(pointer(A), x, index, Val(align)) |
101 | | - return A |
102 | | -end |
0 commit comments