Skip to content

Commit 050283c

Browse files
committed
bugfix for MPI controller
1 parent 57d524e commit 050283c

File tree

2 files changed

+24
-27
lines changed

2 files changed

+24
-27
lines changed

pySDC/implementations/controller_classes/controller_MPI.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ def __init__(self, controller_params, description, comm):
3131

3232
# pass communicator for future use
3333
self.comm = comm
34-
# add request handle container for isend
35-
self.req_send = []
3634
# add request handler for status send
3735
self.req_status = None
3836

@@ -44,6 +42,9 @@ def __init__(self, controller_params, description, comm):
4442

4543
num_levels = len(self.S.levels)
4644

45+
# add request handle container for isend
46+
self.req_send = [None] * num_levels
47+
4748
if num_procs > 1 and num_levels > 1:
4849
for L in self.S.levels:
4950
if not L.sweep.coll.right_is_node or L.sweep.params.do_coll_update:
@@ -161,7 +162,7 @@ def restart_block(self, size, time, u0):
161162
for l in self.S.levels:
162163
l.tag = None
163164
self.req_status = None
164-
self.req_send = []
165+
self.req_send = [None] * len(self.S.levels)
165166
self.S.status.prev_done = False
166167

167168
for lvl in self.S.levels:
@@ -328,22 +329,21 @@ def pfasst(self, comm, num_procs):
328329
# check whether to stop iterating (parallel)
329330

330331
self.hooks.pre_comm(step=self.S, level_number=0)
331-
req_send = None
332332
self.S.levels[0].sweep.compute_end_point()
333333
if not self.S.status.last and self.params.fine_comm:
334+
if self.req_send[0] is not None:
335+
self.req_send[0].wait()
334336
self.logger.debug('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
335337
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
336338
0, self.S.status.iter))
337-
req_send = self.S.levels[0].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm)
339+
self.req_send[0] = self.S.levels[0].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm)
338340

339341
if not self.S.status.first and not self.S.status.prev_done and self.params.fine_comm:
340342
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
341343
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
342344
0, self.S.status.iter))
343345
self.recv(target=self.S.levels[0], source=self.S.prev, tag=self.S.status.iter, comm=comm)
344346

345-
if not self.S.status.last and self.params.fine_comm:
346-
req_send.wait()
347347
self.hooks.post_comm(step=self.S, level_number=0)
348348

349349
self.S.levels[0].sweep.compute_residual()
@@ -414,22 +414,21 @@ def pfasst(self, comm, num_procs):
414414
self.S.levels[0].status.sweep += 1
415415

416416
self.hooks.pre_comm(step=self.S, level_number=0)
417-
req_send = None
418417
self.S.levels[0].sweep.compute_end_point()
419418
if not self.S.status.last and self.params.fine_comm:
419+
if self.req_send[0] is not None:
420+
self.req_send[0].wait()
420421
self.logger.debug('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
421422
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
422423
0, self.S.status.iter))
423-
req_send = self.S.levels[0].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm)
424+
self.req_send[0] = self.S.levels[0].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm)
424425

425426
if not self.S.status.first and not self.S.status.prev_done and self.params.fine_comm:
426427
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
427428
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
428429
0, self.S.status.iter))
429430
self.recv(target=self.S.levels[0], source=self.S.prev, tag=self.S.status.iter, comm=comm)
430431

431-
if not self.S.status.last and self.params.fine_comm:
432-
req_send.wait()
433432
self.hooks.post_comm(step=self.S, level_number=0, add_to_stats=(k == nsweeps - 1))
434433

435434
self.hooks.pre_sweep(step=self.S, level_number=0)
@@ -454,22 +453,21 @@ def pfasst(self, comm, num_procs):
454453
for k in range(nsweeps):
455454

456455
self.hooks.pre_comm(step=self.S, level_number=l)
457-
req_send = None
458456
self.S.levels[l].sweep.compute_end_point()
459457
if not self.S.status.last and self.params.fine_comm:
458+
if self.req_send[l] is not None:
459+
self.req_send[l].wait()
460460
self.logger.debug('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
461461
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
462462
l, self.S.status.iter))
463-
req_send = self.S.levels[l].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm)
463+
self.req_send[l] = self.S.levels[l].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm)
464464

465465
if not self.S.status.first and not self.S.status.prev_done and self.params.fine_comm:
466466
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
467467
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
468468
l, self.S.status.iter))
469469
self.recv(target=self.S.levels[l], source=self.S.prev, tag=self.S.status.iter, comm=comm)
470470

471-
if not self.S.status.last and self.params.fine_comm:
472-
req_send.wait()
473471
self.hooks.post_comm(step=self.S, level_number=l)
474472

475473
self.hooks.pre_sweep(step=self.S, level_number=l)
@@ -539,13 +537,14 @@ def pfasst(self, comm, num_procs):
539537
for k in range(nsweeps):
540538

541539
self.hooks.pre_comm(step=self.S, level_number=l - 1)
542-
req_send = None
543540
self.S.levels[l - 1].sweep.compute_end_point()
544541
if not self.S.status.last and self.params.fine_comm:
542+
if self.req_send[l - 1] is not None:
543+
self.req_send[l - 1].wait()
545544
self.logger.debug('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
546545
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
547546
l - 1, self.S.status.iter))
548-
req_send = self.S.levels[l - 1].uend.isend(dest=self.S.next, tag=self.S.status.iter,
547+
self.req_send[l - 1] = self.S.levels[l - 1].uend.isend(dest=self.S.next, tag=self.S.status.iter,
549548
comm=comm)
550549

551550
if not self.S.status.first and not self.S.status.prev_done and self.params.fine_comm:
@@ -555,8 +554,6 @@ def pfasst(self, comm, num_procs):
555554
self.recv(target=self.S.levels[l - 1], source=self.S.prev, tag=self.S.status.iter,
556555
comm=comm)
557556

558-
if not self.S.status.last and self.params.fine_comm:
559-
req_send.wait()
560557
self.hooks.post_comm(step=self.S, level_number=l - 1, add_to_stats=(k == nsweeps - 1))
561558

562559
self.hooks.pre_sweep(step=self.S, level_number=l - 1)

pySDC/playgrounds/parallel/AllenCahn_contracting_circle_FFT.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from pySDC.helpers.stats_helper import filter_stats, sort_stats
77
from pySDC.implementations.collocation_classes.gauss_radau_right import CollGaussRadau_Right
88
from pySDC.implementations.controller_classes.controller_MPI import controller_MPI
9-
# from pySDC.implementations.problem_classes.AllenCahn_2D_FFT import allencahn2d_imex, allencahn2d_imex_stab
10-
from pySDC.implementations.problem_classes.AllenCahn_2D_parFFT import allencahn2d_imex, allencahn2d_imex_stab
9+
from pySDC.implementations.problem_classes.AllenCahn_2D_FFT import allencahn2d_imex, allencahn2d_imex_stab
10+
# from pySDC.implementations.problem_classes.AllenCahn_2D_parFFT import allencahn2d_imex, allencahn2d_imex_stab
1111
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
1212
from pySDC.implementations.transfer_classes.TransferMesh_FFT2D import mesh_to_mesh_fft2d
1313
from pySDC.playgrounds.parallel.AllenCahn_parallel_monitor import monitor
@@ -33,7 +33,7 @@ def setup_parameters():
3333
level_params = dict()
3434
level_params['restol'] = 1E-08
3535
level_params['dt'] = 1E-03
36-
level_params['nsweeps'] = [1]#, 1]
36+
level_params['nsweeps'] = [3, 1]
3737

3838
# initialize sweeper parameters
3939
sweeper_params = dict()
@@ -47,8 +47,8 @@ def setup_parameters():
4747
problem_params = dict()
4848
problem_params['nu'] = 2
4949
problem_params['L'] = 1.0
50-
problem_params['nvars'] = [(256, 256)]#, (64, 64)]
51-
problem_params['eps'] = [0.04]#, 0.16]
50+
problem_params['nvars'] = [(256, 256), (64, 64)]
51+
problem_params['eps'] = [0.04, 0.16]
5252
problem_params['radius'] = 0.25
5353

5454
# initialize step parameters
@@ -148,9 +148,9 @@ def run_SDC_variant(variant=None):
148148
# call main function to get things done...
149149
uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
150150

151-
if time_rank == 0:
152-
plt_helper.plt.imshow(uend.values)
153-
plt_helper.savefig(f'uend_{space_rank}', save_pdf=False, save_pgf=False, save_png=True)
151+
# if time_rank == 0:
152+
# plt_helper.plt.imshow(uend.values)
153+
# plt_helper.savefig(f'uend_{space_rank}', save_pdf=False, save_pgf=False, save_png=True)
154154
# exit()
155155

156156
rank = comm.Get_rank()

0 commit comments

Comments
 (0)