@@ -139,7 +139,7 @@ def restart_block(self, size, time, u0):
139139 self .S .init_step (u0 )
140140 # reset some values
141141 self .S .status .done = False
142- self .S .status .iter = 1
142+ self .S .status .iter = 0
143143 self .S .status .stage = 'SPREAD'
144144 for l in self .S .levels :
145145 l .tag = None
@@ -239,17 +239,9 @@ def pfasst(self, comm, num_procs):
239239 # update stage
240240 if len (self .S .levels ) > 1 and self .params .predict : # MLSDC or PFASST with predict
241241 self .S .status .stage = 'PREDICT'
242- elif len (self .S .levels ) > 1 : # MLSDC or PFASST without predict
243- self .hooks .pre_iteration (step = self .S , level_number = 0 )
244- self .S .status .stage = 'IT_FINE'
245- elif num_procs > 1 and len (self .S .levels ) == 1 : # MSSDC
246- self .hooks .pre_iteration (step = self .S , level_number = 0 )
247- self .S .status .stage = 'IT_COARSE'
248- elif num_procs == 1 : # SDC
249- self .hooks .pre_iteration (step = self .S , level_number = 0 )
250- self .S .status .stage = 'IT_FINE'
251242 else :
252- raise ControllerError ("Don't know what to do after spread, aborting" )
243+ self .S .status .stage = 'IT_CHECK'
244+
253245
254246 elif stage == 'PREDICT' :
255247
@@ -259,51 +251,66 @@ def pfasst(self, comm, num_procs):
259251
260252 # update stage
261253 self .hooks .pre_iteration (step = self .S , level_number = 0 )
262- self .S .status .stage = 'IT_FINE'
263-
264- elif stage == 'IT_FINE' :
265-
266- # do fine sweep
267-
268- # standard sweep workflow: update nodes, compute residual, log progress
269- self .hooks .pre_sweep (step = self .S , level_number = 0 )
270- for k in range (self .S .levels [0 ].params .nsweeps ):
271- self .S .levels [0 ].sweep .update_nodes ()
272- self .S .levels [0 ].sweep .compute_residual ()
273- self .hooks .post_sweep (step = self .S , level_number = 0 )
274-
275- # update stage
276254 self .S .status .stage = 'IT_CHECK'
277255
278256 elif stage == 'IT_CHECK' :
279257
280258 # check whether to stop iterating (parallel)
281259
282- self .hooks .post_iteration (step = self .S , level_number = 0 )
260+ req_send = None
261+ self .S .levels [0 ].sweep .compute_end_point ()
262+ if not self .S .status .last and self .params .fine_comm :
263+ self .logger .debug ('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
264+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
265+ 0 , self .S .status .iter ))
266+ req_send = comm .isend (self .S .levels [0 ].uend , dest = self .S .next , tag = self .S .status .iter )
267+
268+ if not self .S .status .first and self .params .fine_comm :
269+ self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
270+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
271+ 0 , self .S .status .iter ))
272+ self .recv (target = self .S .levels [0 ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
273+
274+ if not self .S .status .last and self .params .fine_comm :
275+ req_send .wait ()
276+
277+ self .S .levels [0 ].sweep .compute_residual ()
283278 self .S .status .done = self .check_convergence (self .S )
284279 all_done = comm .allgather (self .S .status .done )
285280
281+ if self .S .status .iter > 0 :
282+ self .hooks .post_iteration (step = self .S , level_number = 0 )
283+
286284 # if not everyone is ready yet, keep doing stuff
287285 if not all (all_done ):
288286 self .S .status .done = False
289287 # increment iteration count here (and only here)
290288 self .S .status .iter += 1
291289 self .hooks .pre_iteration (step = self .S , level_number = 0 )
292- # multi-level or single-level?
293- if len (self .S .levels ) > 1 : # MLSDC or PFASST
294- self .S .status .stage = 'IT_UP'
295- elif num_procs > 1 : # MSSDC
296- self .S .status .stage = 'IT_COARSE_RECV'
297- elif num_procs == 1 : # SDC
298- self .S .status .stage = 'IT_FINE'
299- else :
300- raise ControllerError ("Weird stage in IT_CHECK" )
290+ self .S .status .stage = 'IT_FINE'
301291
302292 else :
303293 self .S .levels [0 ].sweep .compute_end_point ()
304294 self .hooks .post_step (step = self .S , level_number = 0 )
305295 self .S .status .stage = 'DONE'
306296
297+ elif stage == 'IT_FINE' :
298+
299+ # do fine sweep
300+
301+ # standard sweep workflow: update nodes, compute residual, log progress
302+ self .hooks .pre_sweep (step = self .S , level_number = 0 )
303+ for k in range (self .S .levels [0 ].params .nsweeps ):
304+ self .S .levels [0 ].sweep .update_nodes ()
305+ self .S .levels [0 ].sweep .compute_residual ()
306+ self .hooks .post_sweep (step = self .S , level_number = 0 )
307+
308+ # update stage
309+ if len (self .S .levels ) > 1 : # MLSDC or PFASST
310+ self .S .status .stage = 'IT_UP'
311+ else : # SDC
312+ self .S .status .stage = 'IT_CHECK'
313+
307314 elif stage == 'IT_UP' :
308315
309316 # go up the hierarchy from finest to coarsest level (parallel)
@@ -322,25 +329,19 @@ def pfasst(self, comm, num_procs):
322329 self .S .transfer (source = self .S .levels [l ], target = self .S .levels [l + 1 ])
323330
324331 # update stage
325- self .S .status .stage = 'IT_COARSE_RECV '
332+ self .S .status .stage = 'IT_COARSE '
326333
327- elif stage == 'IT_COARSE_RECV ' :
334+ elif stage == 'IT_COARSE ' :
328335
329- # receive from previous step (if not first )
336+ # sweeps on coarsest level (serial/blocking )
330337
338+ # receive from previous step (if not first)
331339 if not self .S .status .first :
332340 self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
333341 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
334342 len (self .S .levels ) - 1 , self .S .status .iter ))
335343 self .recv (target = self .S .levels [- 1 ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
336344
337- # update stage
338- self .S .status .stage = 'IT_COARSE'
339-
340- elif stage == 'IT_COARSE' :
341-
342- # sweeps on coarsest level (serial/blocking)
343-
344345 # do the sweep
345346 self .hooks .pre_sweep (step = self .S , level_number = len (self .S .levels ) - 1 )
346347 for k in range (self .S .levels [- 1 ].params .nsweeps ):
@@ -357,10 +358,7 @@ def pfasst(self, comm, num_procs):
357358 self .send (source = self .S .levels [- 1 ], target = self .S .next , tag = self .S .status .iter , comm = comm )
358359
359360 # update stage
360- if len (self .S .levels ) > 1 : # MLSDC or PFASST
361- self .S .status .stage = 'IT_DOWN'
362- else : # MSSDC
363- self .S .status .stage = 'IT_CHECK'
361+ self .S .status .stage = 'IT_DOWN'
364362
365363 elif stage == 'IT_DOWN' :
366364
@@ -373,32 +371,32 @@ def pfasst(self, comm, num_procs):
373371 self .S .transfer (source = self .S .levels [l ], target = self .S .levels [l - 1 ])
374372 self .S .levels [l - 1 ].sweep .compute_end_point ()
375373
376- req_send = None
377- if not self .S .status .last and self .params .fine_comm :
378- self .logger .debug ('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
379- (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
380- l - 1 , self .S .status .iter ))
381- req_send = comm .isend (self .S .levels [l - 1 ].uend , dest = self .S .next , tag = self .S .status .iter )
382-
383- if not self .S .status .first and self .params .fine_comm :
384- self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
385- (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
386- l - 1 , self .S .status .iter ))
387- self .recv (target = self .S .levels [l - 1 ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
388-
389- if not self .S .status .last and self .params .fine_comm :
390- req_send .wait ()
391-
392374 # on middle levels: do sweep as usual
393375 if l - 1 > 0 :
376+ req_send = None
377+ if not self .S .status .last and self .params .fine_comm :
378+ self .logger .debug ('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
379+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
380+ l - 1 , self .S .status .iter ))
381+ req_send = comm .isend (self .S .levels [l - 1 ].uend , dest = self .S .next , tag = self .S .status .iter )
382+
383+ if not self .S .status .first and self .params .fine_comm :
384+ self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
385+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
386+ l - 1 , self .S .status .iter ))
387+ self .recv (target = self .S .levels [l - 1 ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
388+
389+ if not self .S .status .last and self .params .fine_comm :
390+ req_send .wait ()
391+
394392 self .hooks .pre_sweep (step = self .S , level_number = l - 1 )
395393 for k in range (self .S .levels [l - 1 ].params .nsweeps ):
396394 self .S .levels [l - 1 ].sweep .update_nodes ()
397395 self .S .levels [l - 1 ].sweep .compute_residual ()
398396 self .hooks .post_sweep (step = self .S , level_number = l - 1 )
399397
400398 # update stage
401- self .S .status .stage = 'IT_FINE '
399+ self .S .status .stage = 'IT_CHECK '
402400
403401 else :
404402
0 commit comments