|
26 | 26 | except ModuleNotFoundError:
|
27 | 27 | with_distributed = False
|
28 | 28 |
|
| 29 | +try: |
| 30 | + import mpi4py.futures |
| 31 | + with_mpi4py = True |
| 32 | +except ModuleNotFoundError: |
| 33 | + with_mpi4py = False |
| 34 | + |
29 | 35 | with suppress(ModuleNotFoundError):
|
30 | 36 | import uvloop
|
31 | 37 | asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
@@ -66,7 +72,7 @@ class BaseRunner(metaclass=abc.ABCMeta):
|
66 | 72 | the learner as its sole argument, and return True when we should
|
67 | 73 | stop requesting more points.
|
68 | 74 | executor : `concurrent.futures.Executor`, `distributed.Client`,\
|
69 |
| - or `ipyparallel.Client`, optional |
| 75 | + `mpi4py.futures.MPIPoolExecutor`, or `ipyparallel.Client`, optional |
70 | 76 | The executor in which to evaluate the function to be learned.
|
71 | 77 | If not provided, a new `~concurrent.futures.ProcessPoolExecutor`
|
72 | 78 | is used on Unix systems while on Windows a `distributed.Client`
|
@@ -281,7 +287,7 @@ class BlockingRunner(BaseRunner):
|
281 | 287 | the learner as its sole argument, and return True when we should
|
282 | 288 | stop requesting more points.
|
283 | 289 | executor : `concurrent.futures.Executor`, `distributed.Client`,\
|
284 |
| - or `ipyparallel.Client`, optional |
| 290 | + `mpi4py.futures.MPIPoolExecutor`, or `ipyparallel.Client`, optional |
285 | 291 | The executor in which to evaluate the function to be learned.
|
286 | 292 | If not provided, a new `~concurrent.futures.ProcessPoolExecutor`
|
287 | 293 | is used on Unix systems while on Windows a `distributed.Client`
|
@@ -386,7 +392,7 @@ class AsyncRunner(BaseRunner):
|
386 | 392 | stop requesting more points. If not provided, the runner will run
|
387 | 393 | forever, or until ``self.task.cancel()`` is called.
|
388 | 394 | executor : `concurrent.futures.Executor`, `distributed.Client`,\
|
389 |
| - or `ipyparallel.Client`, optional |
| 395 | + `mpi4py.futures.MPIPoolExecutor`, or `ipyparallel.Client`, optional |
390 | 396 | The executor in which to evaluate the function to be learned.
|
391 | 397 | If not provided, a new `~concurrent.futures.ProcessPoolExecutor`
|
392 | 398 | is used on Unix systems while on Windows a `distributed.Client`
|
@@ -693,6 +699,9 @@ def _get_ncores(ex):
|
693 | 699 | return 1
|
694 | 700 | elif with_distributed and isinstance(ex, distributed.cfexecutor.ClientExecutor):
|
695 | 701 | return sum(n for n in ex._client.ncores().values())
|
| 702 | + elif with_mpi4py and isinstance(ex, mpi4py.futures.MPIPoolExecutor): |
| 703 | + ex.bootup() # wait until all workers are up and running |
| 704 | + return ex._pool.size # not public API! |
696 | 705 | else:
|
697 | 706 | raise TypeError('Cannot get number of cores for {}'
|
698 | 707 | .format(ex.__class__))
|
0 commit comments