Skip to content

Commit a584b85

Browse files
authored
Add Scatterv example based on #469 (#470)
1 parent 96070b7 commit a584b85

File tree

2 files changed

+80
-1
lines changed

2 files changed

+80
-1
lines changed

docs/examples/06-scatterv.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# This example shows how to use MPI.Scatterv! and MPI.Gatherv!
2+
# roughly based on the example from
3+
# https://stackoverflow.com/a/36082684/392585
4+
5+
using MPI
6+
7+
"""
8+
split_count(N::Integer, n::Integer)
9+
10+
Return a vector of `n` integers which are approximately equally sized and sum to `N`.
11+
"""
12+
function split_count(N::Integer, n::Integer)
13+
q,r = divrem(N, n)
14+
return [i <= r ? q+1 : q for i = 1:n]
15+
end
16+
17+
18+
MPI.Init()
19+
20+
comm = MPI.COMM_WORLD
21+
rank = MPI.Comm_rank(comm)
22+
comm_size = MPI.Comm_size(comm)
23+
24+
root = 0
25+
26+
if rank == root
27+
M, N = 4, 7
28+
29+
test = Float64[i for i = 1:M, j = 1:N]
30+
output = similar(test)
31+
32+
# Julia arrays are stored in column-major order, so we need to split along the last dimension
33+
# dimension
34+
M_counts = [M for i = 1:comm_size]
35+
N_counts = split_count(N, comm_size)
36+
37+
# store sizes in 2 * comm_size Array
38+
sizes = vcat(M_counts', N_counts')
39+
40+
# store number of values to send to each rank in comm_size length Vector
41+
counts = vec(prod(sizes, dims=1))
42+
43+
test_vbuf = VBuffer(test, counts) # VBuffer for scatter
44+
output_vbuf = VBuffer(output, counts) # VBuffer for gather
45+
else
46+
# these variables can be set to `nothing` on non-root processes
47+
sizes = nothing
48+
output_vbuf = test_vbuf = VBuffer(nothing)
49+
end
50+
51+
if rank == root
52+
println("Original matrix")
53+
println("================")
54+
@show test sizes counts
55+
println()
56+
println("Each rank")
57+
println("================")
58+
end
59+
MPI.Barrier(comm)
60+
61+
local_M, local_N = MPI.Scatter!(sizes, zeros(Int, 2), root, comm)
62+
local_test = MPI.Scatterv!(test_vbuf, zeros(Float64, local_M, local_N), root, comm)
63+
64+
for i = 0:comm_size-1
65+
if rank == i
66+
@show rank local_test
67+
end
68+
MPI.Barrier(comm)
69+
end
70+
71+
MPI.Gatherv!(local_test, output_vbuf, root, comm)
72+
73+
if rank == root
74+
println()
75+
println("Final matrix")
76+
println("================")
77+
@show output
78+
end

docs/make.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ EXAMPLES = [
77
"Broadcast" => "examples/02-broadcast.md",
88
"Reduce" => "examples/03-reduce.md",
99
"Send/receive" => "examples/04-sendrecv.md",
10-
"Job Scheduling" => "examples/05-job_schedule.md"
10+
"Job Scheduling" => "examples/05-job_schedule.md",
11+
"Scatterv and Gatherv" => "examples/06-scatterv.md",
1112
]
1213

1314
examples_md_dir = joinpath(@__DIR__,"src/examples")

0 commit comments

Comments
 (0)