|
1 | | -import .CuArrays: CuArray |
2 | | -import .CuArrays.CUDAdrv: CuPtr, synchronize |
3 | | -import .CuArrays.CUDAdrv.Mem: DeviceBuffer |
| 1 | +import .CUDA |
4 | 2 |
|
5 | | - |
6 | | -function Base.cconvert(::Type{MPIPtr}, buf::CuArray{T}) where T |
7 | | - Base.cconvert(CuPtr{T}, buf) # returns DeviceBuffer |
| 3 | +function Base.cconvert(::Type{MPIPtr}, buf::CUDA.CuArray{T}) where T |
| 4 | + Base.cconvert(CUDA.CuPtr{T}, buf) # returns DeviceBuffer |
8 | 5 | end |
9 | 6 |
|
10 | | -# CuArrays <= v1.3 |
11 | | -function Base.unsafe_convert(::Type{MPIPtr}, buf::DeviceBuffer) |
12 | | - reinterpret(MPIPtr, buf.ptr) |
13 | | -end |
14 | | -# CuArrays > v1.3 |
15 | | -function Base.unsafe_convert(::Type{MPIPtr}, X::CuArray{T}) where T |
16 | | - reinterpret(MPIPtr, Base.unsafe_convert(CuPtr{T}, X)) |
| 7 | +function Base.unsafe_convert(::Type{MPIPtr}, X::CUDA.CuArray{T}) where T |
| 8 | + reinterpret(MPIPtr, Base.unsafe_convert(CUDA.CuPtr{T}, X)) |
17 | 9 | end |
| 10 | + |
18 | 11 | # only need to define this for strided arrays: all others can be handled by generic machinery |
19 | | -function Base.unsafe_convert(::Type{MPIPtr}, V::SubArray{T,N,P,I,true}) where {T,N,P<:CuArray,I} |
| 12 | +function Base.unsafe_convert(::Type{MPIPtr}, V::SubArray{T,N,P,I,true}) where {T,N,P<:CUDA.CuArray,I} |
20 | 13 | X = parent(V) |
21 | | - pX = Base.unsafe_convert(CuPtr{T}, X) |
| 14 | + pX = Base.unsafe_convert(CUDA.CuPtr{T}, X) |
22 | 15 | pV = pX + ((V.offset1 + V.stride1) - first(LinearIndices(X)))*sizeof(T) |
23 | 16 | return reinterpret(MPIPtr, pV) |
24 | 17 | end |
25 | 18 |
|
26 | | -function Buffer(arr::CuArray) |
| 19 | +function Buffer(arr::CUDA.CuArray) |
27 | 20 | Buffer(arr, Cint(length(arr)), Datatype(eltype(arr))) |
28 | 21 | end |
0 commit comments