Skip to content

Commit 5be8042

Browse files
committed
tweak MPIPtr, modify and enable CUDA tests, add some docs
- A few small tweaks to the MPIPtr implementation (which we can also use for SubArray contiguous checks) - Modify the tests to support CuArrays where possible - Enable CuArray tests on the buildbot - Add some docs
1 parent 0a30948 commit 5be8042

26 files changed

+504
-550
lines changed

.gitlab-ci.yml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,16 @@ include:
2020
- popd
2121
script:
2222
- export JULIA_MPI_PATH="${HOME}/mpi"
23+
- export JULIA_PROJECT="test/cudaenv"
2324
- ${JULIA_MPI_PATH}/bin/ompi_info
2425
- julia -e 'using InteractiveUtils;
2526
versioninfo()'
26-
- julia --project -e 'using Pkg;
27-
Pkg.build();
28-
Pkg.test(coverage=true)'
27+
- julia --color=yes -e 'using Pkg;
28+
Pkg.develop(PackageSpec(path=pwd()));
29+
Pkg.instantiate();
30+
Pkg.build()'
31+
- julia --color=yes test/runtests.jl
32+
2933
.gputest:
3034
extends: .projecttest
3135
variables:

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@ julia = "1"
1616

1717
[extras]
1818
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
19-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2019
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2120
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2221

2322
[targets]
24-
test = ["DoubleFloats", "LinearAlgebra", "Pkg", "Test"]
23+
test = ["DoubleFloats", "Pkg", "Test"]

docs/src/usage.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ Hello world, I am rank 2 of 3
2222
Hello world, I am rank 1 of 3
2323
```
2424

25+
## CUDA-aware MPI support
26+
27+
If your MPI implementation has been compiled with CUDA support, then `CuArray`s (from the
28+
[CuArrays.jl](https://github.com/JuliaGPU/CuArrays.jl) package) can be passed directly as
29+
send and receive buffers for point-to-point and collective operations (they may also work
30+
with one-sided operations, but these are not often supported).
31+
2532
## Finalizers
2633

2734
In order to ensure MPI routines are called in the correct order at finalization time,

src/MPI.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,15 @@ function deserialize(x)
2222
Serialization.deserialize(s)
2323
end
2424

25-
primitive type SentinelPtr
26-
Sys.WORD_SIZE
25+
primitive type SentinelPtr Sys.WORD_SIZE
2726
end
2827

28+
primitive type MPIPtr Sys.WORD_SIZE
29+
end
30+
Base.cconvert(::Type{MPIPtr}, x::SentinelPtr) = x
31+
Base.unsafe_convert(::Type{MPIPtr}, x::SentinelPtr) = reinterpret(MPIPtr, x)
32+
33+
2934
function _doc_external(fname)
3035
"""
3136
- `$fname` man page: [OpenMPI](https://www.open-mpi.org/doc/current/man3/$fname.3.php), [MPICH](https://www.mpich.org/static/docs/latest/www3/$fname.html)

src/collective.jl

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,6 @@ function Bcast!(buffer::AbstractArray{T}, root::Integer, comm::Comm) where T
3535
Bcast!(buffer, length(buffer), root, comm)
3636
end
3737

38-
function Bcast!(buffer::SubArray{T}, root::Integer, comm::Comm) where T
39-
@assert Base.iscontiguous(buffer)
40-
Bcast!(buffer, length(buffer), root, comm)
41-
end
42-
4338
#=
4439
function Bcast{T}(obj::T, root::Integer, comm::Comm)
4540
buf = [T]
@@ -112,25 +107,20 @@ To specify the output buffer, see [`Reduce!`](@ref).
112107
113108
To perform the reduction in place, see [`Reduce_in_place!`](@ref).
114109
"""
115-
function Reduce(sendbuf::MPIBuffertype{T}, count::Integer,
110+
function Reduce(sendbuf, count::Integer,
116111
op, root::Integer, comm::Comm) where T
117112
isroot = Comm_rank(comm) == root
118-
recvbuf = Array{T}(undef, isroot ? count : 0)
113+
recvbuf = similar(sendbuf, isroot ? count : 0)
119114
Reduce!(sendbuf, recvbuf, count, op, root, comm)
120115
end
121116

122117
function Reduce(sendbuf::AbstractArray{T,N}, op,
123-
root::Integer, comm::Comm) where {T,N}
118+
root::Integer, comm::Comm) where {T,N}
124119
isroot = Comm_rank(comm) == root
125120
recvbuf = similar(sendbuf, isroot ? size(sendbuf) : Tuple(zeros(Int, ndims(sendbuf))))
126121
Reduce!(sendbuf, recvbuf, length(sendbuf), op, root, comm)
127122
end
128123

129-
function Reduce(sendbuf::SubArray{T}, op, root::Integer, comm::Comm) where T
130-
@assert Base.iscontiguous(sendbuf)
131-
Reduce(sendbuf, length(sendbuf), op, root, comm)
132-
end
133-
134124
function Reduce(object::T, op
135125
, root::Integer, comm::Comm) where T
136126
isroot = Comm_rank(comm) == root
@@ -238,12 +228,6 @@ output buffer in all processes of the group.
238228
239229
To specify the output buffer or perform the operation in pace, see [`Allreduce!`](@ref).
240230
"""
241-
function Allreduce(sendbuf::MPIBuffertype{T}, op, comm::Comm) where T
242-
243-
recvbuf = similar(sendbuf)
244-
Allreduce!(sendbuf, recvbuf, length(recvbuf), op, comm)
245-
end
246-
247231
function Allreduce(sendbuf::AbstractArray{T, N}, op, comm::Comm) where {T, N}
248232
recvbuf = similar(sendbuf, size(sendbuf))
249233
Allreduce!(sendbuf, recvbuf, length(sendbuf), op, comm)
@@ -257,14 +241,6 @@ function Allreduce(obj::T, op, comm::Comm) where T
257241
outref[]
258242
end
259243

260-
# Deprecation warning for lowercase allreduce that was used until v. 0.7.2
261-
# Should be removed at some point in the future
262-
function allreduce(sendbuf::MPIBuffertype{T}, op,
263-
comm::Comm) where T
264-
@warn "`allreduce` is deprecated, use `Allreduce` instead."
265-
Allreduce(sendbuf, op, comm)
266-
end
267-
268244

269245
"""
270246
Scatter!(sendbuf, recvbuf, count, root, comm)
@@ -456,10 +432,10 @@ Each process sends the first `count` elements of the buffer `sendbuf` to the
456432
`root` process. The `root` allocates the output buffer and stores elements in
457433
rank order.
458434
"""
459-
function Gather(sendbuf::MPIBuffertype{T}, count::Integer,
435+
function Gather(sendbuf, count::Integer,
460436
root::Integer, comm::Comm) where T
461437
isroot = Comm_rank(comm) == root
462-
recvbuf = Array{T}(undef, isroot ? Comm_size(comm) * count : 0)
438+
recvbuf = similar(sendbuf, isroot ? Comm_size(comm) * count : 0)
463439
Gather!(sendbuf, recvbuf, count, root, comm)
464440
end
465441

@@ -469,11 +445,6 @@ function Gather(sendbuf::AbstractArray{T}, root::Integer, comm::Comm) where T
469445
Gather!(sendbuf, recvbuf, length(sendbuf), root, comm)
470446
end
471447

472-
function Gather(sendbuf::SubArray{T}, root::Integer, comm::Comm) where T
473-
@assert Base.iscontiguous(sendbuf)
474-
Gather(sendbuf, length(sendbuf), root, comm)
475-
end
476-
477448
function Gather(object::T, root::Integer, comm::Comm) where T
478449
isroot = Comm_rank(comm) == root
479450
sendbuf = T[object]
@@ -549,25 +520,17 @@ function Allgather!(buf, count::Integer,
549520
end
550521

551522
"""
552-
Allgather(sendbuf, count, comm)
523+
Allgather(sendbuf[, count=length(sendbuf)], comm)
553524
554525
Each process sends the first `count` elements of `sendbuf` to the
555526
other processes, who store the results in rank order allocating
556527
the output buffer.
557528
"""
558-
function Allgather(sendbuf::MPIBuffertype{T}, count::Integer,
559-
comm::Comm) where T
560-
recvbuf = Array{T}(undef, Comm_size(comm) * count)
529+
function Allgather(sendbuf, count::Integer, comm::Comm)
530+
recvbuf = similar(sendbuf, Comm_size(comm) * count)
561531
Allgather!(sendbuf, recvbuf, count, comm)
562532
end
563-
564-
function Allgather(sendbuf::AbstractArray{T}, comm::Comm) where T
565-
recvbuf = similar(sendbuf, Comm_size(comm) * length(sendbuf))
566-
Allgather!(sendbuf, recvbuf, length(sendbuf), comm)
567-
end
568-
569-
function Allgather(sendbuf::SubArray{T}, comm::Comm) where T
570-
@assert Base.iscontiguous(sendbuf)
533+
function Allgather(sendbuf::AbstractArray, comm::Comm)
571534
Allgather(sendbuf, length(sendbuf), comm)
572535
end
573536

@@ -613,7 +576,7 @@ in rank order.
613576
function Gatherv(sendbuf, counts::Vector{Cint},
614577
root::Integer, comm::Comm)
615578
isroot = Comm_rank(comm) == root
616-
recvbuf = Array{T}(undef, isroot ? sum(counts) : 0)
579+
recvbuf = similar(sendbuf, isroot ? sum(counts) : 0)
617580
Gatherv!(sendbuf, recvbuf, counts, root, comm)
618581
end
619582

@@ -797,24 +760,61 @@ function Scan(sendbuf, count::Integer,
797760
recvbuf
798761
end
799762

800-
function Scan(object::T, op::Union{Op,MPI_Op}, comm::Comm) where T
763+
764+
function Scan!(sendbuf, recvbuf, count::Integer,
765+
op::Union{Op,MPI_Op}, comm::Comm)
766+
T = eltype(sendbuf)
767+
# int MPI_Scan(const void* sendbuf, void* recvbuf, int count,
768+
# MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
769+
@mpichk ccall((:MPI_Scan, libmpi), Cint,
770+
(MPIPtr, MPIPtr, Cint, MPI_Datatype, MPI_Op, MPI_Comm),
771+
sendbuf, recvbuf, count, mpitype(T), op, comm)
772+
recvbuf
773+
end
774+
function Scan!(sendbuf, recvbuf, count::Integer, opfunc, comm::Comm)
775+
Scan!(sendbuf, recvbuf, count, Op(opfunc, eltype(sendbuf)), comm)
776+
end
777+
function Scan!(sendbuf::AbstractArray, recvbuf, op, comm::Comm)
778+
Scan!(sendbuf, recvbuf, length(sendbuf), op, comm)
779+
end
780+
781+
function Scan(sendbuf, count::Integer, op, comm::Comm)
782+
Scan!(sendbuf, similar(sendbuf, count), count, op, comm)
783+
end
784+
785+
function Scan(sendbuf::AbstractArray, op, comm::Comm)
786+
Scan(sendbuf, length(sendbuf), op, comm)
787+
end
788+
function Scan(object::T, op, comm::Comm) where T
801789
sendbuf = T[object]
802790
Scan(sendbuf,1,op,comm)
803791
end
804792

805-
function Exscan(sendbuf, count::Integer,
793+
function Exscan!(sendbuf, recvbuf, count::Integer,
806794
op::Union{Op,MPI_Op}, comm::Comm)
807795
T = eltype(sendbuf)
808-
recvbuf = similar(sendbuf, count)
809796
# int MPI_Exscan(const void* sendbuf, void* recvbuf, int count,
810797
# MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
811798
@mpichk ccall((:MPI_Exscan, libmpi), Cint,
812799
(MPIPtr, MPIPtr, Cint, MPI_Datatype, MPI_Op, MPI_Comm),
813800
sendbuf, recvbuf, count, mpitype(T), op, comm)
814801
recvbuf
815802
end
803+
function Exscan!(sendbuf, recvbuf, count::Integer, opfunc, comm::Comm)
804+
Exscan!(sendbuf, recvbuf, count, Op(opfunc, eltype(sendbuf)), comm)
805+
end
806+
function Exscan!(sendbuf::AbstractArray, recvbuf, op, comm::Comm)
807+
Exscan!(sendbuf, recvbuf, length(sendbuf), op, comm)
808+
end
809+
810+
function Exscan(sendbuf, count::Integer, op, comm::Comm)
811+
Exscan!(sendbuf, similar(sendbuf, count), count, op, comm)
812+
end
816813

817-
function Exscan(object::T, op::Union{Op,MPI_Op}, comm::Comm) where T
814+
function Exscan(sendbuf::AbstractArray, op, comm::Comm)
815+
Exscan(sendbuf, length(sendbuf), op, comm)
816+
end
817+
function Exscan(object::T, op, comm::Comm) where T
818818
sendbuf = T[object]
819819
Exscan(sendbuf,1,op,comm)
820820
end

src/datatypes.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,16 @@ MPIBuffertype{T} = Union{Ptr{T}, Array{T}, SubArray{T}, Ref{T}}
1515

1616
MPIBuffertypeOrConst{T} = Union{MPIBuffertype{T}, SentinelPtr}
1717

18-
if sizeof(Ptr{Cvoid}) == 8
19-
primitive type MPIPtr 64 end
20-
else
21-
primitive type MPIPtr 32 end
18+
Base.cconvert(::Type{MPIPtr}, x::Union{Ptr{T}, Array{T}, Ref{T}}) where T = Base.cconvert(Ptr{T}, x)
19+
function Base.cconvert(::Type{MPIPtr}, x::SubArray{T}) where T
20+
@assert Base.iscontiguous(x)
21+
Base.cconvert(Ptr{T}, x)
2222
end
23-
24-
MPIPtr(x::Cint) where T = reinterpret(MPIPtr, x)
25-
26-
Base.cconvert(::Type{MPIPtr}, x::MPIBuffertype{T}) where T = Base.cconvert(Ptr{T}, x)
27-
2823
function Base.unsafe_convert(::Type{MPIPtr}, x::MPIBuffertype{T}) where T
2924
ptr = Base.unsafe_convert(Ptr{T}, x)
3025
reinterpret(MPIPtr, ptr)
3126
end
3227

33-
Base.cconvert(::Type{MPIPtr}, x::SentinelPtr) = x
34-
Base.unsafe_convert(::Type{MPIPtr}, x::SentinelPtr) = reinterpret(MPIPtr, x)
35-
3628
fieldoffsets(::Type{T}) where {T} = Int[fieldoffset(T, i) for i in 1:length(fieldnames(T))]
3729

3830
# Define a function mpitype(T) that returns the MPI datatype code for

src/pointtopoint.jl

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -126,17 +126,6 @@ function Send(buf::AbstractArray{T}, dest::Integer, tag::Integer, comm::Comm) wh
126126
Send(buf, length(buf), dest, tag, comm)
127127
end
128128

129-
"""
130-
Send(buf::SubArray{T}, dest::Integer, tag::Integer, comm::Comm) where T
131-
132-
Complete a blocking send of `SubArray` `buf` to MPI rank `dest` of communicator
133-
`comm` using with the message tag `tag`. Note that the `buf` must be contiguous.
134-
"""
135-
function Send(buf::SubArray{T}, dest::Integer, tag::Integer, comm::Comm) where T
136-
@assert Base.iscontiguous(buf)
137-
Send(buf, length(buf), dest, tag, comm)
138-
end
139-
140129
"""
141130
Send(obj::T, dest::Integer, tag::Integer, comm::Comm) where T
142131
@@ -206,20 +195,6 @@ function Isend(buf::AbstractArray{T}, dest::Integer, tag::Integer, comm::Comm) w
206195
Isend(buf, length(buf), mpitype(T), dest, tag, comm)
207196
end
208197

209-
"""
210-
Isend(buf::SubArray{T}, dest::Integer, tag::Integer, comm::Comm) where T
211-
212-
Starts a nonblocking send of `SubArray` `buf` to MPI rank `dest` of communicator
213-
`comm` using with the message tag `tag`. Note that the `buf` must be contiguous.
214-
215-
Returns the commication `Request` for the nonblocking send.
216-
"""
217-
function Isend(buf::SubArray{T}, dest::Integer, tag::Integer,
218-
comm::Comm) where T
219-
@assert Base.iscontiguous(buf)
220-
Isend(buf, length(buf), mpitype(T), dest, tag, comm)
221-
end
222-
223198
"""
224199
Isend(obj::T, dest::Integer, tag::Integer, comm::Comm) where T
225200
@@ -293,20 +268,6 @@ function Recv!(buf::AbstractArray{T}, src::Integer, tag::Integer, comm::Comm) wh
293268
Recv!(buf, length(buf), src, tag, comm)
294269
end
295270

296-
"""
297-
Recv!(buf::SubArray{T}, src::Integer, tag::Integer, comm::Comm) where T
298-
299-
Completes a blocking receive into `SubArray` `buf` from MPI rank `src` of
300-
communicator `comm` using with the message tag `tag`. Note that `buf` must be
301-
contiguous.
302-
303-
Returns the `Status` of the receive
304-
"""
305-
function Recv!(buf::SubArray{T}, src::Integer, tag::Integer, comm::Comm) where T
306-
@assert Base.iscontiguous(buf)
307-
Recv!(buf, length(buf), src, tag, comm)
308-
end
309-
310271
function Recv(::Type{T}, src::Integer, tag::Integer, comm::Comm) where T
311272
buf = Ref{T}()
312273
stat = Recv!(buf, 1, src, tag, comm)
@@ -369,21 +330,6 @@ function Irecv!(buf::AbstractArray{T}, src::Integer, tag::Integer,
369330
Irecv!(buf, length(buf), mpitype(T), src, tag, comm)
370331
end
371332

372-
"""
373-
Irecv!(buf::SubArray{T}, src::Integer, tag::Integer, comm::Comm) where T
374-
375-
Starts a nonblocking receive into `SubArray` `buf` from MPI rank `src` of
376-
communicator `comm` using with the message tag `tag`. Note that `buf` must be
377-
contiguous.
378-
379-
Returns the communication `Request` for the nonblocking receive.
380-
"""
381-
function Irecv!(buf::SubArray{T}, src::Integer, tag::Integer,
382-
comm::Comm) where T
383-
@assert Base.iscontiguous(buf)
384-
Irecv!(buf, length(buf), mpitype(T), src, tag, comm)
385-
end
386-
387333
function irecv(src::Integer, tag::Integer, comm::Comm)
388334
(flag, stat) = Iprobe(src, tag, comm)
389335
if !flag

test/cudaenv/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[deps]
2+
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
3+
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
4+
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
5+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
6+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

0 commit comments

Comments
 (0)