Skip to content

Commit 3411efc

Browse files
authored
Add buffer objects for collective operations (#335)
* Add ChunkBuffers for collective operations * fix deprecations, get tests passing * remove old deprecations * update docs * update tests for deprecations
1 parent 502717a commit 3411efc

20 files changed

+783
-513
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ makedocs(
5151
"library.md",
5252
"environment.md",
5353
"comm.md",
54+
"buffers.md",
5455
"pointtopoint.md",
5556
"collective.md",
5657
"onesided.md",

docs/src/advanced.md

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,6 @@
66
MPI.free
77
```
88

9-
## Buffers
10-
11-
```@docs
12-
MPI.Buffer
13-
MPI.Buffer_send
14-
MPI.MPIPtr
15-
```
16-
179
## Datatype objects
1810

1911
```@docs

docs/src/buffers.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Buffers
2+
3+
Buffers are used for sending and receiving data. MPI.jl provides the following buffer types:
4+
5+
```@docs
6+
MPI.IN_PLACE
7+
MPI.Buffer
8+
MPI.Buffer_send
9+
MPI.UBuffer
10+
MPI.VBuffer
11+
MPI.RBuffer
12+
MPI.MPIPtr
13+
```

docs/src/collective.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ MPI.Barrier
1010

1111
```@docs
1212
MPI.Bcast!
13+
MPI.bcast
1314
```
1415

1516
## Gather/Scatter
@@ -31,9 +32,7 @@ MPI.Gatherv
3132

3233
```@docs
3334
MPI.Scatter!
34-
MPI.Scatter
3535
MPI.Scatterv!
36-
MPI.Scatterv
3736
```
3837

3938
### All-to-all

src/MPI.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Libdl, Serialization
44
using Requires
55
using DocStringExtensions
66

7-
export mpiexec
7+
export mpiexec, UBuffer, VBuffer
88

99
function serialize(x)
1010
s = IOBuffer()

src/buffers.jl

Lines changed: 227 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,42 @@ Additionally, certain sentinel values can be used, e.g. `MPI_IN_PLACE` or `MPI_B
5757
"""
5858
MPIPtr
5959

60+
# MPI_IN_PLACE
61+
62+
struct InPlace
63+
end
64+
Base.cconvert(::Type{MPIPtr}, ::InPlace) = MPI_IN_PLACE
65+
66+
67+
"""
68+
MPI.IN_PLACE
69+
70+
A sentinel value that can be passed as a buffer argument for certain collective operations
71+
to use the same buffer for send and receive operations.
72+
73+
- [`Scatter!`](@ref) and [`Scatterv!`](@ref): can be used as the `recvbuf` argument on the
74+
root process.
75+
76+
- [`Gather!`](@ref) and [`Gatherv!`](@ref): can be used as the `sendbuf` argument on the
77+
root process.
78+
79+
- [`Allgather!`](@ref), [`Allgatherv!`](@ref), [`Alltoall!`](@ref) and
80+
[`Alltoallv!`](@ref): can be used as the `sendbuf` argument on all processes.
81+
82+
- [`Reduce!`](@ref) (root only), [`Allreduce!`](@ref), [`Scan!`](@ref) and
83+
[`Exscan!`](@ref): can be used as `sendbuf` argument.
84+
85+
"""
86+
const IN_PLACE = InPlace()
87+
88+
# TODO: MPI_BOTTOM
89+
6090

6191
"""
6292
MPI.Buffer
6393
64-
An MPI buffer for communication operations.
94+
An MPI buffer for communication with a single rank. It is used for point-to-point
95+
communication and some collective operations.
6596
6697
# Fields
6798
$(DocStringExtensions.FIELDS)
@@ -83,6 +114,10 @@ and `datatype`. Methods are provided for
83114
- `SubArray`s of an `Array` or `CUDA.CuArray` where the layout is contiguous, sequential or
84115
blocked.
85116
117+
# See also
118+
119+
- [`Buffer_send`](@ref)
120+
86121
"""
87122
struct Buffer{A}
88123
"""a Julia object referencing a region of memory to be used for communication. It is
@@ -125,6 +160,9 @@ function Buffer(sub::SubArray{T,N,P,I,false}) where {T,N,P,I<:Tuple{Vararg{Union
125160
Buffer(parent(sub), Cint(1), datatype)
126161
end
127162

163+
Buffer(::InPlace) = Buffer(IN_PLACE, 0, DATATYPE_NULL)
164+
Buffer(::Nothing) = Buffer(nothing, 0, DATATYPE_NULL)
165+
128166
"""
129167
Buffer_send(data)
130168
@@ -133,7 +171,194 @@ Construct a [`Buffer`](@ref) object for a send operation from `data`, allowing c
133171
"""
134172
Buffer_send(data) = isbits(data) ? Buffer(Ref(data)) : Buffer(data)
135173
Buffer_send(str::String) = Buffer(str, sizeof(str), MPI.CHAR)
174+
Buffer_send(::InPlace) = Buffer(InPlace())
175+
Buffer_send(::Nothing) = Buffer(nothing)
176+
177+
178+
179+
180+
181+
182+
"""
183+
MPI.UBuffer
184+
185+
An MPI buffer for chunked collective communication, where all chunks are of uniform size.
186+
187+
# Fields
188+
$(DocStringExtensions.FIELDS)
189+
190+
# Usage
191+
192+
UBuffer(data, count::Integer, nchunks::Union{Nothing, Integer}, datatype::Datatype)
193+
194+
Generic constructor.
195+
196+
UBuffer(data, count::Integer)
197+
198+
Construct a `UBuffer` backed by `data`, where `count` is the number of elements in each chunk.
199+
200+
# See also
201+
202+
- [`VBuffer`](@ref): similar, but supports chunks of non-uniform sizes.
203+
"""
204+
struct UBuffer{A}
205+
"""A Julia object referencing a region of memory to be used for communication. It is
206+
required that the object can be `cconvert`ed to an [`MPIPtr`](@ref)."""
207+
data::A
208+
209+
"""The number of elements of `datatype` in each chunk."""
210+
count::Cint
211+
212+
"""The maximum number of chunks stored in the buffer. This is used only for
213+
validation, and can be set to `nothing` to disable checks."""
214+
nchunks::Union{Nothing,Cint}
215+
216+
"""The [`MPI.Datatype`](@ref) stored in the buffer."""
217+
datatype::Datatype
218+
end
219+
UBuffer(data, count::Integer, nchunks::Union{Integer, Nothing}, datatype::Datatype) =
220+
UBuffer(data, Cint(count), nchunks isa Integer ? Cint(nchunks) : nothing, datatype)
221+
222+
function UBuffer(arr::AbstractArray, count::Integer)
223+
@assert stride(arr, 1) == 1
224+
UBuffer(arr, count, div(length(arr), count), Datatype(eltype(arr)))
225+
end
226+
Base.similar(buf::UBuffer) =
227+
UBuffer(similar(buf.data), buf.count, buf.nchunks, buf.datatype)
228+
229+
UBuffer(::Nothing) = UBuffer(nothing, 0, nothing, DATATYPE_NULL)
230+
UBuffer(::InPlace) = UBuffer(IN_PLACE, 0, nothing, DATATYPE_NULL)
231+
232+
233+
234+
"""
235+
MPI.VBuffer
236+
237+
An MPI buffer for chunked collective communication, where chunks can be of different sizes and at different offsets.
238+
239+
240+
# Fields
241+
$(DocStringExtensions.FIELDS)
242+
243+
# Usage
244+
245+
VBuffer(data, counts[, displs[, datatype]])
246+
247+
Construct a `VBuffer` backed by `data`, where `counts[j]` is the number of elements in the
248+
`j`th chunk, and `displs[j]` is the 0-based displacement. In other words, the `j`th chunk
249+
occurs in indices `displs[j]+1:displs[j]+counts[j]`.
250+
251+
The default value for `displs[j] = sum(counts[1:j-1])`.
252+
253+
# See also
254+
255+
- [`UBuffer`](@ref) when chunks are all of the same size.
256+
"""
257+
struct VBuffer{A}
258+
"""A Julia object referencing a region of memory to be used for communication. It is
259+
required that the object can be `cconvert`ed to an [`MPIPtr`](@ref)."""
260+
data::A
261+
262+
"""An array containing the length of each chunk."""
263+
counts::Vector{Cint}
264+
265+
"""An array containing the (0-based) displacements of each chunk."""
266+
displs::Vector{Cint}
267+
268+
"""The [`MPI.Datatype`](@ref) stored in the buffer."""
269+
datatype::Datatype
270+
end
271+
VBuffer(data, counts, displs, datatype::Datatype) =
272+
VBuffer(data, convert(Vector{Cint}, counts),
273+
convert(Vector{Cint}, displs), datatype)
274+
VBuffer(data, counts, displs) =
275+
VBuffer(data, counts, displs, Datatype(eltype(data)))
276+
277+
function VBuffer(arr::AbstractArray, counts)
278+
@assert stride(arr,1) == 1
279+
counts = convert(Vector{Cint}, counts)
280+
displs = similar(counts)
281+
d = zero(Cint)
282+
for i in eachindex(displs)
283+
displs[i] = d
284+
d += counts[i]
285+
end
286+
@assert length(arr) >= d
287+
VBuffer(arr, counts, displs, Datatype(eltype(arr)))
288+
end
289+
290+
VBuffer(::Nothing) = VBuffer(nothing, Cint[], Cint[], DATATYPE_NULL)
291+
VBuffer(::InPlace) = VBuffer(IN_PLACE, Cint[], Cint[], DATATYPE_NULL)
292+
293+
294+
"""
295+
MPI.RBuffer
296+
297+
An MPI buffer for reduction operations ([`MPI.Reduce!`](@ref), [`MPI.Allreduce!`](@ref), [`MPI.Scan!`](@ref), [`MPI.Exscan!`](@ref)).
298+
299+
# Fields
300+
$(DocStringExtensions.FIELDS)
136301
302+
# Usage
303+
304+
RBuffer(senddata, recvdata[, count, datatype])
305+
306+
Generic constructor.
307+
308+
RBuffer(senddata, recvdata)
309+
310+
Construct a `Buffer` backed by `senddata` and `recvdata`, automatically determining the
311+
appropriate `count` and `datatype`.
312+
313+
- `senddata` can be [`MPI.IN_PLACE`](@ref)
314+
- `recvdata` can be `nothing` on a non-root node with [`MPI.Reduce!`](@ref)
315+
"""
316+
struct RBuffer{S,R}
317+
"""A Julia object referencing a region of memory to be used for the send buffer. It is
318+
required that the object can be `cconvert`ed to an [`MPIPtr`](@ref)."""
319+
senddata::S
320+
321+
"""A Julia object referencing a region of memory to be used for the receive buffer. It is
322+
required that the object can be `cconvert`ed to an [`MPIPtr`](@ref)."""
323+
recvdata::R
324+
325+
"""the number of elements of `datatype` in the buffer. Note that this may not
326+
correspond to the number of elements in the array if derived types are used."""
327+
count::Cint
328+
329+
"""the [`MPI.Datatype`](@ref) stored in the buffer."""
330+
datatype::Datatype
331+
end
332+
333+
RBuffer(senddata, recvdata, count::Integer, datatype::Datatype) =
334+
RBuffer(senddata, recvdata, Cint(count), datatype)
335+
336+
function RBuffer(senddata::AbstractArray{T}, recvdata::AbstractArray{T}) where {T}
337+
@assert (count = length(senddata)) == length(recvdata)
338+
@assert stride(senddata,1) == stride(recvdata,1) == 1
339+
RBuffer(senddata, recvdata, count, Datatype(T))
340+
end
341+
function RBuffer(::InPlace, recvdata::AbstractArray{T}) where {T}
342+
count = length(recvdata)
343+
@assert stride(recvdata,1) == 1
344+
RBuffer(IN_PLACE, recvdata, count, Datatype(T))
345+
end
346+
function RBuffer(senddata::AbstractArray{T}, recvdata::Nothing) where {T}
347+
count = length(senddata)
348+
@assert stride(senddata,1) == 1
349+
RBuffer(senddata, nothing, count, Datatype(T))
350+
end
351+
352+
function RBuffer(senddata::Ref{T}, recvdata::Ref{T}) where {T}
353+
RBuffer(senddata, recvdata, 1, Datatype(T))
354+
end
355+
function RBuffer(senddata::InPlace, recvdata::Ref{T}) where {T}
356+
RBuffer(IN_PLACE, recvdata, 1, Datatype(T))
357+
end
358+
function RBuffer(senddata::Ref{T}, recvdata::Nothing) where {T}
359+
RBuffer(senddata, nothing, 1, Datatype(T))
360+
end
137361

138362

139-
const BUFFER_NULL = Buffer(C_NULL, 0, DATATYPE_NULL)
363+
Base.eltype(rbuf::RBuffer) = eltype(rbuf.senddata)
364+
Base.eltype(rbuf::RBuffer{InPlace}) = eltype(rbuf.recvdata)

0 commit comments

Comments
 (0)