Skip to content

Commit 7587391

Browse files
committed
fixed MPI-MG?
1 parent db48749 commit 7587391

File tree

2 files changed

+65
-67
lines changed

2 files changed

+65
-67
lines changed

pySDC/core/Hooks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def post_sweep(self, step, level_number):
123123

124124
L = step.levels[level_number]
125125

126-
self.logger.info('Process %2i on time %8.6f at stage %15s: Level: %s -- Iteration: %2i -- Residual: %12.8e',
126+
self.logger.info('Process %2i on time %8.6f at stage %15s: Level: %s -- Iteration: %2i -- lagged residual: %12.8e',
127127
step.status.slot, L.time, step.status.stage, L.level_index, step.status.iter,
128128
L.status.residual)
129129

pySDC/implementations/controller_classes/allinclusive_multigrid_MPI.py

Lines changed: 64 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def restart_block(self, size, time, u0):
139139
self.S.init_step(u0)
140140
# reset some values
141141
self.S.status.done = False
142-
self.S.status.iter = 1
142+
self.S.status.iter = 0
143143
self.S.status.stage = 'SPREAD'
144144
for l in self.S.levels:
145145
l.tag = None
@@ -239,17 +239,9 @@ def pfasst(self, comm, num_procs):
239239
# update stage
240240
if len(self.S.levels) > 1 and self.params.predict: # MLSDC or PFASST with predict
241241
self.S.status.stage = 'PREDICT'
242-
elif len(self.S.levels) > 1: # MLSDC or PFASST without predict
243-
self.hooks.pre_iteration(step=self.S, level_number=0)
244-
self.S.status.stage = 'IT_FINE'
245-
elif num_procs > 1 and len(self.S.levels) == 1: # MSSDC
246-
self.hooks.pre_iteration(step=self.S, level_number=0)
247-
self.S.status.stage = 'IT_COARSE'
248-
elif num_procs == 1: # SDC
249-
self.hooks.pre_iteration(step=self.S, level_number=0)
250-
self.S.status.stage = 'IT_FINE'
251242
else:
252-
raise ControllerError("Don't know what to do after spread, aborting")
243+
self.S.status.stage = 'IT_CHECK'
244+
253245

254246
elif stage == 'PREDICT':
255247

@@ -259,51 +251,66 @@ def pfasst(self, comm, num_procs):
259251

260252
# update stage
261253
self.hooks.pre_iteration(step=self.S, level_number=0)
262-
self.S.status.stage = 'IT_FINE'
263-
264-
elif stage == 'IT_FINE':
265-
266-
# do fine sweep
267-
268-
# standard sweep workflow: update nodes, compute residual, log progress
269-
self.hooks.pre_sweep(step=self.S, level_number=0)
270-
for k in range(self.S.levels[0].params.nsweeps):
271-
self.S.levels[0].sweep.update_nodes()
272-
self.S.levels[0].sweep.compute_residual()
273-
self.hooks.post_sweep(step=self.S, level_number=0)
274-
275-
# update stage
276254
self.S.status.stage = 'IT_CHECK'
277255

278256
elif stage == 'IT_CHECK':
279257

280258
# check whether to stop iterating (parallel)
281259

282-
self.hooks.post_iteration(step=self.S, level_number=0)
260+
req_send = None
261+
self.S.levels[0].sweep.compute_end_point()
262+
if not self.S.status.last and self.params.fine_comm:
263+
self.logger.debug('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
264+
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
265+
0, self.S.status.iter))
266+
req_send = comm.isend(self.S.levels[0].uend, dest=self.S.next, tag=self.S.status.iter)
267+
268+
if not self.S.status.first and self.params.fine_comm:
269+
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
270+
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
271+
0, self.S.status.iter))
272+
self.recv(target=self.S.levels[0], source=self.S.prev, tag=self.S.status.iter, comm=comm)
273+
274+
if not self.S.status.last and self.params.fine_comm:
275+
req_send.wait()
276+
277+
self.S.levels[0].sweep.compute_residual()
283278
self.S.status.done = self.check_convergence(self.S)
284279
all_done = comm.allgather(self.S.status.done)
285280

281+
if self.S.status.iter > 0:
282+
self.hooks.post_iteration(step=self.S, level_number=0)
283+
286284
# if not everyone is ready yet, keep doing stuff
287285
if not all(all_done):
288286
self.S.status.done = False
289287
# increment iteration count here (and only here)
290288
self.S.status.iter += 1
291289
self.hooks.pre_iteration(step=self.S, level_number=0)
292-
# multi-level or single-level?
293-
if len(self.S.levels) > 1: # MLSDC or PFASST
294-
self.S.status.stage = 'IT_UP'
295-
elif num_procs > 1: # MSSDC
296-
self.S.status.stage = 'IT_COARSE_RECV'
297-
elif num_procs == 1: # SDC
298-
self.S.status.stage = 'IT_FINE'
299-
else:
300-
raise ControllerError("Weird stage in IT_CHECK")
290+
self.S.status.stage = 'IT_FINE'
301291

302292
else:
303293
self.S.levels[0].sweep.compute_end_point()
304294
self.hooks.post_step(step=self.S, level_number=0)
305295
self.S.status.stage = 'DONE'
306296

297+
elif stage == 'IT_FINE':
298+
299+
# do fine sweep
300+
301+
# standard sweep workflow: update nodes, compute residual, log progress
302+
self.hooks.pre_sweep(step=self.S, level_number=0)
303+
for k in range(self.S.levels[0].params.nsweeps):
304+
self.S.levels[0].sweep.update_nodes()
305+
self.S.levels[0].sweep.compute_residual()
306+
self.hooks.post_sweep(step=self.S, level_number=0)
307+
308+
# update stage
309+
if len(self.S.levels) > 1: # MLSDC or PFASST
310+
self.S.status.stage = 'IT_UP'
311+
else: # SDC
312+
self.S.status.stage = 'IT_CHECK'
313+
307314
elif stage == 'IT_UP':
308315

309316
# go up the hierarchy from finest to coarsest level (parallel)
@@ -322,25 +329,19 @@ def pfasst(self, comm, num_procs):
322329
self.S.transfer(source=self.S.levels[l], target=self.S.levels[l + 1])
323330

324331
# update stage
325-
self.S.status.stage = 'IT_COARSE_RECV'
332+
self.S.status.stage = 'IT_COARSE'
326333

327-
elif stage == 'IT_COARSE_RECV':
334+
elif stage == 'IT_COARSE':
328335

329-
# receive from previous step (if not first)
336+
# sweeps on coarsest level (serial/blocking)
330337

338+
# receive from previous step (if not first)
331339
if not self.S.status.first:
332340
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
333341
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
334342
len(self.S.levels) - 1, self.S.status.iter))
335343
self.recv(target=self.S.levels[-1], source=self.S.prev, tag=self.S.status.iter, comm=comm)
336344

337-
# update stage
338-
self.S.status.stage = 'IT_COARSE'
339-
340-
elif stage == 'IT_COARSE':
341-
342-
# sweeps on coarsest level (serial/blocking)
343-
344345
# do the sweep
345346
self.hooks.pre_sweep(step=self.S, level_number=len(self.S.levels) - 1)
346347
for k in range(self.S.levels[-1].params.nsweeps):
@@ -357,10 +358,7 @@ def pfasst(self, comm, num_procs):
357358
self.send(source=self.S.levels[-1], target=self.S.next, tag=self.S.status.iter, comm=comm)
358359

359360
# update stage
360-
if len(self.S.levels) > 1: # MLSDC or PFASST
361-
self.S.status.stage = 'IT_DOWN'
362-
else: # MSSDC
363-
self.S.status.stage = 'IT_CHECK'
361+
self.S.status.stage = 'IT_DOWN'
364362

365363
elif stage == 'IT_DOWN':
366364

@@ -373,32 +371,32 @@ def pfasst(self, comm, num_procs):
373371
self.S.transfer(source=self.S.levels[l], target=self.S.levels[l - 1])
374372
self.S.levels[l - 1].sweep.compute_end_point()
375373

376-
req_send = None
377-
if not self.S.status.last and self.params.fine_comm:
378-
self.logger.debug('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
379-
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
380-
l - 1, self.S.status.iter))
381-
req_send = comm.isend(self.S.levels[l - 1].uend, dest=self.S.next, tag=self.S.status.iter)
382-
383-
if not self.S.status.first and self.params.fine_comm:
384-
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
385-
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
386-
l - 1, self.S.status.iter))
387-
self.recv(target=self.S.levels[l - 1], source=self.S.prev, tag=self.S.status.iter, comm=comm)
388-
389-
if not self.S.status.last and self.params.fine_comm:
390-
req_send.wait()
391-
392374
# on middle levels: do sweep as usual
393375
if l - 1 > 0:
376+
req_send = None
377+
if not self.S.status.last and self.params.fine_comm:
378+
self.logger.debug('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
379+
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
380+
l - 1, self.S.status.iter))
381+
req_send = comm.isend(self.S.levels[l - 1].uend, dest=self.S.next, tag=self.S.status.iter)
382+
383+
if not self.S.status.first and self.params.fine_comm:
384+
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
385+
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
386+
l - 1, self.S.status.iter))
387+
self.recv(target=self.S.levels[l - 1], source=self.S.prev, tag=self.S.status.iter, comm=comm)
388+
389+
if not self.S.status.last and self.params.fine_comm:
390+
req_send.wait()
391+
394392
self.hooks.pre_sweep(step=self.S, level_number=l - 1)
395393
for k in range(self.S.levels[l - 1].params.nsweeps):
396394
self.S.levels[l - 1].sweep.update_nodes()
397395
self.S.levels[l - 1].sweep.compute_residual()
398396
self.hooks.post_sweep(step=self.S, level_number=l - 1)
399397

400398
# update stage
401-
self.S.status.stage = 'IT_FINE'
399+
self.S.status.stage = 'IT_CHECK'
402400

403401
else:
404402

0 commit comments

Comments
 (0)