Skip to content

Commit d1a55c9

Browse files
committed
fixed MPI multigrid
1 parent 8981101 commit d1a55c9

File tree

2 files changed

+103
-38
lines changed

2 files changed

+103
-38
lines changed

pySDC/implementations/controller_classes/allinclusive_multigrid_MPI.py

Lines changed: 102 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def pfasst(self, comm, num_procs):
240240
if len(self.S.levels) > 1 and self.params.predict: # MLSDC or PFASST with predict
241241
self.S.status.stage = 'PREDICT'
242242
else:
243-
self.S.status.stage = 'IT_CHECK'
243+
self.S.status.stage = 'IT_FINE'
244244

245245
elif stage == 'PREDICT':
246246

@@ -250,7 +250,7 @@ def pfasst(self, comm, num_procs):
250250

251251
# update stage
252252
self.hooks.pre_iteration(step=self.S, level_number=0)
253-
self.S.status.stage = 'IT_CHECK'
253+
self.S.status.stage = 'IT_FINE'
254254

255255
elif stage == 'IT_CHECK':
256256

@@ -286,7 +286,11 @@ def pfasst(self, comm, num_procs):
286286
# increment iteration count here (and only here)
287287
self.S.status.iter += 1
288288
self.hooks.pre_iteration(step=self.S, level_number=0)
289-
self.S.status.stage = 'IT_FINE'
289+
if len(self.S.levels) > 1: # MLSDC or PFASST
290+
self.S.status.stage = 'IT_UP'
291+
else: # SDC
292+
self.S.status.stage = 'IT_FINE'
293+
290294

291295
else:
292296
self.S.levels[0].sweep.compute_end_point()
@@ -295,34 +299,74 @@ def pfasst(self, comm, num_procs):
295299

296300
elif stage == 'IT_FINE':
297301

302+
nsweeps = self.S.levels[0].params.nsweeps
303+
298304
# do fine sweep
305+
for k in range(nsweeps):
299306

300-
# standard sweep workflow: update nodes, compute residual, log progress
301-
self.hooks.pre_sweep(step=self.S, level_number=0)
302-
for k in range(self.S.levels[0].params.nsweeps):
307+
self.hooks.pre_sweep(step=self.S, level_number=0)
303308
self.S.levels[0].sweep.update_nodes()
304-
self.S.levels[0].sweep.compute_residual()
305-
self.hooks.post_sweep(step=self.S, level_number=0)
309+
310+
req_send = None
311+
self.S.levels[0].sweep.compute_end_point()
312+
if not self.S.status.last and self.params.fine_comm:
313+
self.logger.debug('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
314+
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
315+
0, self.S.status.iter))
316+
req_send = comm.isend(self.S.levels[0].uend, dest=self.S.next, tag=self.S.status.iter)
317+
318+
if not self.S.status.first and self.params.fine_comm:
319+
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
320+
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
321+
0, self.S.status.iter))
322+
self.recv(target=self.S.levels[0], source=self.S.prev, tag=self.S.status.iter, comm=comm)
323+
324+
if not self.S.status.last and self.params.fine_comm:
325+
req_send.wait()
326+
327+
self.S.levels[0].sweep.compute_residual()
328+
self.hooks.post_sweep(step=self.S, level_number=0)
306329

307330
# update stage
308-
if len(self.S.levels) > 1: # MLSDC or PFASST
309-
self.S.status.stage = 'IT_UP'
310-
else: # SDC
311-
self.S.status.stage = 'IT_CHECK'
331+
self.S.status.stage = 'IT_CHECK'
312332

313333
elif stage == 'IT_UP':
314334

315335
# go up the hierarchy from finest to coarsest level (parallel)
316336

337+
338+
317339
self.S.transfer(source=self.S.levels[0], target=self.S.levels[1])
318340

319341
# sweep and send on middle levels (not on finest, not on coarsest, though)
320342
for l in range(1, len(self.S.levels) - 1):
321-
self.hooks.pre_sweep(step=self.S, level_number=l)
322-
for k in range(self.S.levels[l].params.nsweeps):
343+
344+
nsweeps = self.S.levels[l].params.nsweeps
345+
346+
for k in range(nsweeps):
347+
348+
self.hooks.pre_sweep(step=self.S, level_number=l)
323349
self.S.levels[l].sweep.update_nodes()
324-
self.S.levels[l].sweep.compute_residual()
325-
self.hooks.post_sweep(step=self.S, level_number=l)
350+
351+
req_send = None
352+
self.S.levels[l].sweep.compute_end_point()
353+
if not self.S.status.last and self.params.fine_comm:
354+
self.logger.debug('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
355+
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
356+
l, self.S.status.iter))
357+
req_send = comm.isend(self.S.levels[l].uend, dest=self.S.next, tag=self.S.status.iter)
358+
359+
if not self.S.status.first and self.params.fine_comm:
360+
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
361+
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
362+
l, self.S.status.iter))
363+
self.recv(target=self.S.levels[l], source=self.S.prev, tag=self.S.status.iter, comm=comm)
364+
365+
if not self.S.status.last and self.params.fine_comm:
366+
req_send.wait()
367+
368+
self.S.levels[l].sweep.compute_residual()
369+
self.hooks.post_sweep(step=self.S, level_number=l)
326370

327371
# transfer further up the hierarchy
328372
self.S.transfer(source=self.S.levels[l], target=self.S.levels[l + 1])
@@ -343,8 +387,10 @@ def pfasst(self, comm, num_procs):
343387

344388
# do the sweep
345389
self.hooks.pre_sweep(step=self.S, level_number=len(self.S.levels) - 1)
346-
for k in range(self.S.levels[-1].params.nsweeps):
347-
self.S.levels[-1].sweep.update_nodes()
390+
assert self.S.levels[-1].params.nsweeps == 1, \
391+
'ERROR: this controller can only work with one sweep on the coarse level, got %s' % \
392+
self.S.levels[-1].params.nsweeps
393+
self.S.levels[-1].sweep.update_nodes()
348394
self.S.levels[-1].sweep.compute_residual()
349395
self.hooks.post_sweep(step=self.S, level_number=len(self.S.levels) - 1)
350396
self.S.levels[-1].sweep.compute_end_point()
@@ -370,32 +416,51 @@ def pfasst(self, comm, num_procs):
370416
self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1])
371417
self.S.levels[l - 1].sweep.compute_end_point()
372418

373-
# on middle levels: do sweep as usual
374-
if l - 1 > 0:
375-
req_send = None
376-
if not self.S.status.last and self.params.fine_comm:
377-
self.logger.debug('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
378-
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
379-
l - 1, self.S.status.iter))
380-
req_send = comm.isend(self.S.levels[l - 1].uend, dest=self.S.next, tag=self.S.status.iter)
419+
req_send = None
420+
if not self.S.status.last and self.params.fine_comm:
421+
self.logger.debug('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
422+
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
423+
l - 1, self.S.status.iter))
424+
req_send = comm.isend(self.S.levels[l - 1].uend, dest=self.S.next, tag=self.S.status.iter)
381425

382-
if not self.S.status.first and self.params.fine_comm:
383-
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
384-
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
385-
l - 1, self.S.status.iter))
386-
self.recv(target=self.S.levels[l - 1], source=self.S.prev, tag=self.S.status.iter, comm=comm)
426+
if not self.S.status.first and self.params.fine_comm:
427+
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
428+
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
429+
l - 1, self.S.status.iter))
430+
self.recv(target=self.S.levels[l - 1], source=self.S.prev, tag=self.S.status.iter, comm=comm)
387431

388-
if not self.S.status.last and self.params.fine_comm:
389-
req_send.wait()
432+
if not self.S.status.last and self.params.fine_comm:
433+
req_send.wait()
434+
435+
# on middle levels: do sweep as usual
436+
if l - 1 > 0:
390437

391-
self.hooks.pre_sweep(step=self.S, level_number=l - 1)
392438
for k in range(self.S.levels[l - 1].params.nsweeps):
439+
440+
self.hooks.pre_sweep(step=self.S, level_number=l - 1)
393441
self.S.levels[l - 1].sweep.update_nodes()
394-
self.S.levels[l - 1].sweep.compute_residual()
395-
self.hooks.post_sweep(step=self.S, level_number=l - 1)
442+
443+
req_send = None
444+
if not self.S.status.last and self.params.fine_comm:
445+
self.logger.debug('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
446+
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
447+
l - 1, self.S.status.iter))
448+
req_send = comm.isend(self.S.levels[l - 1].uend, dest=self.S.next, tag=self.S.status.iter)
449+
450+
if not self.S.status.first and self.params.fine_comm:
451+
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
452+
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
453+
l - 1, self.S.status.iter))
454+
self.recv(target=self.S.levels[l - 1], source=self.S.prev, tag=self.S.status.iter, comm=comm)
455+
456+
if not self.S.status.last and self.params.fine_comm:
457+
req_send.wait()
458+
459+
self.S.levels[l - 1].sweep.compute_residual()
460+
self.hooks.post_sweep(step=self.S, level_number=l - 1)
396461

397462
# update stage
398-
self.S.status.stage = 'IT_CHECK'
463+
self.S.status.stage = 'IT_FINE'
399464

400465
else:
401466

pySDC/tutorial/step_6/A_classic_vs_multigrid_controller.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def main(num_proc_list=None, fname=None):
7878

7979
assert all([item[1] <= 7 for item in iter_counts_multigrid]), \
8080
"ERROR: weird iteration counts for multigrid, got %s" % iter_counts_multigrid
81-
assert diff < 2E-10, "ERROR: difference between classic and multigrid controller is too large, got %s" % diff
81+
assert diff < 2.2E-10, "ERROR: difference between classic and multigrid controller is too large, got %s" % diff
8282

8383
f.close()
8484

0 commit comments

Comments
 (0)