11import numpy as np
22from mpi4py import MPI
3+ import threading
34
45from pySDC .core .Controller import controller
56from pySDC .core .Errors import ControllerError
@@ -31,8 +32,10 @@ def __init__(self, controller_params, description, comm):
3132
3233 # pass communicator for future use
3334 self .comm = comm
34- # add request handler for status send
35+ # add request and thread handler for status send
3536 self .req_status = None
37+ self .wait_thread_stat = None
38+ self .send_thread_stat = None
3639
3740 num_procs = self .comm .Get_size ()
3841 rank = self .comm .Get_rank ()
@@ -45,8 +48,10 @@ def __init__(self, controller_params, description, comm):
4548
4649 num_levels = len (self .S .levels )
4750
48- # add request handle container for isend
51+ # add request handle and thread container for isend
4952 self .req_send = [None ] * num_levels
53+ self .wait_thread = [None ] * num_levels
54+ self .send_thread = [None ] * num_levels
5055
5156 if num_procs > 1 and num_levels > 1 :
5257 for L in self .S .levels :
@@ -170,7 +175,11 @@ def restart_block(self, size, time, u0):
170175 for l in self .S .levels :
171176 l .tag = None
172177 self .req_status = None
178+ self .wait_thread_stat = None
179+ self .send_thread_stat = None
173180 self .req_send = [None ] * len (self .S .levels )
181+ self .wait_thread = [None ] * len (self .S .levels )
182+ self .send_thread = [None ] * len (self .S .levels )
174183 self .S .status .prev_done = False
175184
176185 self .S .status .time_size = size
@@ -194,6 +203,19 @@ def recv(target, source, tag=None, comm=None):
194203 # re-evaluate f on left interval boundary
195204 target .f [0 ] = target .prob .eval_f (target .u [0 ], target .time )
196205
206+ @staticmethod
207+ def wait (req ):
208+ req .Wait ()
209+
210+ @staticmethod
211+ def wait_stat (req ):
212+ req .wait ()
213+
214+ @staticmethod
215+ def send (S , lvl , tag , comm ):
216+ req = S .levels [lvl ].uend .isend (dest = S .next , tag = tag , comm = comm )
217+ req .Wait ()
218+
197219 def predictor (self , comm ):
198220 """
199221 Predictor function, extracted from the stepwise implementation (will be also used by matrix sweppers)
@@ -340,21 +362,26 @@ def pfasst(self, comm, num_procs):
340362
341363 self .hooks .pre_comm (step = self .S , level_number = 0 )
342364
343- if self .req_send [0 ] is not None :
344- self .req_send [0 ].wait ()
365+ if self .send_thread [0 ] is not None :
366+ self .send_thread [0 ].join ()
345367 self .S .levels [0 ].sweep .compute_end_point ()
346368
347369 if not self .S .status .last and self .params .fine_comm :
348370 self .logger .debug ('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
349371 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
350372 0 , self .S .status .iter ))
351- self .req_send [0 ] = self .S .levels [0 ].uend .isend (dest = self .S .next , tag = self .S .status .iter , comm = comm )
373+ self .send_thread [0 ] = threading .Thread (target = self .send , args = (self .S , 0 , 121212 , comm , ))
374+ self .send_thread [0 ].start ()
375+ self .send_thread [0 ].join ()
376+ # self.req_send[0] = self.S.levels[0].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm)
377+ # self.wait_thread[0] = threading.Thread(target=self.wait, args=(self.req_send[0], ))
378+ # self.wait_thread[0].start()
352379
353380 if not self .S .status .first and not self .S .status .prev_done and self .params .fine_comm :
354381 self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
355382 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
356383 0 , self .S .status .iter ))
357- self .recv (target = self .S .levels [0 ], source = self .S .prev , tag = self . S . status . iter , comm = comm )
384+ self .recv (target = self .S .levels [0 ], source = self .S .prev , tag = 121212 , comm = comm )
358385
359386 self .hooks .post_comm (step = self .S , level_number = 0 )
360387
@@ -371,24 +398,26 @@ def pfasst(self, comm, num_procs):
371398
372399 self .hooks .pre_comm (step = self .S , level_number = 0 )
373400
374- # check if an open request of the status send is pending
375- if self .req_status is not None :
376- self .req_status .wait ()
377-
378- # recv status
379- if not self .S .status .first and not self .S .status .prev_done :
380- self .S .status .prev_done = comm .recv (source = self .S .prev , tag = 99 )
381- self .logger .debug ('recv status: status %s, process %s, time %s, target %s, tag %s, iter %s' %
382- (self .S .status .prev_done , self .S .status .slot , self .S .time , self .S .next ,
383- 99 , self .S .status .iter ))
384- self .S .status .done = self .S .status .done and self .S .status .prev_done
385-
386- # send status forward
387- if not self .S .status .last :
388- self .logger .debug ('isend status: status %s, process %s, time %s, target %s, tag %s, iter %s' %
389- (self .S .status .done , self .S .status .slot , self .S .time , self .S .next ,
390- 99 , self .S .status .iter ))
391- self .req_status = comm .isend (self .S .status .done , dest = self .S .next , tag = 99 )
401+ # # check if an open request of the status send is pending
402+ # if self.wait_thread_stat is not None:
403+ # self.wait_thread_stat.join()
404+ #
405+ # # recv status
406+ # if not self.S.status.first and not self.S.status.prev_done:
407+ # self.S.status.prev_done = comm.recv(source=self.S.prev, tag=99)
408+ # self.logger.debug('recv status: status %s, process %s, time %s, target %s, tag %s, iter %s' %
409+ # (self.S.status.prev_done, self.S.status.slot, self.S.time, self.S.next,
410+ # 99, self.S.status.iter))
411+ # self.S.status.done = self.S.status.done and self.S.status.prev_done
412+ #
413+ # # send status forward
414+ # if not self.S.status.last:
415+ # self.logger.debug('isend status: status %s, process %s, time %s, target %s, tag %s, iter %s' %
416+ # (self.S.status.done, self.S.status.slot, self.S.time, self.S.next,
417+ # 99, self.S.status.iter))
418+ # self.req_status = comm.isend(self.S.status.done, dest=self.S.next, tag=99)
419+ # self.wait_thread_stat = threading.Thread(target=self.wait_stat, args=(self.req_status, ))
420+ # self.wait_thread_stat.start()
392421
393422 self .hooks .post_comm (step = self .S , level_number = 0 , add_to_stats = True )
394423
@@ -411,7 +440,7 @@ def pfasst(self, comm, num_procs):
411440
412441 else :
413442
414- # Need to finish alll pending isend requests. These will occur for the first active process, since
443+ # Need to finish all pending isend requests. These will occur for the first active process, since
415444 # in the last iteration the wait statement will not be called ("send and forget")
416445 for req in self .req_send :
417446 if req is not None :
@@ -435,22 +464,27 @@ def pfasst(self, comm, num_procs):
435464
436465 self .hooks .pre_comm (step = self .S , level_number = 0 )
437466
438- if self .req_send [0 ] is not None :
439- self .req_send [0 ].wait ()
467+ if self .send_thread [0 ] is not None :
468+ self .send_thread [0 ].join ()
440469 self .S .levels [0 ].sweep .compute_end_point ()
441470
442471 if not self .S .status .last and self .params .fine_comm :
443472 self .logger .debug ('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
444473 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
445474 0 , self .S .status .iter ))
446- self .req_send [0 ] = self .S .levels [0 ].uend .isend (dest = self .S .next , tag = self .S .status .iter , comm = comm )
475+ self .send_thread [0 ] = threading .Thread (target = self .send , args = (self .S , 0 , 232323 , comm ,))
476+ self .send_thread [0 ].start ()
477+ self .send_thread [0 ].join ()
478+
479+ # self.req_send[0] = self.S.levels[0].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm)
480+ # self.wait_thread[0] = threading.Thread(target=self.wait, args=(self.req_send[0],))
481+ # self.wait_thread[0].start()
447482
448483 if not self .S .status .first and not self .S .status .prev_done and self .params .fine_comm :
449484 self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
450485 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
451486 0 , self .S .status .iter ))
452- self .recv (target = self .S .levels [0 ], source = self .S .prev , tag = self .S .status .iter , comm = comm )
453-
487+ self .recv (target = self .S .levels [0 ], source = self .S .prev , tag = 232323 , comm = comm )
454488 self .hooks .post_comm (step = self .S , level_number = 0 , add_to_stats = (k == nsweeps - 1 ))
455489
456490 self .hooks .pre_sweep (step = self .S , level_number = 0 )
@@ -476,22 +510,26 @@ def pfasst(self, comm, num_procs):
476510
477511 self .hooks .pre_comm (step = self .S , level_number = l )
478512
479- if self .req_send [l ] is not None :
480- self .req_send [l ].wait ()
513+ if self .send_thread [l ] is not None :
514+ self .send_thread [l ].join ()
481515 self .S .levels [l ].sweep .compute_end_point ()
482516
483517 if not self .S .status .last and self .params .fine_comm :
484518 self .logger .debug ('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
485519 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
486520 l , self .S .status .iter ))
487- self .req_send [l ] = self .S .levels [l ].uend .isend (dest = self .S .next , tag = self .S .status .iter ,
488- comm = comm )
521+ self .send_thread [l ] = threading .Thread (target = self .send , args = (self .S , l , 343434 , comm ,))
522+ self .send_thread [l ].start ()
523+ # self.req_send[l] = self.S.levels[l].uend.isend(dest=self.S.next, tag=self.S.status.iter,
524+ # comm=comm)
525+ # self.wait_thread[l] = threading.Thread(target=self.wait, args=(self.req_send[l],))
526+ # self.wait_thread[l].start()
489527
490528 if not self .S .status .first and not self .S .status .prev_done and self .params .fine_comm :
491529 self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
492530 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
493531 l , self .S .status .iter ))
494- self .recv (target = self .S .levels [l ], source = self .S .prev , tag = self . S . status . iter , comm = comm )
532+ self .recv (target = self .S .levels [l ], source = self .S .prev , tag = 343434 , comm = comm )
495533
496534 self .hooks .post_comm (step = self .S , level_number = l )
497535
@@ -516,7 +554,7 @@ def pfasst(self, comm, num_procs):
516554 self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
517555 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
518556 len (self .S .levels ) - 1 , self .S .status .iter ))
519- self .recv (target = self .S .levels [- 1 ], source = self .S .prev , tag = self . S . status . iter , comm = comm )
557+ self .recv (target = self .S .levels [- 1 ], source = self .S .prev , tag = 454545 , comm = comm )
520558 self .hooks .post_comm (step = self .S , level_number = len (self .S .levels ) - 1 )
521559
522560 # do the sweep
@@ -535,7 +573,7 @@ def pfasst(self, comm, num_procs):
535573 self .logger .debug ('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
536574 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
537575 len (self .S .levels ) - 1 , self .S .status .iter ))
538- self .S .levels [- 1 ].uend .send (dest = self .S .next , tag = self . S . status . iter , comm = comm )
576+ self .S .levels [- 1 ].uend .send (dest = self .S .next , tag = 454545 , comm = comm )
539577 self .hooks .post_comm (step = self .S , level_number = len (self .S .levels ) - 1 , add_to_stats = True )
540578
541579 # update stage
@@ -563,23 +601,27 @@ def pfasst(self, comm, num_procs):
563601
564602 self .hooks .pre_comm (step = self .S , level_number = l - 1 )
565603
566- if self .req_send [l - 1 ] is not None :
567- self .req_send [l - 1 ].wait ()
604+ if self .wait_thread [l - 1 ] is not None :
605+ self .wait_thread [l - 1 ].join ()
568606 self .S .levels [l - 1 ].sweep .compute_end_point ()
569607
570608 if not self .S .status .last and self .params .fine_comm :
571609 self .logger .debug ('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
572610 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .next ,
573611 l - 1 , self .S .status .iter ))
574- self .req_send [l - 1 ] = self .S .levels [l - 1 ].uend .isend (dest = self .S .next ,
575- tag = self .S .status .iter ,
576- comm = comm )
612+ self .send_thread [l - 1 ] = threading .Thread (target = self .send , args = (self .S , l - 1 , 565656 , comm ,))
613+ self .send_thread [l - 1 ].start ()
614+ # self.req_send[l - 1] = self.S.levels[l - 1].uend.isend(dest=self.S.next,
615+ # tag=self.S.status.iter,
616+ # comm=comm)
617+ # self.wait_thread[l - 1] = threading.Thread(target=self.wait, args=(self.req_send[l - 1],))
618+ # self.wait_thread[l - 1].start()
577619
578620 if not self .S .status .first and not self .S .status .prev_done and self .params .fine_comm :
579621 self .logger .debug ('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
580622 (self .S .status .slot , self .S .status .stage , self .S .time , self .S .prev ,
581623 l - 1 , self .S .status .iter ))
582- self .recv (target = self .S .levels [l - 1 ], source = self .S .prev , tag = self . S . status . iter ,
624+ self .recv (target = self .S .levels [l - 1 ], source = self .S .prev , tag = 565656 ,
583625 comm = comm )
584626
585627 self .hooks .post_comm (step = self .S , level_number = l - 1 , add_to_stats = (k == nsweeps - 1 ))
0 commit comments