@@ -43,6 +43,8 @@ def __init__(self, controller_params, description, comm):
4343
4444 num_levels = len (self .S .levels )
4545
46+ assert not (num_procs > 1 and num_levels == 1 ), "ERROR: classic cannot do MSSDC"
47+
4648 if num_procs > 1 and num_levels > 1 :
4749 for L in self .S .levels :
4850 if not L .sweep .coll .right_is_node or L .sweep .params .do_coll_update :
@@ -136,7 +138,7 @@ def restart_block(self, size, time, u0):
136138 self .S .init_step (u0 )
137139 # reset some values
138140 self .S .status .done = False
139- self .S .status .iter = 1
141+ self .S .status .iter = 0
140142 self .S .status .stage = 'SPREAD'
141143 for l in self .S .levels :
142144 l .tag = None
@@ -257,7 +259,7 @@ def pfasst(self, comm, num_procs):
257259
258260 # update stage
259261 self .hooks .pre_iteration (step = self .S , level_number = 0 )
260- self .S .status .stage = 'IT_FINE '
262+ self .S .status .stage = 'IT_CHECK '
261263
262264 elif stage == 'IT_FINE' :
263265
@@ -283,21 +285,27 @@ def pfasst(self, comm, num_procs):
283285 self .req_send .append (comm .isend (self .S .levels [0 ].uend , dest = self .S .next , tag = 0 ))
284286
285287 # update stage
286- self .S .status .stage = 'IT_CHECK'
288+ # multi-level or single-level?
289+ if len (self .S .levels ) > 1 : # MLSDC or PFASST
290+ self .S .status .stage = 'IT_UP'
291+ else : # SDC
292+ self .S .status .stage = 'IT_CHECK'
287293
288294 elif stage == 'IT_CHECK' :
289295
290296 # check whether to stop iterating (parallel)
291297
292- self .hooks .post_iteration (step = self .S , level_number = 0 )
293-
294298 # check if an open request of the status send is pending
295299 if self .req_status is not None :
296300 self .req_status .wait ()
297301
298302 # check for convergence or abort
303+ self .S .levels [0 ].sweep .compute_residual ()
299304 self .S .status .done = self .check_convergence (self .S )
300305
306+ if self .S .status .iter > 0 :
307+ self .hooks .post_iteration (step = self .S , level_number = 0 )
308+
301309 # send status forward
302310 if not self .S .status .last :
303311 self .logger .debug ('isend status: status %s, process %s, time %s, target %s, tag %s, iter %s' %
@@ -318,15 +326,7 @@ def pfasst(self, comm, num_procs):
318326 # increment iteration count here (and only here)
319327 self .S .status .iter += 1
320328 self .hooks .pre_iteration (step = self .S , level_number = 0 )
321- # multi-level or single-level?
322- if len (self .S .levels ) > 1 : # MLSDC or PFASST
323- self .S .status .stage = 'IT_UP'
324- elif num_procs > 1 : # MSSDC
325- self .S .status .stage = 'IT_COARSE_RECV'
326- elif num_procs == 1 : # SDC
327- self .S .status .stage = 'IT_FINE'
328- else :
329- raise ControllerError ("Weird stage in IT_CHECK" )
329+ self .S .status .stage = 'IT_FINE'
330330
331331 else :
332332 self .S .levels [0 ].sweep .compute_end_point ()
@@ -397,10 +397,7 @@ def pfasst(self, comm, num_procs):
397397 self .send (source = self .S .levels [- 1 ], target = self .S .next , tag = len (self .S .levels ) - 1 , comm = comm )
398398
399399 # update stage
400- if len (self .S .levels ) > 1 : # MLSDC or PFASST
401- self .S .status .stage = 'IT_DOWN'
402- else : # MSSDC
403- self .S .status .stage = 'IT_CHECK'
400+ self .S .status .stage = 'IT_DOWN'
404401
405402 elif stage == 'IT_DOWN' :
406403
@@ -427,7 +424,7 @@ def pfasst(self, comm, num_procs):
427424 self .hooks .post_sweep (step = self .S , level_number = l - 1 )
428425
429426 # update stage
430- self .S .status .stage = 'IT_FINE '
427+ self .S .status .stage = 'IT_CHECK '
431428
432429 else :
433430
0 commit comments