11import numpy as np
22from mpi4py import MPI
3- import threading
43
54from pySDC .core .Controller import controller
65from pySDC .core .Errors import ControllerError
@@ -32,10 +31,8 @@ def __init__(self, controller_params, description, comm):
3231
3332 # pass communicator for future use
3433 self .comm = comm
35- # add request and thread handler for status send
34+ # add request handler for status send
3635 self .req_status = None
37- self .wait_thread_stat = None
38- self .send_thread_stat = None
3936
4037 num_procs = self .comm .Get_size ()
4138 rank = self .comm .Get_rank ()
@@ -48,10 +45,8 @@ def __init__(self, controller_params, description, comm):
4845
4946 num_levels = len (self .S .levels )
5047
51- # add request handle and thread container for isend
48+ # add request handle container for isend
5249 self .req_send = [None ] * num_levels
53- self .wait_thread = [None ] * num_levels
54- self .send_thread = [None ] * num_levels
5550
5651 if num_procs > 1 and num_levels > 1 :
5752 for L in self .S .levels :
@@ -175,11 +170,7 @@ def restart_block(self, size, time, u0):
175170 for l in self .S .levels :
176171 l .tag = None
177172 self .req_status = None
178- self .wait_thread_stat = None
179- self .send_thread_stat = None
180173 self .req_send = [None ] * len (self .S .levels )
181- self .wait_thread = [None ] * len (self .S .levels )
182- self .send_thread = [None ] * len (self .S .levels )
183174 self .S .status .prev_done = False
184175
185176 self .S .status .time_size = size
@@ -203,19 +194,6 @@ def recv(target, source, tag=None, comm=None):
203194 # re-evaluate f on left interval boundary
204195 target .f [0 ] = target .prob .eval_f (target .u [0 ], target .time )
205196
206- @staticmethod
207- def wait (req ):
208- req .Wait ()
209-
210- @staticmethod
211- def wait_stat (req ):
212- req .wait ()
213-
214- @staticmethod
215- def send (S , lvl , tag , comm ):
216- req = S .levels [lvl ].uend .isend (dest = S .next , tag = tag , comm = comm )
217- req .Wait ()
218-
219197 def predictor (self , comm ):
220198 """
221199 Predictor function, extracted from the stepwise implementation (will be also used by matrix sweppers)
@@ -362,26 +340,21 @@ def pfasst(self, comm, num_procs):
362340
363341 self .hooks .pre_comm (step = self .S , level_number = 0 )
364342
365- if self .send_thread [0 ] is not None :
366- self .send_thread [0 ].join ()
343+ if self .req_send [0 ] is not None :
344+ self .req_send [0 ].wait ()
367345 self .S .levels [0 ].sweep .compute_end_point ()
368346
369347 if not self .S .status .last and self .params .fine_comm :
370348 self .logger .debug ('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
371349 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
372350 0 , self .S .status .iter ))
373- self .send_thread [0 ] = threading .Thread (target = self .send , args = (self .S , 0 , 121212 , comm , ))
374- self .send_thread [0 ].start ()
375- self .send_thread [0 ].join ()
376- # self.req_send[0] = self.S.levels[0].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm)
377- # self.wait_thread[0] = threading.Thread(target=self.wait, args=(self.req_send[0], ))
378- # self.wait_thread[0].start()
351+ self .req_send [0 ] = self .S .levels [0 ].uend .isend (dest = self .S .next , tag = self .S .status .iter , comm = comm )
379352
380353 if not self .S .status .first and not self .S .status .prev_done and self .params .fine_comm :
381354 self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
382355 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
383356 0 , self .S .status .iter ))
384- self .recv (target = self .S .levels [0 ], source = self .S .prev , tag = 121212 , comm = comm )
357+ self .recv (target = self .S .levels [0 ], source = self .S .prev , tag = self . S . status . iter , comm = comm )
385358
386359 self .hooks .post_comm (step = self .S , level_number = 0 )
387360
@@ -398,26 +371,24 @@ def pfasst(self, comm, num_procs):
398371
399372 self .hooks .pre_comm (step = self .S , level_number = 0 )
400373
401- # # check if an open request of the status send is pending
402- # if self.wait_thread_stat is not None:
403- # self.wait_thread_stat.join()
404- #
405- # # recv status
406- # if not self.S.status.first and not self.S.status.prev_done:
407- # self.S.status.prev_done = comm.recv(source=self.S.prev, tag=99)
408- # self.logger.debug('recv status: status %s, process %s, time %s, target %s, tag %s, iter %s' %
409- # (self.S.status.prev_done, self.S.status.slot, self.S.time, self.S.next,
410- # 99, self.S.status.iter))
411- # self.S.status.done = self.S.status.done and self.S.status.prev_done
412- #
413- # # send status forward
414- # if not self.S.status.last:
415- # self.logger.debug('isend status: status %s, process %s, time %s, target %s, tag %s, iter %s' %
416- # (self.S.status.done, self.S.status.slot, self.S.time, self.S.next,
417- # 99, self.S.status.iter))
418- # self.req_status = comm.isend(self.S.status.done, dest=self.S.next, tag=99)
419- # self.wait_thread_stat = threading.Thread(target=self.wait_stat, args=(self.req_status, ))
420- # self.wait_thread_stat.start()
374+ # check if an open request of the status send is pending
375+ if self .req_status is not None :
376+ self .req_status .wait ()
377+
378+ # recv status
379+ if not self .S .status .first and not self .S .status .prev_done :
380+ self .S .status .prev_done = comm .recv (source = self .S .prev , tag = 99 )
381+ self .logger .debug ('recv status: status %s, process %s, time %s, target %s, tag %s, iter %s' %
382+ (self .S .status .prev_done , self .S .status .slot , self .S .time , self .S .next ,
383+ 99 , self .S .status .iter ))
384+ self .S .status .done = self .S .status .done and self .S .status .prev_done
385+
386+ # send status forward
387+ if not self .S .status .last :
388+ self .logger .debug ('isend status: status %s, process %s, time %s, target %s, tag %s, iter %s' %
389+ (self .S .status .done , self .S .status .slot , self .S .time , self .S .next ,
390+ 99 , self .S .status .iter ))
391+ self .req_status = comm .isend (self .S .status .done , dest = self .S .next , tag = 99 )
421392
422393 self .hooks .post_comm (step = self .S , level_number = 0 , add_to_stats = True )
423394
@@ -440,7 +411,7 @@ def pfasst(self, comm, num_procs):
440411
441412 else :
442413
443- # Need to finish all pending isend requests. These will occur for the first active process, since
414+ # Need to finish alll pending isend requests. These will occur for the first active process, since
444415 # in the last iteration the wait statement will not be called ("send and forget")
445416 for req in self .req_send :
446417 if req is not None :
@@ -464,27 +435,22 @@ def pfasst(self, comm, num_procs):
464435
465436 self .hooks .pre_comm (step = self .S , level_number = 0 )
466437
467- if self .send_thread [0 ] is not None :
468- self .send_thread [0 ].join ()
438+ if self .req_send [0 ] is not None :
439+ self .req_send [0 ].wait ()
469440 self .S .levels [0 ].sweep .compute_end_point ()
470441
471442 if not self .S .status .last and self .params .fine_comm :
472443 self .logger .debug ('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
473444 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
474445 0 , self .S .status .iter ))
475- self .send_thread [0 ] = threading .Thread (target = self .send , args = (self .S , 0 , 232323 , comm ,))
476- self .send_thread [0 ].start ()
477- self .send_thread [0 ].join ()
478-
479- # self.req_send[0] = self.S.levels[0].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm)
480- # self.wait_thread[0] = threading.Thread(target=self.wait, args=(self.req_send[0],))
481- # self.wait_thread[0].start()
446+ self .req_send [0 ] = self .S .levels [0 ].uend .isend (dest = self .S .next , tag = self .S .status .iter , comm = comm )
482447
483448 if not self .S .status .first and not self .S .status .prev_done and self .params .fine_comm :
484449 self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
485450 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
486451 0 , self .S .status .iter ))
487- self .recv (target = self .S .levels [0 ], source = self .S .prev , tag = 232323 , comm = comm )
452+ self .recv (target = self .S .levels [0 ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
453+
488454 self .hooks .post_comm (step = self .S , level_number = 0 , add_to_stats = (k == nsweeps - 1 ))
489455
490456 self .hooks .pre_sweep (step = self .S , level_number = 0 )
@@ -510,26 +476,22 @@ def pfasst(self, comm, num_procs):
510476
511477 self .hooks .pre_comm (step = self .S , level_number = l )
512478
513- if self .send_thread [l ] is not None :
514- self .send_thread [l ].join ()
479+ if self .req_send [l ] is not None :
480+ self .req_send [l ].wait ()
515481 self .S .levels [l ].sweep .compute_end_point ()
516482
517483 if not self .S .status .last and self .params .fine_comm :
518484 self .logger .debug ('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
519485 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
520486 l , self .S .status .iter ))
521- self .send_thread [l ] = threading .Thread (target = self .send , args = (self .S , l , 343434 , comm ,))
522- self .send_thread [l ].start ()
523- # self.req_send[l] = self.S.levels[l].uend.isend(dest=self.S.next, tag=self.S.status.iter,
524- # comm=comm)
525- # self.wait_thread[l] = threading.Thread(target=self.wait, args=(self.req_send[l],))
526- # self.wait_thread[l].start()
487+ self .req_send [l ] = self .S .levels [l ].uend .isend (dest = self .S .next , tag = self .S .status .iter ,
488+ comm = comm )
527489
528490 if not self .S .status .first and not self .S .status .prev_done and self .params .fine_comm :
529491 self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
530492 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
531493 l , self .S .status .iter ))
532- self .recv (target = self .S .levels [l ], source = self .S .prev , tag = 343434 , comm = comm )
494+ self .recv (target = self .S .levels [l ], source = self .S .prev , tag = self . S . status . iter , comm = comm )
533495
534496 self .hooks .post_comm (step = self .S , level_number = l )
535497
@@ -554,7 +516,7 @@ def pfasst(self, comm, num_procs):
554516 self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
555517 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
556518 len (self .S .levels ) - 1 , self .S .status .iter ))
557- self .recv (target = self .S .levels [- 1 ], source = self .S .prev , tag = 454545 , comm = comm )
519+ self .recv (target = self .S .levels [- 1 ], source = self .S .prev , tag = self . S . status . iter , comm = comm )
558520 self .hooks .post_comm (step = self .S , level_number = len (self .S .levels ) - 1 )
559521
560522 # do the sweep
@@ -573,7 +535,7 @@ def pfasst(self, comm, num_procs):
573535 self .logger .debug ('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
574536 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
575537 len (self .S .levels ) - 1 , self .S .status .iter ))
576- self .S .levels [- 1 ].uend .send (dest = self .S .next , tag = 454545 , comm = comm )
538+ self .S .levels [- 1 ].uend .send (dest = self .S .next , tag = self . S . status . iter , comm = comm )
577539 self .hooks .post_comm (step = self .S , level_number = len (self .S .levels ) - 1 , add_to_stats = True )
578540
579541 # update stage
@@ -601,27 +563,23 @@ def pfasst(self, comm, num_procs):
601563
602564 self .hooks .pre_comm (step = self .S , level_number = l - 1 )
603565
604- if self .wait_thread [l - 1 ] is not None :
605- self .wait_thread [l - 1 ].join ()
566+ if self .req_send [l - 1 ] is not None :
567+ self .req_send [l - 1 ].wait ()
606568 self .S .levels [l - 1 ].sweep .compute_end_point ()
607569
608570 if not self .S .status .last and self .params .fine_comm :
609571 self .logger .debug ('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
610572 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
611573 l - 1 , self .S .status .iter ))
612- self .send_thread [l - 1 ] = threading .Thread (target = self .send , args = (self .S , l - 1 , 565656 , comm ,))
613- self .send_thread [l - 1 ].start ()
614- # self.req_send[l - 1] = self.S.levels[l - 1].uend.isend(dest=self.S.next,
615- # tag=self.S.status.iter,
616- # comm=comm)
617- # self.wait_thread[l - 1] = threading.Thread(target=self.wait, args=(self.req_send[l - 1],))
618- # self.wait_thread[l - 1].start()
574+ self .req_send [l - 1 ] = self .S .levels [l - 1 ].uend .isend (dest = self .S .next ,
575+ tag = self .S .status .iter ,
576+ comm = comm )
619577
620578 if not self .S .status .first and not self .S .status .prev_done and self .params .fine_comm :
621579 self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
622580 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
623581 l - 1 , self .S .status .iter ))
624- self .recv (target = self .S .levels [l - 1 ], source = self .S .prev , tag = 565656 ,
582+ self .recv (target = self .S .levels [l - 1 ], source = self .S .prev , tag = self . S . status . iter ,
625583 comm = comm )
626584
627585 self .hooks .post_comm (step = self .S , level_number = l - 1 , add_to_stats = (k == nsweeps - 1 ))
0 commit comments