Skip to content

Commit e1988f2

Browse files
committed
playing around, nothing works..
1 parent bfc8cbe commit e1988f2

File tree

5 files changed

+767
-66
lines changed

5 files changed

+767
-66
lines changed

pySDC/implementations/controller_classes/controller_MPI.py

Lines changed: 85 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from mpi4py import MPI
3+
import threading
34

45
from pySDC.core.Controller import controller
56
from pySDC.core.Errors import ControllerError
@@ -31,8 +32,10 @@ def __init__(self, controller_params, description, comm):
3132

3233
# pass communicator for future use
3334
self.comm = comm
34-
# add request handler for status send
35+
# add request and thread handler for status send
3536
self.req_status = None
37+
self.wait_thread_stat = None
38+
self.send_thread_stat = None
3639

3740
num_procs = self.comm.Get_size()
3841
rank = self.comm.Get_rank()
@@ -45,8 +48,10 @@ def __init__(self, controller_params, description, comm):
4548

4649
num_levels = len(self.S.levels)
4750

48-
# add request handle container for isend
51+
# add request handle and thread container for isend
4952
self.req_send = [None] * num_levels
53+
self.wait_thread = [None] * num_levels
54+
self.send_thread = [None] * num_levels
5055

5156
if num_procs > 1 and num_levels > 1:
5257
for L in self.S.levels:
@@ -170,7 +175,11 @@ def restart_block(self, size, time, u0):
170175
for l in self.S.levels:
171176
l.tag = None
172177
self.req_status = None
178+
self.wait_thread_stat = None
179+
self.send_thread_stat = None
173180
self.req_send = [None] * len(self.S.levels)
181+
self.wait_thread = [None] * len(self.S.levels)
182+
self.send_thread = [None] * len(self.S.levels)
174183
self.S.status.prev_done = False
175184

176185
self.S.status.time_size = size
@@ -194,6 +203,19 @@ def recv(target, source, tag=None, comm=None):
194203
# re-evaluate f on left interval boundary
195204
target.f[0] = target.prob.eval_f(target.u[0], target.time)
196205

206+
@staticmethod
207+
def wait(req):
208+
req.Wait()
209+
210+
@staticmethod
211+
def wait_stat(req):
212+
req.wait()
213+
214+
@staticmethod
215+
def send(S, lvl, tag, comm):
216+
req = S.levels[lvl].uend.isend(dest=S.next, tag=tag, comm=comm)
217+
req.Wait()
218+
197219
def predictor(self, comm):
198220
"""
199221
Predictor function, extracted from the stepwise implementation (will be also used by matrix sweppers)
@@ -340,21 +362,26 @@ def pfasst(self, comm, num_procs):
340362

341363
self.hooks.pre_comm(step=self.S, level_number=0)
342364

343-
if self.req_send[0] is not None:
344-
self.req_send[0].wait()
365+
if self.send_thread[0] is not None:
366+
self.send_thread[0].join()
345367
self.S.levels[0].sweep.compute_end_point()
346368

347369
if not self.S.status.last and self.params.fine_comm:
348370
self.logger.debug('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
349371
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
350372
0, self.S.status.iter))
351-
self.req_send[0] = self.S.levels[0].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm)
373+
self.send_thread[0] = threading.Thread(target=self.send, args=(self.S, 0, 121212, comm, ))
374+
self.send_thread[0].start()
375+
self.send_thread[0].join()
376+
# self.req_send[0] = self.S.levels[0].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm)
377+
# self.wait_thread[0] = threading.Thread(target=self.wait, args=(self.req_send[0], ))
378+
# self.wait_thread[0].start()
352379

353380
if not self.S.status.first and not self.S.status.prev_done and self.params.fine_comm:
354381
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
355382
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
356383
0, self.S.status.iter))
357-
self.recv(target=self.S.levels[0], source=self.S.prev, tag=self.S.status.iter, comm=comm)
384+
self.recv(target=self.S.levels[0], source=self.S.prev, tag=121212, comm=comm)
358385

359386
self.hooks.post_comm(step=self.S, level_number=0)
360387

@@ -371,24 +398,26 @@ def pfasst(self, comm, num_procs):
371398

372399
self.hooks.pre_comm(step=self.S, level_number=0)
373400

374-
# check if an open request of the status send is pending
375-
if self.req_status is not None:
376-
self.req_status.wait()
377-
378-
# recv status
379-
if not self.S.status.first and not self.S.status.prev_done:
380-
self.S.status.prev_done = comm.recv(source=self.S.prev, tag=99)
381-
self.logger.debug('recv status: status %s, process %s, time %s, target %s, tag %s, iter %s' %
382-
(self.S.status.prev_done, self.S.status.slot, self.S.time, self.S.next,
383-
99, self.S.status.iter))
384-
self.S.status.done = self.S.status.done and self.S.status.prev_done
385-
386-
# send status forward
387-
if not self.S.status.last:
388-
self.logger.debug('isend status: status %s, process %s, time %s, target %s, tag %s, iter %s' %
389-
(self.S.status.done, self.S.status.slot, self.S.time, self.S.next,
390-
99, self.S.status.iter))
391-
self.req_status = comm.isend(self.S.status.done, dest=self.S.next, tag=99)
401+
# # check if an open request of the status send is pending
402+
# if self.wait_thread_stat is not None:
403+
# self.wait_thread_stat.join()
404+
#
405+
# # recv status
406+
# if not self.S.status.first and not self.S.status.prev_done:
407+
# self.S.status.prev_done = comm.recv(source=self.S.prev, tag=99)
408+
# self.logger.debug('recv status: status %s, process %s, time %s, target %s, tag %s, iter %s' %
409+
# (self.S.status.prev_done, self.S.status.slot, self.S.time, self.S.next,
410+
# 99, self.S.status.iter))
411+
# self.S.status.done = self.S.status.done and self.S.status.prev_done
412+
#
413+
# # send status forward
414+
# if not self.S.status.last:
415+
# self.logger.debug('isend status: status %s, process %s, time %s, target %s, tag %s, iter %s' %
416+
# (self.S.status.done, self.S.status.slot, self.S.time, self.S.next,
417+
# 99, self.S.status.iter))
418+
# self.req_status = comm.isend(self.S.status.done, dest=self.S.next, tag=99)
419+
# self.wait_thread_stat = threading.Thread(target=self.wait_stat, args=(self.req_status, ))
420+
# self.wait_thread_stat.start()
392421

393422
self.hooks.post_comm(step=self.S, level_number=0, add_to_stats=True)
394423

@@ -411,7 +440,7 @@ def pfasst(self, comm, num_procs):
411440

412441
else:
413442

414-
# Need to finish alll pending isend requests. These will occur for the first active process, since
443+
# Need to finish all pending isend requests. These will occur for the first active process, since
415444
# in the last iteration the wait statement will not be called ("send and forget")
416445
for req in self.req_send:
417446
if req is not None:
@@ -435,22 +464,27 @@ def pfasst(self, comm, num_procs):
435464

436465
self.hooks.pre_comm(step=self.S, level_number=0)
437466

438-
if self.req_send[0] is not None:
439-
self.req_send[0].wait()
467+
if self.send_thread[0] is not None:
468+
self.send_thread[0].join()
440469
self.S.levels[0].sweep.compute_end_point()
441470

442471
if not self.S.status.last and self.params.fine_comm:
443472
self.logger.debug('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
444473
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
445474
0, self.S.status.iter))
446-
self.req_send[0] = self.S.levels[0].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm)
475+
self.send_thread[0] = threading.Thread(target=self.send, args=(self.S, 0, 232323, comm,))
476+
self.send_thread[0].start()
477+
self.send_thread[0].join()
478+
479+
# self.req_send[0] = self.S.levels[0].uend.isend(dest=self.S.next, tag=self.S.status.iter, comm=comm)
480+
# self.wait_thread[0] = threading.Thread(target=self.wait, args=(self.req_send[0],))
481+
# self.wait_thread[0].start()
447482

448483
if not self.S.status.first and not self.S.status.prev_done and self.params.fine_comm:
449484
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
450485
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
451486
0, self.S.status.iter))
452-
self.recv(target=self.S.levels[0], source=self.S.prev, tag=self.S.status.iter, comm=comm)
453-
487+
self.recv(target=self.S.levels[0], source=self.S.prev, tag=232323, comm=comm)
454488
self.hooks.post_comm(step=self.S, level_number=0, add_to_stats=(k == nsweeps - 1))
455489

456490
self.hooks.pre_sweep(step=self.S, level_number=0)
@@ -476,22 +510,26 @@ def pfasst(self, comm, num_procs):
476510

477511
self.hooks.pre_comm(step=self.S, level_number=l)
478512

479-
if self.req_send[l] is not None:
480-
self.req_send[l].wait()
513+
if self.send_thread[l] is not None:
514+
self.send_thread[l].join()
481515
self.S.levels[l].sweep.compute_end_point()
482516

483517
if not self.S.status.last and self.params.fine_comm:
484518
self.logger.debug('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
485519
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
486520
l, self.S.status.iter))
487-
self.req_send[l] = self.S.levels[l].uend.isend(dest=self.S.next, tag=self.S.status.iter,
488-
comm=comm)
521+
self.send_thread[l] = threading.Thread(target=self.send, args=(self.S, l, 343434, comm,))
522+
self.send_thread[l].start()
523+
# self.req_send[l] = self.S.levels[l].uend.isend(dest=self.S.next, tag=self.S.status.iter,
524+
# comm=comm)
525+
# self.wait_thread[l] = threading.Thread(target=self.wait, args=(self.req_send[l],))
526+
# self.wait_thread[l].start()
489527

490528
if not self.S.status.first and not self.S.status.prev_done and self.params.fine_comm:
491529
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
492530
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
493531
l, self.S.status.iter))
494-
self.recv(target=self.S.levels[l], source=self.S.prev, tag=self.S.status.iter, comm=comm)
532+
self.recv(target=self.S.levels[l], source=self.S.prev, tag=343434, comm=comm)
495533

496534
self.hooks.post_comm(step=self.S, level_number=l)
497535

@@ -516,7 +554,7 @@ def pfasst(self, comm, num_procs):
516554
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
517555
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
518556
len(self.S.levels) - 1, self.S.status.iter))
519-
self.recv(target=self.S.levels[-1], source=self.S.prev, tag=self.S.status.iter, comm=comm)
557+
self.recv(target=self.S.levels[-1], source=self.S.prev, tag=454545, comm=comm)
520558
self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1)
521559

522560
# do the sweep
@@ -535,7 +573,7 @@ def pfasst(self, comm, num_procs):
535573
self.logger.debug('send data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
536574
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
537575
len(self.S.levels) - 1, self.S.status.iter))
538-
self.S.levels[-1].uend.send(dest=self.S.next, tag=self.S.status.iter, comm=comm)
576+
self.S.levels[-1].uend.send(dest=self.S.next, tag=454545, comm=comm)
539577
self.hooks.post_comm(step=self.S, level_number=len(self.S.levels) - 1, add_to_stats=True)
540578

541579
# update stage
@@ -563,23 +601,27 @@ def pfasst(self, comm, num_procs):
563601

564602
self.hooks.pre_comm(step=self.S, level_number=l - 1)
565603

566-
if self.req_send[l - 1] is not None:
567-
self.req_send[l - 1].wait()
604+
if self.wait_thread[l - 1] is not None:
605+
self.wait_thread[l - 1].join()
568606
self.S.levels[l - 1].sweep.compute_end_point()
569607

570608
if not self.S.status.last and self.params.fine_comm:
571609
self.logger.debug('isend data: process %s, stage %s, time %s, target %s, tag %s, iter %s' %
572610
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.next,
573611
l - 1, self.S.status.iter))
574-
self.req_send[l - 1] = self.S.levels[l - 1].uend.isend(dest=self.S.next,
575-
tag=self.S.status.iter,
576-
comm=comm)
612+
self.send_thread[l - 1] = threading.Thread(target=self.send, args=(self.S, l - 1, 565656, comm,))
613+
self.send_thread[l - 1].start()
614+
# self.req_send[l - 1] = self.S.levels[l - 1].uend.isend(dest=self.S.next,
615+
# tag=self.S.status.iter,
616+
# comm=comm)
617+
# self.wait_thread[l - 1] = threading.Thread(target=self.wait, args=(self.req_send[l - 1],))
618+
# self.wait_thread[l - 1].start()
577619

578620
if not self.S.status.first and not self.S.status.prev_done and self.params.fine_comm:
579621
self.logger.debug('recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' %
580622
(self.S.status.slot, self.S.status.stage, self.S.time, self.S.prev,
581623
l - 1, self.S.status.iter))
582-
self.recv(target=self.S.levels[l - 1], source=self.S.prev, tag=self.S.status.iter,
624+
self.recv(target=self.S.levels[l - 1], source=self.S.prev, tag=565656,
583625
comm=comm)
584626

585627
self.hooks.post_comm(step=self.S, level_number=l - 1, add_to_stats=(k == nsweeps - 1))

0 commit comments

Comments
 (0)