@@ -25,14 +25,15 @@ ROCDeviceArray
2525# NOTE: we can't support the typical `tuple or series of integer` style
2626# construction, because we're currently requiring a trailing pointer argument.
2727
28- struct ROCDeviceArray{T,N,A} <: AbstractArray{T,N}
29- shape:: Dims{N}
28+ struct ROCDeviceArray{T,N,A} <: DenseArray{T,N}
29+ # NOTE: `dims` is a mandatory name for things like `strides(xd)` to work.
30+ dims:: Dims{N}
3031 ptr:: LLVMPtr{T,A}
3132 len:: Int
3233
3334 # inner constructors, fully parameterized, exact types (ie. Int not <:Integer)
34- function ROCDeviceArray {T,N,A} (shape :: Dims{N} , ptr:: LLVMPtr{T,A} ) where {T,A,N}
35- new (shape , ptr, prod (shape ))
35+ function ROCDeviceArray {T,N,A} (dims :: Dims{N} , ptr:: LLVMPtr{T,A} ) where {T,A,N}
36+ new (dims , ptr, prod (dims ))
3637 end
3738end
3839
@@ -65,7 +66,8 @@ Base.pointer(a::ROCDeviceArray, i::Integer) =
6566 pointer (a) + (i - 1 ) * Base. elsize (a) # TODO use _memory_offset(a, i)
6667
6768Base. elsize (:: Type{<:ROCDeviceArray{T}} ) where {T} = sizeof (T)
68- Base. size (g:: ROCDeviceArray ) = g. shape
69+ Base. size (g:: ROCDeviceArray ) = g. dims
70+ Base. sizeof (x:: ROCDeviceArray ) = Base. elsize (x) * length (x)
6971Base. length (g:: ROCDeviceArray ) = g. len
7072
7173# conversions
@@ -96,14 +98,14 @@ Base.IndexStyle(::Type{<:ROCDeviceArray}) = Base.IndexLinear()
9698# comparisons
9799
98100Base. isequal (a1:: R1 , a2:: R2 ) where {R1<: ROCDeviceArray ,R2<: ROCDeviceArray } =
99- R1 == R2 && a1. shape == a2. shape && a1. ptr == a2. ptr
101+ R1 == R2 && a1. dims == a2. dims && a1. ptr == a2. ptr
100102
101103# other
102104
103105Base. show (io:: IO , a:: ROCDeviceVector ) =
104106 print (io, " $(length (a)) -element device array at $(pointer (a)) " )
105107Base. show (io:: IO , a:: ROCDeviceArray ) =
106- print (io, " $(join (a. shape , ' ×' )) device array at $(pointer (a)) " )
108+ print (io, " $(join (a. dims , ' ×' )) device array at $(pointer (a)) " )
107109
108110Base. show (io:: IO , a:: SubArray{T,D,P,I,F} ) where {T,D,P<: ROCDeviceVector ,I,F} =
109111 print (io, " $(length (a. indices[1 ])) -element device array view(::$P at $(pointer (parent (a))) , $(a. indices[1 ]) ) with eltype $T " )
@@ -113,12 +115,12 @@ Base.show(io::IO, a::SubArray{T,D,P,I,F}) where {T,D,P<:ROCDeviceArray,I,F} =
113115Base. show (io:: IO , a:: S ) where S<: AnyROCDeviceVector =
114116 print (io, " $(length (a)) -element device array wrapper $S at $(pointer (parent (a))) " )
115117Base. show (io:: IO , a:: S ) where S<: AnyROCDeviceArray =
116- print (io, " $(join (parent (a). shape , ' ×' )) device array wrapper $S at $(pointer (parent (a))) " )
118+ print (io, " $(join (parent (a). dims , ' ×' )) device array wrapper $S at $(pointer (parent (a))) " )
117119
118120Base. show (io:: IO , mime:: MIME"text/plain" , a:: S ) where S<: AnyROCDeviceArray = show (io, a)
119121
120122@inline function Base. unsafe_view (A:: ROCDeviceVector{T} , I:: Vararg{Base.ViewIndex,1} ) where {T}
121- ptr = pointer (A) + (I[1 ]. start- 1 ) * sizeof (T)
123+ ptr = pointer (A) + (I[1 ]. start - 1 ) * sizeof (T)
122124 len = I[1 ]. stop - I[1 ]. start + 1
123125
124126 return ROCDeviceArray (len, ptr)
0 commit comments