Skip to content

Commit d77cee6

Browse files
authored
Adapt ROCDeviceArray to DenseArray interface (#692)
* Adapt ROCDeviceArray to DenseArray interface * Disable opaque pointers CI for now
1 parent 12fcce0 commit d77cee6

File tree

5 files changed

+52
-32
lines changed

5 files changed

+52
-32
lines changed

.buildkite/pipeline.yml

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,26 +62,26 @@ steps:
6262
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
6363
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
6464

65-
- label: "Julia 1.11 opaque pointers"
66-
plugins:
67-
- JuliaCI/julia#v1:
68-
version: "1.11"
69-
- JuliaCI/julia-test#v1:
70-
- JuliaCI/julia-coverage#v1:
71-
codecov: true
72-
agents:
73-
queue: "juliagpu"
74-
rocm: "*"
75-
rocmgpu: "*"
76-
if: build.message !~ /\[skip tests\]/
77-
command: "julia --project -e 'using Pkg; Pkg.update()'"
78-
timeout_in_minutes: 180
79-
env:
80-
JULIA_LLVM_ARGS: "-opaque-pointers"
81-
JULIA_NUM_THREADS: 4
82-
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
83-
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
84-
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
65+
# - label: "Julia 1.11 opaque pointers"
66+
# plugins:
67+
# - JuliaCI/julia#v1:
68+
# version: "1.11"
69+
# - JuliaCI/julia-test#v1:
70+
# - JuliaCI/julia-coverage#v1:
71+
# codecov: true
72+
# agents:
73+
# queue: "juliagpu"
74+
# rocm: "*"
75+
# rocmgpu: "*"
76+
# if: build.message !~ /\[skip tests\]/
77+
# command: "julia --project -e 'using Pkg; Pkg.update()'"
78+
# timeout_in_minutes: 180
79+
# env:
80+
# JULIA_LLVM_ARGS: "-opaque-pointers"
81+
# JULIA_NUM_THREADS: 4
82+
# JULIA_AMDGPU_CORE_MUST_LOAD: "1"
83+
# JULIA_AMDGPU_HIP_MUST_LOAD: "1"
84+
# JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
8585

8686
- label: "GPU-less environment"
8787
plugins:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
3535

3636
[compat]
3737
AbstractFFTs = "1.0"
38-
AcceleratedKernels = "0.1.0"
38+
AcceleratedKernels = "0.1, 0.2"
3939
Adapt = "4"
4040
Atomix = "0.1"
4141
CEnum = "0.4, 0.5"

src/ROCKernels.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ struct ROCBackend <: KA.GPU end
1717
Adapt.adapt_storage(::ROCBackend, a::Array) = Adapt.adapt(AMDGPU.ROCArray, a)
1818
Adapt.adapt_storage(::ROCBackend, a::AMDGPU.ROCArray) = a
1919
Adapt.adapt_storage(::KA.CPU, a::AMDGPU.ROCArray) = convert(Array, a)
20-
Adapt.adapt_storage(::KA.ConstAdaptor, a::AMDGPU.ROCDeviceArray{T}) where T =
21-
AMDGPU.ROCDeviceArray(a.shape, LLVM.Interop.addrspacecast(Core.LLVMPtr{T,AMDGPU.Device.AS.Constant}, a.ptr))
20+
function Adapt.adapt_storage(::KA.ConstAdaptor, a::AMDGPU.ROCDeviceArray{T}) where T
21+
ptr = LLVM.Interop.addrspacecast(Core.LLVMPtr{T,AMDGPU.Device.AS.Constant}, a.ptr)
22+
AMDGPU.ROCDeviceArray(a.dims, ptr)
23+
end
2224

2325
KA.argconvert(::KA.Kernel{ROCBackend}, arg) = AMDGPU.rocconvert(arg)
2426
KA.get_backend(::AMDGPU.ROCArray) = ROCBackend()

src/device/gcn/array.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@ ROCDeviceArray
2525
# NOTE: we can't support the typical `tuple or series of integer` style
2626
# construction, because we're currently requiring a trailing pointer argument.
2727

28-
struct ROCDeviceArray{T,N,A} <: AbstractArray{T,N}
29-
shape::Dims{N}
28+
struct ROCDeviceArray{T,N,A} <: DenseArray{T,N}
29+
# NOTE: `dims` is a mandatory name for things like `strides(xd)` to work.
30+
dims::Dims{N}
3031
ptr::LLVMPtr{T,A}
3132
len::Int
3233

3334
# inner constructors, fully parameterized, exact types (ie. Int not <:Integer)
34-
function ROCDeviceArray{T,N,A}(shape::Dims{N}, ptr::LLVMPtr{T,A}) where {T,A,N}
35-
new(shape, ptr, prod(shape))
35+
function ROCDeviceArray{T,N,A}(dims::Dims{N}, ptr::LLVMPtr{T,A}) where {T,A,N}
36+
new(dims, ptr, prod(dims))
3637
end
3738
end
3839

@@ -65,7 +66,8 @@ Base.pointer(a::ROCDeviceArray, i::Integer) =
6566
pointer(a) + (i - 1) * Base.elsize(a) # TODO use _memory_offset(a, i)
6667

6768
Base.elsize(::Type{<:ROCDeviceArray{T}}) where {T} = sizeof(T)
68-
Base.size(g::ROCDeviceArray) = g.shape
69+
Base.size(g::ROCDeviceArray) = g.dims
70+
Base.sizeof(x::ROCDeviceArray) = Base.elsize(x) * length(x)
6971
Base.length(g::ROCDeviceArray) = g.len
7072

7173
# conversions
@@ -96,14 +98,14 @@ Base.IndexStyle(::Type{<:ROCDeviceArray}) = Base.IndexLinear()
9698
# comparisons
9799

98100
Base.isequal(a1::R1, a2::R2) where {R1<:ROCDeviceArray,R2<:ROCDeviceArray} =
99-
R1 == R2 && a1.shape == a2.shape && a1.ptr == a2.ptr
101+
R1 == R2 && a1.dims == a2.dims && a1.ptr == a2.ptr
100102

101103
# other
102104

103105
Base.show(io::IO, a::ROCDeviceVector) =
104106
print(io, "$(length(a))-element device array at $(pointer(a))")
105107
Base.show(io::IO, a::ROCDeviceArray) =
106-
print(io, "$(join(a.shape, '×')) device array at $(pointer(a))")
108+
print(io, "$(join(a.dims, '×')) device array at $(pointer(a))")
107109

108110
Base.show(io::IO, a::SubArray{T,D,P,I,F}) where {T,D,P<:ROCDeviceVector,I,F} =
109111
print(io, "$(length(a.indices[1]))-element device array view(::$P at $(pointer(parent(a))), $(a.indices[1])) with eltype $T")
@@ -113,12 +115,12 @@ Base.show(io::IO, a::SubArray{T,D,P,I,F}) where {T,D,P<:ROCDeviceArray,I,F} =
113115
Base.show(io::IO, a::S) where S<:AnyROCDeviceVector =
114116
print(io, "$(length(a))-element device array wrapper $S at $(pointer(parent(a)))")
115117
Base.show(io::IO, a::S) where S<:AnyROCDeviceArray =
116-
print(io, "$(join(parent(a).shape, '×')) device array wrapper $S at $(pointer(parent(a)))")
118+
print(io, "$(join(parent(a).dims, '×')) device array wrapper $S at $(pointer(parent(a)))")
117119

118120
Base.show(io::IO, mime::MIME"text/plain", a::S) where S<:AnyROCDeviceArray = show(io, a)
119121

120122
@inline function Base.unsafe_view(A::ROCDeviceVector{T}, I::Vararg{Base.ViewIndex,1}) where {T}
121-
ptr = pointer(A) + (I[1].start-1)*sizeof(T)
123+
ptr = pointer(A) + (I[1].start - 1) * sizeof(T)
122124
len = I[1].stop - I[1].start + 1
123125

124126
return ROCDeviceArray(len, ptr)

test/device/array.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
@testset "ROCDeviceArray array interface" begin
2+
x = ROCArray(zeros(Int, 1, 2, 3, 4))
3+
xd = rocconvert(x)
4+
5+
@test typeof(xd) <: DenseArray
6+
@test hasproperty(xd, :dims)
7+
8+
@test size(xd) == size(x)
9+
@test length(xd) == length(x)
10+
@test strides(xd) == strides(x)
11+
@test Base.elsize(xd) == Base.elsize(x)
12+
@test sizeof(xd) == sizeof(x)
13+
14+
@test Int(pointer(xd)) == Int(pointer(x))
15+
end
16+
117
@testset "ROCDeviceArray" begin
218
RA = ROCArray(rand(4,4))
319
RD = rocconvert(RA)

0 commit comments

Comments
 (0)