@@ -35,8 +35,6 @@ def __init__(self, num_procs, controller_params, description):
3535 if self .params .dump_setup :
3636 self .dump_setup (step = self .MS [0 ], controller_params = controller_params , description = description )
3737
38- assert not (len (self .MS ) > 1 and len (self .MS [0 ].levels ) == 1 ), "ERROR: multigrid cannot do MSSDC"
39-
4038 if num_procs > 1 and len (self .MS [0 ].levels ) > 1 :
4139 for S in self .MS :
4240 for L in S .levels :
@@ -142,7 +140,7 @@ def restart_block(self, active_slots, time, u0):
142140 self .MS [p ].init_step (u0 )
143141 # reset some values
144142 self .MS [p ].status .done = False
145- self .MS [p ].status .iter = 1
143+ self .MS [p ].status .iter = 0
146144 self .MS [p ].status .stage = 'SPREAD'
147145 for l in self .MS [p ].levels :
148146 l .tag = None
@@ -267,8 +265,7 @@ def pfasst(self, MS):
267265 if len (S .levels ) > 1 and self .params .predict : # MLSDC or PFASST with predict
268266 S .status .stage = 'PREDICT'
269267 else :
270- self .hooks .pre_iteration (step = S , level_number = 0 )
271- S .status .stage = 'IT_FINE'
268+ S .status .stage = 'IT_CHECK'
272269
273270 return MS
274271
@@ -278,23 +275,6 @@ def pfasst(self, MS):
278275 MS = self .predictor (MS )
279276
280277 for S in MS :
281- # update stage
282- self .hooks .pre_iteration (step = S , level_number = 0 )
283- S .status .stage = 'IT_FINE'
284-
285- return MS
286-
287- elif stage == 'IT_FINE' :
288- # do fine sweep for all steps (virtually parallel)
289-
290- for S in MS :
291- # standard sweep workflow: update nodes, compute residual, log progress
292- self .hooks .pre_sweep (step = S , level_number = 0 )
293- for k in range (S .levels [0 ].params .nsweeps ):
294- S .levels [0 ].sweep .update_nodes ()
295- S .levels [0 ].sweep .compute_residual ()
296- self .hooks .post_sweep (step = S , level_number = 0 )
297-
298278 # update stage
299279 S .status .stage = 'IT_CHECK'
300280
@@ -305,9 +285,25 @@ def pfasst(self, MS):
305285 # check whether to stop iterating (parallel)
306286
307287 for S in MS :
308- self .hooks .post_iteration (step = S , level_number = 0 )
288+
289+ # send updated values forward
290+ if self .params .fine_comm and not S .status .last :
291+ self .logger .debug ('Process %2i provides data on level %2i with tag %s'
292+ % (S .status .slot , 0 , S .status .iter ))
293+ self .send (S .levels [0 ], tag = (0 , S .status .iter , S .status .slot ))
294+
295+ # # receive values
296+ if self .params .fine_comm and not S .status .first :
297+ self .logger .debug ('Process %2i receives from %2i on level %2i with tag %s' %
298+ (S .status .slot , S .prev .status .slot , 0 , S .status .iter ))
299+ self .recv (S .levels [0 ], S .prev .levels [0 ], tag = (0 , S .status .iter , S .prev .status .slot ))
300+
301+ S .levels [0 ].sweep .compute_residual ()
309302 S .status .done = self .check_convergence (S )
310303
304+ if S .status .iter > 0 :
305+ self .hooks .post_iteration (step = S , level_number = 0 )
306+
311307 # if not everyone is ready yet, keep doing stuff
312308 if not all (S .status .done for S in MS ):
313309
@@ -316,11 +312,7 @@ def pfasst(self, MS):
316312 # increment iteration count here (and only here)
317313 S .status .iter += 1
318314 self .hooks .pre_iteration (step = S , level_number = 0 )
319- # multi-level or single-level?
320- if len (S .levels ) > 1 : # MLSDC or PFASST
321- S .status .stage = 'IT_UP'
322- else : # SDC
323- S .status .stage = 'IT_FINE'
315+ S .status .stage = 'IT_FINE'
324316
325317 else :
326318 # if everyone is ready, end
@@ -331,6 +323,25 @@ def pfasst(self, MS):
331323
332324 return MS
333325
326+ elif stage == 'IT_FINE' :
327+ # do fine sweep for all steps (virtually parallel)
328+
329+ for S in MS :
330+ # standard sweep workflow: update nodes, compute residual, log progress
331+ self .hooks .pre_sweep (step = S , level_number = 0 )
332+ for k in range (S .levels [0 ].params .nsweeps ):
333+ S .levels [0 ].sweep .update_nodes ()
334+ S .levels [0 ].sweep .compute_residual ()
335+ self .hooks .post_sweep (step = S , level_number = 0 )
336+
337+ # update stage
338+ if len (S .levels ) > 1 : # MLSDC or PFASST
339+ S .status .stage = 'IT_UP'
340+ else : # SDC
341+ S .status .stage = 'IT_CHECK'
342+
343+ return MS
344+
334345 elif stage == 'IT_UP' :
335346 # go up the hierarchy from finest to coarsest level (parallel)
336347
@@ -396,28 +407,31 @@ def pfasst(self, MS):
396407 # prolong values
397408 S .transfer (source = S .levels [l ], target = S .levels [l - 1 ])
398409
399- # send updated values forward
400- if self .params .fine_comm and not S .status .last :
401- self .logger .debug ('Process %2i provides data on level %2i with tag %s'
402- % (S .status .slot , l - 1 , S .status .iter ))
403- self .send (S .levels [l - 1 ], tag = (l - 1 , S .status .iter , S .status .slot ))
410+ # on middle levels: do communication and sweep as usual
411+ if l - 1 > 0 :
412+
413+ # send updated values forward
414+ if self .params .fine_comm and not S .status .last :
415+ self .logger .debug ('Process %2i provides data on level %2i with tag %s'
416+ % (S .status .slot , l - 1 , S .status .iter ))
417+ self .send (S .levels [l - 1 ], tag = (l - 1 , S .status .iter , S .status .slot ))
418+
419+ # # receive values
420+ if self .params .fine_comm and not S .status .first :
421+ self .logger .debug ('Process %2i receives from %2i on level %2i with tag %s' %
422+ (S .status .slot , S .prev .status .slot , l - 1 , S .status .iter ))
423+ self .recv (S .levels [l - 1 ], S .prev .levels [l - 1 ], tag = (l - 1 , S .status .iter , S .prev .status .slot ))
404424
405- # # receive values
406- if self .params .fine_comm and not S .status .first :
407- self .logger .debug ('Process %2i receives from %2i on level %2i with tag %s' %
408- (S .status .slot , S .prev .status .slot , l - 1 , S .status .iter ))
409- self .recv (S .levels [l - 1 ], S .prev .levels [l - 1 ], tag = (l - 1 , S .status .iter , S .prev .status .slot ))
410425
411- # on middle levels: do sweep as usual
412- if l - 1 > 0 :
413426 self .hooks .pre_sweep (step = S , level_number = l - 1 )
414427 for k in range (S .levels [l - 1 ].params .nsweeps ):
415428 S .levels [l - 1 ].sweep .update_nodes ()
416429 S .levels [l - 1 ].sweep .compute_residual ()
417430 self .hooks .post_sweep (step = S , level_number = l - 1 )
431+ # on finest level, first check for convergence (where we will communication, too)
418432
419433 # update stage
420- S .status .stage = 'IT_FINE '
434+ S .status .stage = 'IT_CHECK '
421435
422436 return MS
423437
0 commit comments