diff --git a/Project.toml b/Project.toml index 7e3d66252..ac4123773 100644 --- a/Project.toml +++ b/Project.toml @@ -34,15 +34,19 @@ Requires = "~0.5, 1.0" Serialization = "1" Sockets = "1" julia = "1.6" +oneAPI = "2.1" [extensions] AMDGPUExt = "AMDGPU" CUDAExt = "CUDA" +OneAPIExt = "oneAPI" [extras] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" diff --git a/ext/OneAPIExt.jl b/ext/OneAPIExt.jl new file mode 100644 index 000000000..6891aaf43 --- /dev/null +++ b/ext/OneAPIExt.jl @@ -0,0 +1,27 @@ +module OneAPIExt + +import MPI +isdefined(Base, :get_extension) ? (import oneAPI) : (import ..oneAPI) +import MPI: MPIPtr, Buffer, Datatype + +function Base.cconvert(::Type{MPIPtr}, A::oneAPI.oneArray{T}) where T + A +end + +function Base.unsafe_convert(::Type{MPIPtr}, X::oneAPI.oneArray{T}) where T + reinterpret(MPIPtr, Base.unsafe_convert(oneAPI.ZePtr{T}, X)) +end + +# only need to define this for strided arrays: all others can be handled by generic machinery +function Base.unsafe_convert(::Type{MPIPtr}, V::SubArray{T,N,P,I,true}) where {T,N,P<:oneAPI.oneArray,I} + X = parent(V) + pX = Base.unsafe_convert(oneAPI.ZePtr{T}, X) + pV = pX + ((V.offset1 + V.stride1) - first(LinearIndices(X)))*sizeof(T) + return reinterpret(MPIPtr, pV) +end + +function Buffer(arr::oneAPI.oneArray) + Buffer(arr, Cint(length(arr)), Datatype(eltype(arr))) +end + +end # OneAPIExt