@@ -240,7 +240,7 @@ def pfasst(self, comm, num_procs):
240240 if len (self .S .levels ) > 1 and self .params .predict : # MLSDC or PFASST with predict
241241 self .S .status .stage = 'PREDICT'
242242 else :
243- self .S .status .stage = 'IT_CHECK '
243+ self .S .status .stage = 'IT_FINE '
244244
245245 elif stage == 'PREDICT' :
246246
@@ -250,7 +250,7 @@ def pfasst(self, comm, num_procs):
250250
251251 # update stage
252252 self .hooks .pre_iteration (step = self .S , level_number = 0 )
253- self .S .status .stage = 'IT_CHECK '
253+ self .S .status .stage = 'IT_FINE '
254254
255255 elif stage == 'IT_CHECK' :
256256
@@ -286,7 +286,10 @@ def pfasst(self, comm, num_procs):
286286 # increment iteration count here (and only here)
287287 self .S .status .iter += 1
288288 self .hooks .pre_iteration (step = self .S , level_number = 0 )
289- self .S .status .stage = 'IT_FINE'
289+ if len (self .S .levels ) > 1 : # MLSDC or PFASST
290+ self .S .status .stage = 'IT_UP'
291+ else : # SDC
292+ self .S .status .stage = 'IT_FINE'
290293
291294 else :
292295 self .S .levels [0 ].sweep .compute_end_point ()
@@ -295,20 +298,36 @@ def pfasst(self, comm, num_procs):
295298
296299 elif stage == 'IT_FINE' :
297300
301+ nsweeps = self .S .levels [0 ].params .nsweeps
302+
298303 # do fine sweep
304+ for k in range (nsweeps ):
299305
300- # standard sweep workflow: update nodes, compute residual, log progress
301- self .hooks .pre_sweep (step = self .S , level_number = 0 )
302- for k in range (self .S .levels [0 ].params .nsweeps ):
306+ self .hooks .pre_sweep (step = self .S , level_number = 0 )
303307 self .S .levels [0 ].sweep .update_nodes ()
304- self .S .levels [0 ].sweep .compute_residual ()
305- self .hooks .post_sweep (step = self .S , level_number = 0 )
308+
309+ req_send = None
310+ self .S .levels [0 ].sweep .compute_end_point ()
311+ if not self .S .status .last and self .params .fine_comm :
312+ self .logger .debug ('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
313+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
314+ 0 , self .S .status .iter ))
315+ req_send = comm .isend (self .S .levels [0 ].uend , dest = self .S .next , tag = self .S .status .iter )
316+
317+ if not self .S .status .first and self .params .fine_comm :
318+ self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
319+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
320+ 0 , self .S .status .iter ))
321+ self .recv (target = self .S .levels [0 ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
322+
323+ if not self .S .status .last and self .params .fine_comm :
324+ req_send .wait ()
325+
326+ self .S .levels [0 ].sweep .compute_residual ()
327+ self .hooks .post_sweep (step = self .S , level_number = 0 )
306328
307329 # update stage
308- if len (self .S .levels ) > 1 : # MLSDC or PFASST
309- self .S .status .stage = 'IT_UP'
310- else : # SDC
311- self .S .status .stage = 'IT_CHECK'
330+ self .S .status .stage = 'IT_CHECK'
312331
313332 elif stage == 'IT_UP' :
314333
@@ -318,11 +337,33 @@ def pfasst(self, comm, num_procs):
318337
319338 # sweep and send on middle levels (not on finest, not on coarsest, though)
320339 for l in range (1 , len (self .S .levels ) - 1 ):
321- self .hooks .pre_sweep (step = self .S , level_number = l )
322- for k in range (self .S .levels [l ].params .nsweeps ):
340+
341+ nsweeps = self .S .levels [l ].params .nsweeps
342+
343+ for k in range (nsweeps ):
344+
345+ self .hooks .pre_sweep (step = self .S , level_number = l )
323346 self .S .levels [l ].sweep .update_nodes ()
324- self .S .levels [l ].sweep .compute_residual ()
325- self .hooks .post_sweep (step = self .S , level_number = l )
347+
348+ req_send = None
349+ self .S .levels [l ].sweep .compute_end_point ()
350+ if not self .S .status .last and self .params .fine_comm :
351+ self .logger .debug ('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
352+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
353+ l , self .S .status .iter ))
354+ req_send = comm .isend (self .S .levels [l ].uend , dest = self .S .next , tag = self .S .status .iter )
355+
356+ if not self .S .status .first and self .params .fine_comm :
357+ self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
358+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
359+ l , self .S .status .iter ))
360+ self .recv (target = self .S .levels [l ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
361+
362+ if not self .S .status .last and self .params .fine_comm :
363+ req_send .wait ()
364+
365+ self .S .levels [l ].sweep .compute_residual ()
366+ self .hooks .post_sweep (step = self .S , level_number = l )
326367
327368 # transfer further up the hierarchy
328369 self .S .transfer (source = self .S .levels [l ], target = self .S .levels [l + 1 ])
@@ -343,8 +384,10 @@ def pfasst(self, comm, num_procs):
343384
344385 # do the sweep
345386 self .hooks .pre_sweep (step = self .S , level_number = len (self .S .levels ) - 1 )
346- for k in range (self .S .levels [- 1 ].params .nsweeps ):
347- self .S .levels [- 1 ].sweep .update_nodes ()
387+ assert self .S .levels [- 1 ].params .nsweeps == 1 , \
388+ 'ERROR: this controller can only work with one sweep on the coarse level, got %s' % \
389+ self .S .levels [- 1 ].params .nsweeps
390+ self .S .levels [- 1 ].sweep .update_nodes ()
348391 self .S .levels [- 1 ].sweep .compute_residual ()
349392 self .hooks .post_sweep (step = self .S , level_number = len (self .S .levels ) - 1 )
350393 self .S .levels [- 1 ].sweep .compute_end_point ()
@@ -370,32 +413,52 @@ def pfasst(self, comm, num_procs):
370413 self .S .transfer (source = self .S .levels [l ], target = self .S .levels [l - 1 ])
371414 self .S .levels [l - 1 ].sweep .compute_end_point ()
372415
373- # on middle levels: do sweep as usual
374- if l - 1 > 0 :
375- req_send = None
376- if not self .S .status .last and self .params .fine_comm :
377- self .logger .debug ('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
378- (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
379- l - 1 , self .S .status .iter ))
380- req_send = comm .isend (self .S .levels [l - 1 ].uend , dest = self .S .next , tag = self .S .status .iter )
416+ req_send = None
417+ if not self .S .status .last and self .params .fine_comm :
418+ self .logger .debug ('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
419+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
420+ l - 1 , self .S .status .iter ))
421+ req_send = comm .isend (self .S .levels [l - 1 ].uend , dest = self .S .next , tag = self .S .status .iter )
381422
382- if not self .S .status .first and self .params .fine_comm :
383- self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
384- (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
385- l - 1 , self .S .status .iter ))
386- self .recv (target = self .S .levels [l - 1 ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
423+ if not self .S .status .first and self .params .fine_comm :
424+ self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
425+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
426+ l - 1 , self .S .status .iter ))
427+ self .recv (target = self .S .levels [l - 1 ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
387428
388- if not self .S .status .last and self .params .fine_comm :
389- req_send .wait ()
429+ if not self .S .status .last and self .params .fine_comm :
430+ req_send .wait ()
431+
432+ # on middle levels: do sweep as usual
433+ if l - 1 > 0 :
390434
391- self .hooks .pre_sweep (step = self .S , level_number = l - 1 )
392435 for k in range (self .S .levels [l - 1 ].params .nsweeps ):
436+
437+ self .hooks .pre_sweep (step = self .S , level_number = l - 1 )
393438 self .S .levels [l - 1 ].sweep .update_nodes ()
394- self .S .levels [l - 1 ].sweep .compute_residual ()
395- self .hooks .post_sweep (step = self .S , level_number = l - 1 )
439+
440+ req_send = None
441+ if not self .S .status .last and self .params .fine_comm :
442+ self .logger .debug ('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
443+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
444+ l - 1 , self .S .status .iter ))
445+ req_send = comm .isend (self .S .levels [l - 1 ].uend , dest = self .S .next , tag = self .S .status .iter )
446+
447+ if not self .S .status .first and self .params .fine_comm :
448+ self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
449+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
450+ l - 1 , self .S .status .iter ))
451+ self .recv (target = self .S .levels [l - 1 ], source = self .S .prev , tag = self .S .status .iter ,
452+ comm = comm )
453+
454+ if not self .S .status .last and self .params .fine_comm :
455+ req_send .wait ()
456+
457+ self .S .levels [l - 1 ].sweep .compute_residual ()
458+ self .hooks .post_sweep (step = self .S , level_number = l - 1 )
396459
397460 # update stage
398- self .S .status .stage = 'IT_CHECK '
461+ self .S .status .stage = 'IT_FINE '
399462
400463 else :
401464
0 commit comments