Skip to content

Commit 615ca46

Browse files
kose-ysimonbyrne
authored andcommitted
size check only for AbstractArrays on collective.jl (#285)
1 parent 1bd063b commit 615ca46

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

src/collective.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ function Reduce!(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T},
8181
count::Integer, op::Union{Op,MPI_Op}, root::Integer,
8282
comm::Comm) where T
8383
isroot = Comm_rank(comm) == root
84-
isroot && @assert length(recvbuf) >= count
84+
isroot && typeof(recvbuf) <: AbstractArray && @assert length(recvbuf) >= count
8585
# int MPI_Reduce(const void* sendbuf, void* recvbuf, int count,
8686
# MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm)
8787
@mpichk ccall((:MPI_Reduce, libmpi), Cint,
@@ -161,7 +161,7 @@ function Reduce_in_place!(buf::MPIBuffertype{T}, count::Integer,
161161
op::Union{Op,MPI_Op}, root::Integer,
162162
comm::Comm) where T
163163
if Comm_rank(comm) == root
164-
@assert length(buf) >= count
164+
typeof(buf) <: AbstractArray && @assert length(buf) >= count
165165
@mpichk ccall((:MPI_Reduce, libmpi), Cint,
166166
(Ptr{T}, Ptr{T}, Cint, MPI_Datatype, MPI_Op, Cint, MPI_Comm),
167167
MPI_IN_PLACE, buf, count, mpitype(T), op, root, comm)
@@ -195,7 +195,7 @@ To handle allocation of the output buffer, see [`Allreduce`](@ref).
195195
"""
196196
function Allreduce!(sendbuf::MPIBuffertypeOrConst{T}, recvbuf::MPIBuffertype{T},
197197
count::Integer, op::Union{Op,MPI_Op}, comm::Comm) where T
198-
@assert length(recvbuf) >= count
198+
typeof(recvbuf) <: AbstractArray && @assert length(recvbuf) >= count
199199
# int MPI_Allreduce(const void* sendbuf, void* recvbuf, int count,
200200
# MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
201201
@mpichk ccall((:MPI_Allreduce, libmpi), Cint,
@@ -279,9 +279,9 @@ To handle allocation of the output buffer, see [`Scatter`](@ref).
279279
function Scatter!(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T},
280280
count::Integer, root::Integer,
281281
comm::Comm) where T
282-
@assert length(recvbuf) >= count
282+
typeof(recvbuf) <: AbstractArray && @assert length(recvbuf) >= count
283283
isroot = Comm_rank(comm) == root
284-
isroot && @assert length(sendbuf) >= count*Comm_size(comm)
284+
isroot && typeof(sendbuf) <: AbstractArray && @assert length(sendbuf) >= count*Comm_size(comm)
285285

286286
# int MPI_Scatter(const void* sendbuf, int sendcount, MPI_Datatype sendtype,
287287
# void* recvbuf, int recvcount, MPI_Datatype recvtype, int root,
@@ -352,7 +352,7 @@ function Scatterv!(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T},
352352
counts::Vector{Cint}, root::Integer, comm::Comm) where T
353353
recvcnt = counts[Comm_rank(comm) + 1]
354354
disps = accumulate(+, counts) - counts
355-
@assert length(recvbuf) >= recvcnt
355+
typeof(recvbuf) <: AbstractArray && @assert length(recvbuf) >= recvcnt
356356
# int MPI_Scatterv(const void* sendbuf, const int sendcounts[],
357357
# const int displs[], MPI_Datatype sendtype, void* recvbuf,
358358
# int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
@@ -396,7 +396,7 @@ function Scatterv_in_place!(buf::MPIBuffertype{T}, counts::Vector{Cint},
396396
root::Integer, comm::Comm) where T
397397
recvcnt = counts[Comm_rank(comm) + 1]
398398
disps = accumulate(+, counts) - counts
399-
@assert length(buf) >= recvcnt
399+
typeof(buf) <: AbstractArray && @assert length(buf) >= recvcnt
400400

401401
if Comm_rank(comm) == root
402402
@mpichk ccall((:MPI_Scatterv, libmpi), Cint,
@@ -424,9 +424,9 @@ To perform the reduction in place refer to [`Gather_in_place!`](@ref).
424424
"""
425425
function Gather!(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T},
426426
count::Integer, root::Integer, comm::Comm) where T
427-
@assert length(sendbuf) >= count
427+
typeof(sendbuf) <: AbstractArray && @assert length(sendbuf) >= count
428428
isroot = Comm_rank(comm) == root
429-
isroot && @assert length(recvbuf) >= count*Comm_size(comm)
429+
isroot && typeof(recvbuf) <: AbstractArray && @assert length(recvbuf) >= count*Comm_size(comm)
430430

431431
# int MPI_Gather(const void* sendbuf, int sendcount, MPI_Datatype sendtype,
432432
# void* recvbuf, int recvcount, MPI_Datatype recvtype, int root,
@@ -486,7 +486,7 @@ end
486486
function Gather_in_place!(buf::MPIBuffertype{T}, count::Integer, root::Integer,
487487
comm::Comm) where T
488488
if Comm_rank(comm) == root
489-
@assert length(buf) >= count*Comm_size(comm)
489+
typeof(buf) <: AbstractArray && @assert length(buf) >= count*Comm_size(comm)
490490
@mpichk ccall((:MPI_Gather, libmpi), Cint,
491491
(Ptr{T}, Cint, MPI_Datatype, Ptr{T}, Cint, MPI_Datatype, Cint, MPI_Comm),
492492
MPI_IN_PLACE, count, mpitype(T), buf, count, mpitype(T), root, comm)
@@ -511,7 +511,7 @@ contribution.
511511
"""
512512
function Allgather!(sendbuf::MPIBuffertypeOrConst{T}, recvbuf::MPIBuffertype{T},
513513
count::Integer, comm::Comm) where T
514-
@assert length(recvbuf) >= Comm_size(comm)*count
514+
typeof(recvbuf) <: AbstractArray && @assert length(recvbuf) >= Comm_size(comm)*count
515515
# int MPI_Allgather(const void* sendbuf, int sendcount,
516516
# MPI_Datatype sendtype, void* recvbuf, int recvcount,
517517
# MPI_Datatype recvtype, MPI_Comm comm)
@@ -573,7 +573,7 @@ function Gatherv!(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T},
573573
isroot = Comm_rank(comm) == root
574574
displs = accumulate(+, counts) - counts
575575
sendcnt = counts[Comm_rank(comm) + 1]
576-
isroot && @assert length(recvbuf) >= sum(counts)
576+
isroot && typeof(recvbuf) <: AbstractArray && @assert length(recvbuf) >= sum(counts)
577577
# int MPI_Gatherv(const void* sendbuf, int sendcount, MPI_Datatype sendtype,
578578
# void* recvbuf, const int recvcounts[], const int displs[],
579579
# MPI_Datatype recvtype, int root, MPI_Comm comm)
@@ -620,7 +620,7 @@ function Gatherv_in_place!(buf::MPIBuffertype{T}, counts::Vector{Cint},
620620
sendcnt = counts[Comm_rank(comm) + 1]
621621

622622
if isroot
623-
@assert length(buf) >= sum(counts)
623+
typeof(buf) <: AbstractArray && @assert length(buf) >= sum(counts)
624624
@mpichk ccall((:MPI_Gatherv, libmpi), Cint,
625625
(Ptr{T}, Cint, MPI_Datatype, Ptr{T}, Ptr{Cint}, Ptr{Cint}, MPI_Datatype, Cint, MPI_Comm),
626626
MPI_IN_PLACE, sendcnt, mpitype(T), buf, counts, displs, mpitype(T), root, comm)
@@ -645,7 +645,7 @@ the interval of `recvbuf` where it would store it's own data.
645645
"""
646646
function Allgatherv!(sendbuf::MPIBuffertypeOrConst{T}, recvbuf::MPIBuffertype{T},
647647
counts::Vector{Cint}, comm::Comm) where T
648-
@assert length(recvbuf) >= sum(counts)
648+
typeof(recvbuf) <: AbstractArray && @assert length(recvbuf) >= sum(counts)
649649
displs = accumulate(+, counts) - counts
650650
sendcnt = counts[Comm_rank(comm) + 1]
651651
# int MPI_Allgatherv(const void* sendbuf, int sendcount,
@@ -731,7 +731,7 @@ end
731731
function Alltoallv!(sendbuf::MPIBuffertype{T}, recvbuf::MPIBuffertype{T},
732732
scounts::Vector{Cint}, rcounts::Vector{Cint},
733733
comm::Comm) where T
734-
@assert length(recvbuf) == sum(rcounts)
734+
typeof(recvbuf) <: AbstractArray && @assert length(recvbuf) == sum(rcounts)
735735

736736
sdispls = accumulate(+, scounts) - scounts
737737
rdispls = accumulate(+, rcounts) - rcounts

0 commit comments

Comments
 (0)