Skip to content

Commit 45c2f82

Browse files
authored
Merge pull request #8 from gjbex/development
Add MPI test application
2 parents 32b1215 + cff6cff commit 45c2f82

File tree

2 files changed

+225
-0
lines changed

2 files changed

+225
-0
lines changed

source-code/mpi4py/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ C/C++ or Fortran implementations.
2121
1. `mpi_count.py`: count amino acids in a long sequence, distributing
2222
the work over processes.
2323
1. `large_dna.txt`: example data file to use with `mpi_count.py`.
24+
1. `mpifitness.py`: application to tmie various MPI communications.

source-code/mpi4py/mpifitness.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
#!/usr/bin/env python
2+
3+
from argparse import ArgumentParser
4+
from mpi4py import MPI
5+
import sys
6+
import time
7+
8+
9+
def make_msg(value, nr_bytes):
10+
return f'{value:08d}'*(nr_bytes//8)
11+
12+
13+
def make_float_msg(value, nr_bytes):
14+
return float(value)*(nr_bytes//8)
15+
16+
17+
def acknowledge(comm):
18+
comm.barrier()
19+
rank = comm.Get_rank()
20+
size = comm.Get_size()
21+
print(f'process {rank} out of {size}')
22+
comm.barrier()
23+
24+
25+
def pingpong(comm, nr_iters, msg_size):
26+
comm.barrier()
27+
rank = comm.Get_rank()
28+
size = comm.Get_size()
29+
for _ in range(nr_iters):
30+
for source in range(size):
31+
for dest in range(size):
32+
if (source != dest):
33+
if rank == source:
34+
start_time = time.time()
35+
comm.send(make_msg(source, msg_size), dest=dest)
36+
msg = comm.recv(source=dest)
37+
end_time = time.time()
38+
if msg != make_msg(dest, msg_size):
39+
print(f'{rank} received {msg}, expected {dest}',
40+
file=sys.stderr)
41+
comm.Abort(1)
42+
print(f'{rank} -> {dest} pig-pong: {end_time - start_time}')
43+
if rank == dest:
44+
start_time = time.time()
45+
msg = comm.recv(source=source)
46+
comm.send(make_msg(dest, msg_size), dest=source)
47+
end_time = time.time()
48+
if msg != make_msg(source, msg_size):
49+
print(f'{rank} received {msg}, expected {source}',
50+
file=sys.stderr)
51+
comm.Abort(1)
52+
print(f'{rank} -> {source} ping-pong: {end_time - start_time}')
53+
comm.barrier()
54+
55+
56+
def broadcast(comm, nr_iters, msg_size):
57+
comm.barrier()
58+
rank = comm.Get_rank()
59+
size = comm.Get_size()
60+
for _ in range(nr_iters):
61+
for root in range(size):
62+
msg = None
63+
if (rank == root):
64+
msg = make_msg(root, msg_size)
65+
start_time = time.time()
66+
msg = comm.bcast(msg, root=root)
67+
end_time = time.time()
68+
print(f'{root} -> {rank} bcast: {end_time - start_time}')
69+
if msg != make_msg(root, msg_size):
70+
print(f'{rank} received unexpected bcast message')
71+
comm.Abort(2)
72+
comm.barrier()
73+
74+
75+
def scatter(comm, nr_iters, msg_size):
76+
comm.barrier()
77+
rank = comm.Get_rank()
78+
size = comm.Get_size()
79+
for _ in range(nr_iters):
80+
for root in range(size):
81+
msg = None
82+
if (rank == root):
83+
msg = [make_msg(dest, msg_size) for dest in range(size)]
84+
start_time = time.time()
85+
msg = comm.scatter(msg, root=root)
86+
end_time = time.time()
87+
print(f'{root} -> {rank} scatter: {end_time - start_time}')
88+
if msg != make_msg(rank, msg_size):
89+
print(f'{rank} received unexpected scatter message')
90+
comm.Abort(2)
91+
comm.barrier()
92+
93+
94+
def gather(comm, nr_iters, msg_size):
95+
comm.barrier()
96+
rank = comm.Get_rank()
97+
size = comm.Get_size()
98+
for _ in range(nr_iters):
99+
for root in range(size):
100+
msg = make_msg(rank, msg_size)
101+
start_time = time.time()
102+
msg = comm.gather(msg, root=root)
103+
end_time = time.time()
104+
print(f'{root} -> {rank} gather: {end_time - start_time}')
105+
if (rank == root):
106+
if len(msg) != size:
107+
print(f'{rank} received unexpected gather message')
108+
comm.Abort(2)
109+
for i, msg in enumerate(msg):
110+
if msg != make_msg(i, msg_size):
111+
print(f'{rank} received unexpected gather message')
112+
comm.Abort(2)
113+
comm.barrier()
114+
115+
116+
def alltoall(comm, nr_iters, msg_size):
117+
comm.barrier()
118+
rank = comm.Get_rank()
119+
size = comm.Get_size()
120+
for _ in range(nr_iters):
121+
msg = [make_msg(rank, msg_size) for _ in range(size)]
122+
start_time = time.time()
123+
msg = comm.alltoall(msg)
124+
end_time = time.time()
125+
print(f'{rank} alltoall: {end_time - start_time}')
126+
if len(msg) != size:
127+
print(f'{rank} received unexpected alltoall message')
128+
comm.Abort(2)
129+
for i, msg in enumerate(msg):
130+
if msg != make_msg(i, msg_size):
131+
print(f'{rank} received unexpected alltoall message')
132+
comm.Abort(2)
133+
comm.barrier()
134+
135+
136+
def reduce(comm, nr_iters, msg_size):
137+
comm.barrier()
138+
rank = comm.Get_rank()
139+
size = comm.Get_size()
140+
for _ in range(nr_iters):
141+
for root in range(size):
142+
msg = make_float_msg(rank, msg_size)
143+
start_time = time.time()
144+
msg = comm.reduce(msg, op=MPI.SUM, root=root)
145+
end_time = time.time()
146+
print(f'{root} -> {rank} reduce: {end_time - start_time}')
147+
comm.barrier()
148+
149+
150+
def main():
151+
root = 0
152+
comm = MPI.COMM_WORLD
153+
rank = comm.Get_rank()
154+
if (rank == root):
155+
print(f'# acknowledgment')
156+
acknowledge(comm)
157+
arg_parser = ArgumentParser(description='MPI performance benchmark')
158+
arg_parser.add_argument('--nr_pingpongs', type=int, default=10,
159+
help='number of ping-pong iterations to perform')
160+
arg_parser.add_argument('--pingpong_size', type=int, default=8,
161+
help='number of bytes for ping-pong message')
162+
arg_parser.add_argument('--nr_bcasts', type=int, default=10,
163+
help='number of broadcast iteration to perform')
164+
arg_parser.add_argument('--bcast_size', type=int, default=8,
165+
help='number of bytes for broadcast message')
166+
arg_parser.add_argument('--nr_scatters', type=int, default=10,
167+
help='number of scatter iteration to perform')
168+
arg_parser.add_argument('--scatter_size', type=int, default=8,
169+
help='number of bytes for scatter message')
170+
arg_parser.add_argument('--nr_gathers', type=int, default=10,
171+
help='number of gather iteration to perform')
172+
arg_parser.add_argument('--gather_size', type=int, default=8,
173+
help='number of bytes for gather message')
174+
arg_parser.add_argument('--nr_alltoalls', type=int, default=10,
175+
help='number of alltoall iteration to perform')
176+
arg_parser.add_argument('--alltoall_size', type=int, default=8,
177+
help='number of bytes for alltoall message')
178+
arg_parser.add_argument('--nr_reduces', type=int, default=10,
179+
help='number of reduce iteration to perform')
180+
arg_parser.add_argument('--reduce_size', type=int, default=8,
181+
help='number of bytes for reduce message')
182+
options = arg_parser.parse_args()
183+
comm.barrier()
184+
if (rank == root):
185+
print(f'# {options.nr_pingpongs} ping-pong iterations, '
186+
f'size {options.pingpong_size}')
187+
comm.barrier()
188+
pingpong(comm, options.nr_pingpongs, options.pingpong_size)
189+
comm.barrier()
190+
if (rank == root):
191+
print(f'# {options.nr_bcasts} broadcast iterations, '
192+
f'size {options.bcast_size}')
193+
comm.barrier()
194+
broadcast(comm, options.nr_bcasts, options.bcast_size)
195+
comm.barrier()
196+
if (rank == root):
197+
print(f'# {options.nr_scatters} scatter iterations, '
198+
f'size {options.scatter_size}')
199+
comm.barrier()
200+
scatter(comm, options.nr_scatters, options.scatter_size)
201+
comm.barrier()
202+
if (rank == root):
203+
print(f'# {options.nr_gathers} gather iterations, '
204+
f'size {options.gather_size}')
205+
comm.barrier()
206+
gather(comm, options.nr_gathers, options.gather_size)
207+
comm.barrier()
208+
if (rank == root):
209+
print(f'# {options.nr_alltoalls} alltoall iterations, '
210+
f'size {options.alltoall_size}')
211+
comm.barrier()
212+
alltoall(comm, options.nr_alltoalls, options.alltoall_size)
213+
comm.barrier()
214+
if (rank == root):
215+
print(f'# {options.nr_reduces} reduce iterations, '
216+
f'size {options.reduce_size}')
217+
comm.barrier()
218+
reduce(comm, options.nr_reduces, options.reduce_size)
219+
comm.barrier()
220+
return 0
221+
222+
223+
if __name__ == '__main__':
224+
sys.exit(main())

0 commit comments

Comments
 (0)