Skip to content

Commit 0a30948

Browse files
kose-ysimonbyrne
authored andcommitted
CUDA-aware MPI (#286)
Buffers are passed using an `MPIPtr` type, so CuArrays can be passed using implicit conversion.
1 parent 14adebf commit 0a30948

File tree

7 files changed

+229
-154
lines changed

7 files changed

+229
-154
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.9.0"
77
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
88
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
99
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
10+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1011
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
1112
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
1213

src/MPI.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module MPI
22

33
using Libdl, Serialization
4+
using Requires
45

56
macro mpichk(expr)
67
@assert expr isa Expr && expr.head == :call && expr.args[1] == :ccall
@@ -24,7 +25,6 @@ end
2425
primitive type SentinelPtr
2526
Sys.WORD_SIZE
2627
end
27-
Base.cconvert(::Type{Ptr{T}}, sptr::SentinelPtr) where {T} = reinterpret(Ptr{T}, sptr)
2828

2929
function _doc_external(fname)
3030
"""
@@ -65,6 +65,8 @@ function __init__()
6565
if filesize(dlpath(libmpi)) != libmpi_size
6666
error("MPI library has changed, re-run Pkg.build(\"MPI\")")
6767
end
68+
69+
@require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda.jl")
6870
end
6971

7072
end

src/collective.jl

Lines changed: 138 additions & 103 deletions
Large diffs are not rendered by default.

src/cuda.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import .CuArrays: CuArray
2+
import .CuArrays.CUDAdrv: CuPtr, synchronize
3+
import .CuArrays.CUDAdrv.Mem: DeviceBuffer
4+
5+
function Base.cconvert(::Type{MPIPtr}, buf::CuArray{T}) where T
6+
Base.cconvert(CuPtr{T}, buf) # returns DeviceBuffer
7+
end
8+
9+
function Base.unsafe_convert(::Type{MPIPtr}, buf::DeviceBuffer)
10+
reinterpret(MPIPtr, buf.ptr)
11+
end

src/datatypes.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,24 @@ MPIBuffertype{T} = Union{Ptr{T}, Array{T}, SubArray{T}, Ref{T}}
1515

1616
MPIBuffertypeOrConst{T} = Union{MPIBuffertype{T}, SentinelPtr}
1717

18+
if sizeof(Ptr{Cvoid}) == 8
19+
primitive type MPIPtr 64 end
20+
else
21+
primitive type MPIPtr 32 end
22+
end
23+
24+
MPIPtr(x::Cint) where T = reinterpret(MPIPtr, x)
25+
26+
Base.cconvert(::Type{MPIPtr}, x::MPIBuffertype{T}) where T = Base.cconvert(Ptr{T}, x)
27+
28+
function Base.unsafe_convert(::Type{MPIPtr}, x::MPIBuffertype{T}) where T
29+
ptr = Base.unsafe_convert(Ptr{T}, x)
30+
reinterpret(MPIPtr, ptr)
31+
end
32+
33+
Base.cconvert(::Type{MPIPtr}, x::SentinelPtr) = x
34+
Base.unsafe_convert(::Type{MPIPtr}, x::SentinelPtr) = reinterpret(MPIPtr, x)
35+
1836
fieldoffsets(::Type{T}) where {T} = Int[fieldoffset(T, i) for i in 1:length(fieldnames(T))]
1937

2038
# Define a function mpitype(T) that returns the MPI datatype code for

src/pointtopoint.jl

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import Base: eltype
12

23
# definition of the `Status` struct.
34
# TODO: this bakes in a lot of assumptions about ordering and padding
@@ -94,12 +95,12 @@ end
9495
Complete a blocking send of `count` elements of type `datatype` from `buf` to MPI
9596
rank `dest` of communicator `comm` using the message tag `tag`
9697
"""
97-
function Send(buf::MPIBuffertype{T}, count::Integer, datatype::Union{Datatype, MPI_Datatype},
98-
dest::Integer, tag::Integer, comm::Comm) where T
98+
function Send(buf, count::Integer, datatype::Union{Datatype, MPI_Datatype},
99+
dest::Integer, tag::Integer, comm::Comm)
99100
# int MPI_Send(const void* buf, int count, MPI_Datatype datatype, int dest,
100101
# int tag, MPI_Comm comm)
101102
@mpichk ccall((:MPI_Send, libmpi), Cint,
102-
(Ptr{T}, Cint, MPI_Datatype, Cint, Cint, MPI_Comm),
103+
(MPIPtr, Cint, MPI_Datatype, Cint, Cint, MPI_Comm),
103104
buf, count, datatype, dest, tag, comm)
104105
end
105106

@@ -110,18 +111,18 @@ end
110111
Complete a blocking send of `count` elements of `buf` to MPI rank `dest`
111112
of communicator `comm` using with the message tag `tag`
112113
"""
113-
function Send(buf::MPIBuffertype{T}, count::Integer, dest::Integer,
114-
tag::Integer, comm::Comm) where T
115-
Send(buf, count, mpitype(T), dest, tag, comm)
114+
function Send(buf, count::Integer, dest::Integer,
115+
tag::Integer, comm::Comm)
116+
Send(buf, count, mpitype(eltype(buf)), dest, tag, comm)
116117
end
117118

118119
"""
119-
Send(buf::Array{T}, dest::Integer, tag::Integer, comm::Comm) where T
120+
Send(buf::AbstractArray{T}, dest::Integer, tag::Integer, comm::Comm) where T
120121
121122
Complete a blocking send of `buf` to MPI rank `dest` of communicator `comm`
122123
using with the message tag `tag`
123124
"""
124-
function Send(buf::Array{T}, dest::Integer, tag::Integer, comm::Comm) where T
125+
function Send(buf::AbstractArray{T}, dest::Integer, tag::Integer, comm::Comm) where T
125126
Send(buf, length(buf), dest, tag, comm)
126127
end
127128

@@ -167,13 +168,13 @@ MPI rank `dest` of communicator `comm` using with the message tag `tag`
167168
168169
Returns the commication `Request` for the nonblocking send.
169170
"""
170-
function Isend(buf::MPIBuffertype{T}, count::Integer, datatype::Union{Datatype, MPI_Datatype},
171-
dest::Integer, tag::Integer, comm::Comm) where T
171+
function Isend(buf, count::Integer, datatype::Union{Datatype, MPI_Datatype},
172+
dest::Integer, tag::Integer, comm::Comm)
172173
req = Request()
173174
# int MPI_Isend(const void* buf, int count, MPI_Datatype datatype, int dest,
174175
# int tag, MPI_Comm comm, MPI_Request *request)
175176
@mpichk ccall((:MPI_Isend, libmpi), Cint,
176-
(Ptr{T}, Cint, MPI_Datatype, Cint, Cint, MPI_Comm, Ptr{MPI_Request}),
177+
(MPIPtr, Cint, MPI_Datatype, Cint, Cint, MPI_Comm, Ptr{MPI_Request}),
177178
buf, count, datatype, dest, tag, comm, req)
178179
req.buffer = buf
179180
return req
@@ -188,9 +189,9 @@ of communicator `comm` using with the message tag `tag`
188189
189190
Returns the commication `Request` for the nonblocking send.
190191
"""
191-
function Isend(buf::MPIBuffertype{T}, count::Integer,
192-
dest::Integer, tag::Integer, comm::Comm) where T
193-
Isend(buf, count, mpitype(T), dest, tag, comm)
192+
function Isend(buf, count::Integer,
193+
dest::Integer, tag::Integer, comm::Comm)
194+
Isend(buf, count, mpitype(eltype(buf)), dest, tag, comm)
194195
end
195196

196197
"""
@@ -201,8 +202,8 @@ using with the message tag `tag`
201202
202203
Returns the commication `Request` for the nonblocking send.
203204
"""
204-
function Isend(buf::Array{T}, dest::Integer, tag::Integer, comm::Comm) where T
205-
Isend(buf, length(buf), dest, tag, comm)
205+
function Isend(buf::AbstractArray{T}, dest::Integer, tag::Integer, comm::Comm) where T
206+
Isend(buf, length(buf), mpitype(T), dest, tag, comm)
206207
end
207208

208209
"""
@@ -216,7 +217,7 @@ Returns the commication `Request` for the nonblocking send.
216217
function Isend(buf::SubArray{T}, dest::Integer, tag::Integer,
217218
comm::Comm) where T
218219
@assert Base.iscontiguous(buf)
219-
Isend(buf, length(buf), dest, tag, comm)
220+
Isend(buf, length(buf), mpitype(T), dest, tag, comm)
220221
end
221222

222223
"""
@@ -254,13 +255,13 @@ from MPI rank `src` of communicator `comm` using with the message tag `tag`
254255
255256
Returns the `Status` of the receive
256257
"""
257-
function Recv!(buf::MPIBuffertype{T}, count::Integer, datatype::Union{Datatype,MPI_Datatype}, src::Integer,
258-
tag::Integer, comm::Comm) where T
258+
function Recv!(buf, count::Integer, datatype::Union{Datatype,MPI_Datatype}, src::Integer,
259+
tag::Integer, comm::Comm)
259260
stat_ref = Ref{Status}()
260261
# int MPI_Recv(void* buf, int count, MPI_Datatype datatype, int source,
261262
# int tag, MPI_Comm comm, MPI_Status *status)
262263
@mpichk ccall((:MPI_Recv, libmpi), Cint,
263-
(Ptr{T}, Cint, MPI_Datatype, Cint, Cint, MPI_Comm, Ptr{Status}),
264+
(MPIPtr, Cint, MPI_Datatype, Cint, Cint, MPI_Comm, Ptr{Status}),
264265
buf, count, datatype, src, tag, comm, stat_ref)
265266
return stat_ref[]
266267
end
@@ -274,9 +275,9 @@ Completes a blocking receive of up to `count` elements into `buf` from MPI rank
274275
275276
Returns the `Status` of the receive
276277
"""
277-
function Recv!(buf::MPIBuffertype{T}, count::Integer, src::Integer,
278-
tag::Integer, comm::Comm) where T
279-
Recv!(buf, count, mpitype(T), src, tag, comm)
278+
function Recv!(buf, count::Integer, src::Integer,
279+
tag::Integer, comm::Comm)
280+
Recv!(buf, count, mpitype(eltype(buf)), src, tag, comm)
280281
end
281282

282283

@@ -288,7 +289,7 @@ Completes a blocking receive into `buf` from MPI rank `src` of communicator
288289
289290
Returns the `Status` of the receive
290291
"""
291-
function Recv!(buf::Array{T}, src::Integer, tag::Integer, comm::Comm) where T
292+
function Recv!(buf::AbstractArray{T}, src::Integer, tag::Integer, comm::Comm) where T
292293
Recv!(buf, length(buf), src, tag, comm)
293294
end
294295

@@ -329,13 +330,13 @@ from MPI rank `src` of communicator `comm` using with the message tag `tag`
329330
330331
Returns the communication `Request` for the nonblocking receive.
331332
"""
332-
function Irecv!(buf::MPIBuffertype{T}, count::Integer, datatype::Union{Datatype, MPI_Datatype},
333+
function Irecv!(buf, count::Integer, datatype::Union{Datatype, MPI_Datatype},
333334
src::Integer, tag::Integer, comm::Comm) where T
334335
req = Request()
335336
# int MPI_Irecv(void* buf, int count, MPI_Datatype datatype, int source,
336337
# int tag, MPI_Comm comm, MPI_Request *request)
337338
@mpichk ccall((:MPI_Irecv, libmpi), Cint,
338-
(Ptr{T}, Cint, MPI_Datatype, Cint, Cint, MPI_Comm, Ptr{MPI_Request}),
339+
(MPIPtr, Cint, MPI_Datatype, Cint, Cint, MPI_Comm, Ptr{MPI_Request}),
339340
buf, count, datatype, src, tag, comm, req)
340341
req.buffer = buf
341342
return req
@@ -350,9 +351,9 @@ from MPI rank `src` of communicator `comm` using with the message tag `tag`
350351
351352
Returns the communication `Request` for the nonblocking receive.
352353
"""
353-
function Irecv!(buf::MPIBuffertype{T}, count::Integer,
354-
src::Integer, tag::Integer, comm::Comm) where T
355-
Irecv!(buf, count, mpitype(T), src, tag, comm)
354+
function Irecv!(buf, count::Integer,
355+
src::Integer, tag::Integer, comm::Comm)
356+
Irecv!(buf, count, mpitype(eltype(buf)), src, tag, comm)
356357
end
357358

358359
"""
@@ -363,9 +364,9 @@ Starts a nonblocking receive into `buf` from MPI rank `src` of communicator
363364
364365
Returns the communication `Request` for the nonblocking receive.
365366
"""
366-
function Irecv!(buf::Array{T}, src::Integer, tag::Integer,
367+
function Irecv!(buf::AbstractArray{T}, src::Integer, tag::Integer,
367368
comm::Comm) where T
368-
Irecv!(buf, length(buf), src, tag, comm)
369+
Irecv!(buf, length(buf), mpitype(T), src, tag, comm)
369370
end
370371

371372
"""
@@ -380,7 +381,7 @@ Returns the communication `Request` for the nonblocking receive.
380381
function Irecv!(buf::SubArray{T}, src::Integer, tag::Integer,
381382
comm::Comm) where T
382383
@assert Base.iscontiguous(buf)
383-
Irecv!(buf, length(buf), src, tag, comm)
384+
Irecv!(buf, length(buf), mpitype(T), src, tag, comm)
384385
end
385386

386387
function irecv(src::Integer, tag::Integer, comm::Comm)

src/window.jl

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ This is a collective call over `comm`.
2121
2222
[`MPI.free`](@ref) should be called on the `Win` object once operations have been completed.
2323
"""
24-
function Win_create(base::Array{T}, comm::Comm; infokws...) where T
24+
function Win_create(base::AbstractArray{T}, comm::Comm; infokws...) where T
2525
win = Win()
2626
# int MPI_Win_create(void *base, MPI_Aint size, int disp_unit, MPI_Info info,
2727
# MPI_Comm comm, MPI_Win *win)
2828
@mpichk ccall((:MPI_Win_create, libmpi), Cint,
29-
(Ptr{T}, Cptrdiff_t, Cint, MPI_Info, MPI_Comm, Ptr{MPI_Win}),
29+
(MPIPtr, Cptrdiff_t, Cint, MPI_Info, MPI_Comm, Ptr{MPI_Win}),
3030
base, Cptrdiff_t(length(base)*sizeof(T)), sizeof(T), Info(infokws...), comm, win)
3131
refcount_inc()
3232
finalizer(free, win)
@@ -106,17 +106,17 @@ function Win_shared_query(win::Win, owner_rank::Int)
106106
out_len[], out_sizeT[], out_baseptr[]
107107
end
108108

109-
function Win_attach(win::Win, base::Array{T}) where T
109+
function Win_attach(win::Win, base::AbstractArray{T}) where T
110110
# int MPI_Win_attach(MPI_Win win, void *base, MPI_Aint size)
111111
@mpichk ccall((:MPI_Win_attach, libmpi), Cint,
112-
(MPI_Win, Ptr{T}, Cptrdiff_t),
112+
(MPI_Win, MPIPtr, Cptrdiff_t),
113113
win, base, Cptrdiff_t(sizeof(base)))
114114
end
115115

116-
function Win_detach(win::Win, base::Array{T}) where T
116+
function Win_detach(win::Win, base::AbstractArray{T}) where T
117117
# int MPI_Win_detach(MPI_Win win, const void *base)
118118
@mpichk ccall((:MPI_Win_detach, libmpi), Cint,
119-
(MPI_Win, Ptr{T}),
119+
(MPI_Win, MPIPtr),
120120
win, base)
121121
end
122122

@@ -147,66 +147,73 @@ function Win_unlock(rank::Integer, win::Win)
147147
@mpichk ccall((:MPI_Win_unlock, libmpi), Cint, (Cint, MPI_Win), rank, win)
148148
end
149149

150-
function Get(origin_buffer::MPIBuffertype{T}, count::Integer, target_rank::Integer, target_disp::Integer, win::Win) where T
150+
function Get(origin_buffer, count::Integer, target_rank::Integer, target_disp::Integer, win::Win)
151+
T = eltype(origin_buffer)
151152
# int MPI_Get(void *origin_addr, int origin_count,
152153
# MPI_Datatype origin_datatype, int target_rank,
153154
# MPI_Aint target_disp, int target_count,
154155
# MPI_Datatype target_datatype, MPI_Win win)
155156
@mpichk ccall((:MPI_Get, libmpi), Cint,
156-
(Ptr{T}, Cint, MPI_Datatype, Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Win),
157+
(MPIPtr, Cint, MPI_Datatype, Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Win),
157158
origin_buffer, count, mpitype(T), target_rank, Cptrdiff_t(target_disp), count, mpitype(T), win)
158159
end
159-
function Get(origin_buffer::Array{T}, target_rank::Integer, win::Win) where T
160+
function Get(origin_buffer::AbstractArray{T}, target_rank::Integer, win::Win) where T
160161
count = length(origin_buffer)
161162
Get(origin_buffer, count, target_rank, 0, win)
162163
end
163164
function Get(origin_value::Ref{T}, target_rank::Integer, win::Win) where T
164165
Get(origin_value, 1, target_rank, 0, win)
165166
end
166167

167-
function Put(origin_buffer::MPIBuffertype{T}, count::Integer, target_rank::Integer, target_disp::Integer, win::Win) where T
168+
function Put(origin_buffer, count::Integer, target_rank::Integer, target_disp::Integer, win::Win)
168169
# int MPI_Put(const void *origin_addr, int origin_count,
169170
# MPI_Datatype origin_datatype, int target_rank,
170171
# MPI_Aint target_disp, int target_count,
171172
# MPI_Datatype target_datatype, MPI_Win win)
173+
T = eltype(origin_buffer)
172174
@mpichk ccall((:MPI_Put, libmpi), Cint,
173-
(Ptr{T}, Cint, MPI_Datatype, Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Win),
175+
(MPIPtr, Cint, MPI_Datatype, Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Win),
174176
origin_buffer, count, mpitype(T), target_rank, Cptrdiff_t(target_disp), count, mpitype(T), win)
175177
end
176-
function Put(origin_buffer::Array{T}, target_rank::Integer, win::Win) where T
178+
function Put(origin_buffer::AbstractArray{T}, target_rank::Integer, win::Win) where T
177179
count = length(origin_buffer)
178180
Put(origin_buffer, count, target_rank, 0, win)
179181
end
180182
function Put(origin_value::Ref{T}, target_rank::Integer, win::Win) where T
181183
Put(origin_value, 1, target_rank, 0, win)
182184
end
183185

184-
function Fetch_and_op(sourceval::MPIBuffertype{T}, returnval::MPIBuffertype{T}, target_rank::Integer, target_disp::Integer, op::Op, win::Win) where T
186+
function Fetch_and_op(sourceval, returnval, target_rank::Integer, target_disp::Integer, op::Op, win::Win)
185187
# int MPI_Fetch_and_op(const void *origin_addr, void *result_addr,
186188
# MPI_Datatype datatype, int target_rank, MPI_Aint target_disp,
187189
# MPI_Op op, MPI_Win win)
190+
@assert eltype(sourceval) == eltype(returnval)
191+
T = eltype(sourceval)
188192
@mpichk ccall((:MPI_Fetch_and_op, libmpi), Cint,
189-
(Ptr{T}, Ptr{T}, MPI_Datatype, Cint, Cptrdiff_t, MPI_Op, MPI_Win),
193+
(MPIPtr, MPIPtr, MPI_Datatype, Cint, Cptrdiff_t, MPI_Op, MPI_Win),
190194
sourceval, returnval, mpitype(T), target_rank, target_disp, op, win)
191195
end
192196

193-
function Accumulate(origin_buffer::MPIBuffertype{T}, count::Integer, target_rank::Integer, target_disp::Integer, op::Op, win::Win) where T
197+
function Accumulate(origin_buffer, count::Integer, target_rank::Integer, target_disp::Integer, op::Op, win::Win)
194198
# int MPI_Accumulate(const void *origin_addr, int origin_count,
195199
# MPI_Datatype origin_datatype, int target_rank,
196200
# MPI_Aint target_disp, int target_count,
197201
# MPI_Datatype target_datatype, MPI_Op op, MPI_Win win)
202+
T = eltype(origin_buffer)
198203
@mpichk ccall((:MPI_Accumulate, libmpi), Cint,
199-
(Ptr{T}, Cint, MPI_Datatype, Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Op, MPI_Win),
204+
(MPIPtr, Cint, MPI_Datatype, Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Op, MPI_Win),
200205
origin_buffer, count, mpitype(T), target_rank, Cptrdiff_t(target_disp), count, mpitype(T), op, win)
201206
end
202207

203-
function Get_accumulate(origin_buffer::MPIBuffertype{T}, result_buffer::MPIBuffertype{T}, count::Integer, target_rank::Integer, target_disp::Integer, op::Op, win::Win) where T
208+
function Get_accumulate(origin_buffer, result_buffer, count::Integer, target_rank::Integer, target_disp::Integer, op::Op, win::Win)
204209
# int MPI_Get_accumulate(const void *origin_addr, int origin_count,
205210
# MPI_Datatype origin_datatype, void *result_addr,
206211
# int result_count, MPI_Datatype result_datatype,
207212
# int target_rank, MPI_Aint target_disp, int target_count,
208213
# MPI_Datatype target_datatype, MPI_Op op, MPI_Win win)
214+
@assert eltype(origin_buffer) == eltype(result_buffer)
215+
T = eltype(origin_buffer)
209216
@mpichk ccall((:MPI_Get_accumulate, libmpi), Cint,
210-
(Ptr{T}, Cint, MPI_Datatype, Ptr{T}, Cint, MPI_Datatype, Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Op, MPI_Win),
217+
(MPIPtr, Cint, MPI_Datatype, MPIPtr, Cint, MPI_Datatype, Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Op, MPI_Win),
211218
origin_buffer, count, mpitype(T), result_buffer, count, mpitype(T), target_rank, Cptrdiff_t(target_disp), count, mpitype(T), op, win)
212219
end

0 commit comments

Comments
 (0)