@@ -31,8 +31,6 @@ def __init__(self, controller_params, description, comm):
3131
3232 # pass communicator for future use
3333 self .comm = comm
34- # add request handle container for isend
35- self .req_send = []
3634 # add request handler for status send
3735 self .req_status = None
3836
@@ -44,6 +42,9 @@ def __init__(self, controller_params, description, comm):
4442
4543 num_levels = len (self .S .levels )
4644
45+ # add request handle container for isend
46+ self .req_send = [None ] * num_levels
47+
4748 if num_procs > 1 and num_levels > 1 :
4849 for L in self .S .levels :
4950 if not L .sweep .coll .right_is_node or L .sweep .params .do_coll_update :
@@ -161,7 +162,7 @@ def restart_block(self, size, time, u0):
161162 for l in self .S .levels :
162163 l .tag = None
163164 self .req_status = None
164- self .req_send = []
165+ self .req_send = [None ] * len ( self . S . levels )
165166 self .S .status .prev_done = False
166167
167168 for lvl in self .S .levels :
@@ -328,22 +329,21 @@ def pfasst(self, comm, num_procs):
328329 # check whether to stop iterating (parallel)
329330
330331 self .hooks .pre_comm (step = self .S , level_number = 0 )
331- req_send = None
332332 self .S .levels [0 ].sweep .compute_end_point ()
333333 if not self .S .status .last and self .params .fine_comm :
334+ if self .req_send [0 ] is not None :
335+ self .req_send [0 ].wait ()
334336 self .logger .debug ('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
335337 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
336338 0 , self .S .status .iter ))
337- req_send = self .S .levels [0 ].uend .isend (dest = self .S .next , tag = self .S .status .iter , comm = comm )
339+ self . req_send [ 0 ] = self .S .levels [0 ].uend .isend (dest = self .S .next , tag = self .S .status .iter , comm = comm )
338340
339341 if not self .S .status .first and not self .S .status .prev_done and self .params .fine_comm :
340342 self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
341343 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
342344 0 , self .S .status .iter ))
343345 self .recv (target = self .S .levels [0 ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
344346
345- if not self .S .status .last and self .params .fine_comm :
346- req_send .wait ()
347347 self .hooks .post_comm (step = self .S , level_number = 0 )
348348
349349 self .S .levels [0 ].sweep .compute_residual ()
@@ -414,22 +414,21 @@ def pfasst(self, comm, num_procs):
414414 self .S .levels [0 ].status .sweep += 1
415415
416416 self .hooks .pre_comm (step = self .S , level_number = 0 )
417- req_send = None
418417 self .S .levels [0 ].sweep .compute_end_point ()
419418 if not self .S .status .last and self .params .fine_comm :
419+ if self .req_send [0 ] is not None :
420+ self .req_send [0 ].wait ()
420421 self .logger .debug ('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
421422 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
422423 0 , self .S .status .iter ))
423- req_send = self .S .levels [0 ].uend .isend (dest = self .S .next , tag = self .S .status .iter , comm = comm )
424+ self . req_send [ 0 ] = self .S .levels [0 ].uend .isend (dest = self .S .next , tag = self .S .status .iter , comm = comm )
424425
425426 if not self .S .status .first and not self .S .status .prev_done and self .params .fine_comm :
426427 self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
427428 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
428429 0 , self .S .status .iter ))
429430 self .recv (target = self .S .levels [0 ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
430431
431- if not self .S .status .last and self .params .fine_comm :
432- req_send .wait ()
433432 self .hooks .post_comm (step = self .S , level_number = 0 , add_to_stats = (k == nsweeps - 1 ))
434433
435434 self .hooks .pre_sweep (step = self .S , level_number = 0 )
@@ -454,22 +453,21 @@ def pfasst(self, comm, num_procs):
454453 for k in range (nsweeps ):
455454
456455 self .hooks .pre_comm (step = self .S , level_number = l )
457- req_send = None
458456 self .S .levels [l ].sweep .compute_end_point ()
459457 if not self .S .status .last and self .params .fine_comm :
458+ if self .req_send [l ] is not None :
459+ self .req_send [l ].wait ()
460460 self .logger .debug ('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
461461 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
462462 l , self .S .status .iter ))
463- req_send = self .S .levels [l ].uend .isend (dest = self .S .next , tag = self .S .status .iter , comm = comm )
463+ self . req_send [ l ] = self .S .levels [l ].uend .isend (dest = self .S .next , tag = self .S .status .iter , comm = comm )
464464
465465 if not self .S .status .first and not self .S .status .prev_done and self .params .fine_comm :
466466 self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
467467 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
468468 l , self .S .status .iter ))
469469 self .recv (target = self .S .levels [l ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
470470
471- if not self .S .status .last and self .params .fine_comm :
472- req_send .wait ()
473471 self .hooks .post_comm (step = self .S , level_number = l )
474472
475473 self .hooks .pre_sweep (step = self .S , level_number = l )
@@ -539,13 +537,14 @@ def pfasst(self, comm, num_procs):
539537 for k in range (nsweeps ):
540538
541539 self .hooks .pre_comm (step = self .S , level_number = l - 1 )
542- req_send = None
543540 self .S .levels [l - 1 ].sweep .compute_end_point ()
544541 if not self .S .status .last and self .params .fine_comm :
542+ if self .req_send [l - 1 ] is not None :
543+ self .req_send [l - 1 ].wait ()
545544 self .logger .debug ('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
546545 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
547546 l - 1 , self .S .status .iter ))
548- req_send = self .S .levels [l - 1 ].uend .isend (dest = self .S .next , tag = self .S .status .iter ,
547+ self . req_send [ l - 1 ] = self .S .levels [l - 1 ].uend .isend (dest = self .S .next , tag = self .S .status .iter ,
549548 comm = comm )
550549
551550 if not self .S .status .first and not self .S .status .prev_done and self .params .fine_comm :
@@ -555,8 +554,6 @@ def pfasst(self, comm, num_procs):
555554 self .recv (target = self .S .levels [l - 1 ], source = self .S .prev , tag = self .S .status .iter ,
556555 comm = comm )
557556
558- if not self .S .status .last and self .params .fine_comm :
559- req_send .wait ()
560557 self .hooks .post_comm (step = self .S , level_number = l - 1 , add_to_stats = (k == nsweeps - 1 ))
561558
562559 self .hooks .pre_sweep (step = self .S , level_number = l - 1 )
0 commit comments