@@ -43,6 +43,8 @@ def __init__(self, controller_params, description, comm):
4343
4444 num_levels = len (self .S .levels )
4545
46+ assert not (num_procs > 1 and num_levels == 1 ), "ERROR: classic cannot do MSSDC"
47+
4648 if num_procs > 1 and num_levels > 1 :
4749 for L in self .S .levels :
4850 if not L .sweep .coll .right_is_node or L .sweep .params .do_coll_update :
@@ -136,7 +138,7 @@ def restart_block(self, size, time, u0):
136138 self .S .init_step (u0 )
137139 # reset some values
138140 self .S .status .done = False
139- self .S .status .iter = 1
141+ self .S .status .iter = 0
140142 self .S .status .stage = 'SPREAD'
141143 for l in self .S .levels :
142144 l .tag = None
@@ -237,67 +239,33 @@ def pfasst(self, comm, num_procs):
237239 # update stage
238240 if len (self .S .levels ) > 1 and self .params .predict : # MLSDC or PFASST with predict
239241 self .S .status .stage = 'PREDICT'
240- elif len (self .S .levels ) > 1 : # MLSDC or PFASST without predict
241- self .hooks .pre_iteration (step = self .S , level_number = 0 )
242- self .S .status .stage = 'IT_FINE'
243- elif num_procs > 1 : # MSSDC
244- self .hooks .pre_iteration (step = self .S , level_number = 0 )
245- self .S .status .stage = 'IT_COARSE'
246- elif num_procs == 1 : # SDC
247- self .hooks .pre_iteration (step = self .S , level_number = 0 )
248- self .S .status .stage = 'IT_FINE'
249242 else :
250- raise ControllerError ( "Don't know what to do after spread, aborting" )
243+ self . S . status . stage = 'IT_CHECK'
251244
252245 elif stage == 'PREDICT' :
253246
254247 # call predictor (serial)
255248
256249 self .predictor (comm )
257250
258- # update stage
259- self .hooks .pre_iteration (step = self .S , level_number = 0 )
260- self .S .status .stage = 'IT_FINE'
261-
262- elif stage == 'IT_FINE' :
263-
264- # do fine sweep
265-
266- # standard sweep workflow: update nodes, compute residual, log progress
267- self .hooks .pre_sweep (step = self .S , level_number = 0 )
268- for k in range (self .S .levels [0 ].params .nsweeps ):
269- self .S .levels [0 ].sweep .update_nodes ()
270- self .S .levels [0 ].sweep .compute_residual ()
271- self .hooks .post_sweep (step = self .S , level_number = 0 )
272-
273- # wait for pending sends before computing uend, if any
274- if len (self .req_send ) > 0 and not self .S .status .last and self .params .fine_comm :
275- self .req_send [0 ].wait ()
276-
277- self .S .levels [0 ].sweep .compute_end_point ()
278-
279- if not self .S .status .last and self .params .fine_comm :
280- self .logger .debug ('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
281- (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
282- 0 , self .S .status .iter ))
283- self .req_send .append (comm .isend (self .S .levels [0 ].uend , dest = self .S .next , tag = 0 ))
284-
285251 # update stage
286252 self .S .status .stage = 'IT_CHECK'
287253
288254 elif stage == 'IT_CHECK' :
289255
290256 # check whether to stop iterating (parallel)
291257
292- self .hooks .post_iteration (step = self .S , level_number = 0 )
293-
294258 # check if an open request of the status send is pending
295259 if self .req_status is not None :
296260 self .req_status .wait ()
297261
298262 # check for convergence or abort
263+ self .S .levels [0 ].sweep .compute_residual ()
299264 self .S .status .done = self .check_convergence (self .S )
300265
266+ if self .S .status .iter > 0 :
267+ self .hooks .post_iteration (step = self .S , level_number = 0 )
268+
301269 # send status forward
302270 if not self .S .status .last :
303271 self .logger .debug ('isend status: status %s, process %s, time %s, target %s, tag %s, iter %s' %
@@ -318,21 +286,43 @@ def pfasst(self, comm, num_procs):
318286 # increment iteration count here (and only here)
319287 self .S .status .iter += 1
320288 self .hooks .pre_iteration (step = self .S , level_number = 0 )
321- # multi-level or single-level?
322- if len (self .S .levels ) > 1 : # MLSDC or PFASST
323- self .S .status .stage = 'IT_UP'
324- elif num_procs > 1 : # MSSDC
325- self .S .status .stage = 'IT_COARSE_RECV'
326- elif num_procs == 1 : # SDC
327- self .S .status .stage = 'IT_FINE'
328- else :
329- raise ControllerError ("Weird stage in IT_CHECK" )
289+ self .S .status .stage = 'IT_FINE'
330290
331291 else :
332292 self .S .levels [0 ].sweep .compute_end_point ()
333293 self .hooks .post_step (step = self .S , level_number = 0 )
334294 self .S .status .stage = 'DONE'
335295
296+ elif stage == 'IT_FINE' :
297+
298+ # do fine sweep
299+
300+ # standard sweep workflow: update nodes, compute residual, log progress
301+ self .hooks .pre_sweep (step = self .S , level_number = 0 )
302+ for k in range (self .S .levels [0 ].params .nsweeps ):
303+ self .S .levels [0 ].sweep .update_nodes ()
304+ self .S .levels [0 ].sweep .compute_residual ()
305+ self .hooks .post_sweep (step = self .S , level_number = 0 )
306+
307+ # wait for pending sends before computing uend, if any
308+ if len (self .req_send ) > 0 and not self .S .status .last and self .params .fine_comm :
309+ self .req_send [0 ].wait ()
310+
311+ self .S .levels [0 ].sweep .compute_end_point ()
312+
313+ if not self .S .status .last and self .params .fine_comm :
314+ self .logger .debug ('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
315+ (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
316+ 0 , self .S .status .iter ))
317+ self .req_send .append (comm .isend (self .S .levels [0 ].uend , dest = self .S .next , tag = 0 ))
318+
319+ # update stage
320+ # multi-level or single-level?
321+ if len (self .S .levels ) > 1 : # MLSDC or PFASST
322+ self .S .status .stage = 'IT_UP'
323+ else : # SDC
324+ self .S .status .stage = 'IT_CHECK'
325+
336326 elif stage == 'IT_UP' :
337327
338328 # go up the hierarchy from finest to coarsest level (parallel)
@@ -397,10 +387,7 @@ def pfasst(self, comm, num_procs):
397387 self .send (source = self .S .levels [- 1 ], target = self .S .next , tag = len (self .S .levels ) - 1 , comm = comm )
398388
399389 # update stage
400- if len (self .S .levels ) > 1 : # MLSDC or PFASST
401- self .S .status .stage = 'IT_DOWN'
402- else : # MSSDC
403- self .S .status .stage = 'IT_CHECK'
390+ self .S .status .stage = 'IT_DOWN'
404391
405392 elif stage == 'IT_DOWN' :
406393
@@ -427,7 +414,7 @@ def pfasst(self, comm, num_procs):
427414 self .hooks .post_sweep (step = self .S , level_number = l - 1 )
428415
429416 # update stage
430- self .S .status .stage = 'IT_FINE '
417+ self .S .status .stage = 'IT_CHECK '
431418
432419 else :
433420
0 commit comments