@@ -240,7 +240,7 @@ def pfasst(self, comm, num_procs):
240240 if len (self .S .levels ) > 1 and self .params .predict : # MLSDC or PFASST with predict
241241 self .S .status .stage = 'PREDICT'
242242 else :
243- self .S .status .stage = 'IT_CHECK '
243+ self .S .status .stage = 'IT_FINE '
244244
245245 elif stage == 'PREDICT' :
246246
@@ -250,7 +250,7 @@ def pfasst(self, comm, num_procs):
250250
251251 # update stage
252252 self .hooks .pre_iteration (step = self .S , level_number = 0 )
253- self .S .status .stage = 'IT_CHECK '
253+ self .S .status .stage = 'IT_FINE '
254254
255255 elif stage == 'IT_CHECK' :
256256
@@ -286,7 +286,11 @@ def pfasst(self, comm, num_procs):
286286 # increment iteration count here (and only here)
287287 self .S .status .iter += 1
288288 self .hooks .pre_iteration (step = self .S , level_number = 0 )
289- self .S .status .stage = 'IT_FINE'
289+ if len (self .S .levels ) > 1 : # MLSDC or PFASST
290+ self .S .status .stage = 'IT_UP'
291+ else : # SDC
292+ self .S .status .stage = 'IT_FINE'
293+
290294
291295 else :
292296 self .S .levels [0 ].sweep .compute_end_point ()
@@ -295,34 +299,74 @@ def pfasst(self, comm, num_procs):
295299
296300 elif stage == 'IT_FINE' :
297301
302+ nsweeps = self .S .levels [0 ].params .nsweeps
303+
298304 # do fine sweep
305+ for k in range (nsweeps ):
299306
300- # standard sweep workflow: update nodes, compute residual, log progress
301- self .hooks .pre_sweep (step = self .S , level_number = 0 )
302- for k in range (self .S .levels [0 ].params .nsweeps ):
307+ self .hooks .pre_sweep (step = self .S , level_number = 0 )
303308 self .S .levels [0 ].sweep .update_nodes ()
304- self .S .levels [0 ].sweep .compute_residual ()
305- self .hooks .post_sweep (step = self .S , level_number = 0 )
309+
310+ req_send = None
311+ self .S .levels [0 ].sweep .compute_end_point ()
312+ if not self .S .status .last and self .params .fine_comm :
313+ self .logger .debug ('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
314+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
315+ 0 , self .S .status .iter ))
316+ req_send = comm .isend (self .S .levels [0 ].uend , dest = self .S .next , tag = self .S .status .iter )
317+
318+ if not self .S .status .first and self .params .fine_comm :
319+ self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
320+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
321+ 0 , self .S .status .iter ))
322+ self .recv (target = self .S .levels [0 ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
323+
324+ if not self .S .status .last and self .params .fine_comm :
325+ req_send .wait ()
326+
327+ self .S .levels [0 ].sweep .compute_residual ()
328+ self .hooks .post_sweep (step = self .S , level_number = 0 )
306329
307330 # update stage
308- if len (self .S .levels ) > 1 : # MLSDC or PFASST
309- self .S .status .stage = 'IT_UP'
310- else : # SDC
311- self .S .status .stage = 'IT_CHECK'
331+ self .S .status .stage = 'IT_CHECK'
312332
313333 elif stage == 'IT_UP' :
314334
315335 # go up the hierarchy from finest to coarsest level (parallel)
316336
337+
338+
317339 self .S .transfer (source = self .S .levels [0 ], target = self .S .levels [1 ])
318340
319341 # sweep and send on middle levels (not on finest, not on coarsest, though)
320342 for l in range (1 , len (self .S .levels ) - 1 ):
321- self .hooks .pre_sweep (step = self .S , level_number = l )
322- for k in range (self .S .levels [l ].params .nsweeps ):
343+
344+ nsweeps = self .S .levels [l ].params .nsweeps
345+
346+ for k in range (nsweeps ):
347+
348+ self .hooks .pre_sweep (step = self .S , level_number = l )
323349 self .S .levels [l ].sweep .update_nodes ()
324- self .S .levels [l ].sweep .compute_residual ()
325- self .hooks .post_sweep (step = self .S , level_number = l )
350+
351+ req_send = None
352+ self .S .levels [l ].sweep .compute_end_point ()
353+ if not self .S .status .last and self .params .fine_comm :
354+ self .logger .debug ('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
355+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
356+ l , self .S .status .iter ))
357+ req_send = comm .isend (self .S .levels [l ].uend , dest = self .S .next , tag = self .S .status .iter )
358+
359+ if not self .S .status .first and self .params .fine_comm :
360+ self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
361+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
362+ l , self .S .status .iter ))
363+ self .recv (target = self .S .levels [l ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
364+
365+ if not self .S .status .last and self .params .fine_comm :
366+ req_send .wait ()
367+
368+ self .S .levels [l ].sweep .compute_residual ()
369+ self .hooks .post_sweep (step = self .S , level_number = l )
326370
327371 # transfer further up the hierarchy
328372 self .S .transfer (source = self .S .levels [l ], target = self .S .levels [l + 1 ])
@@ -343,8 +387,10 @@ def pfasst(self, comm, num_procs):
343387
344388 # do the sweep
345389 self .hooks .pre_sweep (step = self .S , level_number = len (self .S .levels ) - 1 )
346- for k in range (self .S .levels [- 1 ].params .nsweeps ):
347- self .S .levels [- 1 ].sweep .update_nodes ()
390+ assert self .S .levels [- 1 ].params .nsweeps == 1 , \
391+ 'ERROR: this controller can only work with one sweep on the coarse level, got %s' % \
392+ self .S .levels [- 1 ].params .nsweeps
393+ self .S .levels [- 1 ].sweep .update_nodes ()
348394 self .S .levels [- 1 ].sweep .compute_residual ()
349395 self .hooks .post_sweep (step = self .S , level_number = len (self .S .levels ) - 1 )
350396 self .S .levels [- 1 ].sweep .compute_end_point ()
@@ -370,32 +416,51 @@ def pfasst(self, comm, num_procs):
370416 self .S .transfer (source = self .S .levels [l ], target = self .S .levels [l - 1 ])
371417 self .S .levels [l - 1 ].sweep .compute_end_point ()
372418
373- # on middle levels: do sweep as usual
374- if l - 1 > 0 :
375- req_send = None
376- if not self .S .status .last and self .params .fine_comm :
377- self .logger .debug ('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
378- (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
379- l - 1 , self .S .status .iter ))
380- req_send = comm .isend (self .S .levels [l - 1 ].uend , dest = self .S .next , tag = self .S .status .iter )
419+ req_send = None
420+ if not self .S .status .last and self .params .fine_comm :
421+ self .logger .debug ('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
422+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
423+ l - 1 , self .S .status .iter ))
424+ req_send = comm .isend (self .S .levels [l - 1 ].uend , dest = self .S .next , tag = self .S .status .iter )
381425
382- if not self .S .status .first and self .params .fine_comm :
383- self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
384- (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
385- l - 1 , self .S .status .iter ))
386- self .recv (target = self .S .levels [l - 1 ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
426+ if not self .S .status .first and self .params .fine_comm :
427+ self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
428+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
429+ l - 1 , self .S .status .iter ))
430+ self .recv (target = self .S .levels [l - 1 ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
387431
388- if not self .S .status .last and self .params .fine_comm :
389- req_send .wait ()
432+ if not self .S .status .last and self .params .fine_comm :
433+ req_send .wait ()
434+
435+ # on middle levels: do sweep as usual
436+ if l - 1 > 0 :
390437
391- self .hooks .pre_sweep (step = self .S , level_number = l - 1 )
392438 for k in range (self .S .levels [l - 1 ].params .nsweeps ):
439+
440+ self .hooks .pre_sweep (step = self .S , level_number = l - 1 )
393441 self .S .levels [l - 1 ].sweep .update_nodes ()
394- self .S .levels [l - 1 ].sweep .compute_residual ()
395- self .hooks .post_sweep (step = self .S , level_number = l - 1 )
442+
443+ req_send = None
444+ if not self .S .status .last and self .params .fine_comm :
445+ self .logger .debug ('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
446+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
447+ l - 1 , self .S .status .iter ))
448+ req_send = comm .isend (self .S .levels [l - 1 ].uend , dest = self .S .next , tag = self .S .status .iter )
449+
450+ if not self .S .status .first and self .params .fine_comm :
451+ self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
452+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
453+ l - 1 , self .S .status .iter ))
454+ self .recv (target = self .S .levels [l - 1 ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
455+
456+ if not self .S .status .last and self .params .fine_comm :
457+ req_send .wait ()
458+
459+ self .S .levels [l - 1 ].sweep .compute_residual ()
460+ self .hooks .post_sweep (step = self .S , level_number = l - 1 )
396461
397462 # update stage
398- self .S .status .stage = 'IT_CHECK '
463+ self .S .status .stage = 'IT_FINE '
399464
400465 else :
401466
0 commit comments