Skip to content

Commit f093b39

Browse files
committed
more test coverage for Cluster API
reasonably complete now fix various little things along the way
1 parent d0505f4 commit f093b39

File tree

4 files changed

+104
-31
lines changed

4 files changed

+104
-31
lines changed

ipyparallel/cluster/cluster.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import IPython
2020
import traitlets.log
2121
from IPython.core.profiledir import ProfileDir
22+
from IPython.core.profiledir import ProfileDirError
2223
from traitlets import Any
2324
from traitlets import default
2425
from traitlets import Dict
@@ -68,9 +69,13 @@ def _default_profile_dir(self):
6869
ip = IPython.get_ipython()
6970
if ip is not None:
7071
return ip.profile_dir.location
71-
return ProfileDir.find_profile_dir_by_name(
72-
IPython.paths.get_ipython_dir(), name=self.profile or 'default'
73-
).location
72+
ipython_dir = IPython.paths.get_ipython_dir()
73+
profile_name = self.profile or 'default'
74+
try:
75+
pd = ProfileDir.find_profile_dir_by_name(ipython_dir, name=profile_name)
76+
except ProfileDirError:
77+
pd = ProfileDir.create_profile_dir_by_name(ipython_dir, name=profile_name)
78+
return pd.location
7479

7580
profile = Unicode(
7681
"",
@@ -215,16 +220,22 @@ def _default_log(self):
215220
_engine_sets = Dict()
216221

217222
def __repr__(self):
218-
profile_dir = self.profile_dir
219-
home_dir = os.path.expanduser("~")
220-
if profile_dir.startswith(home_dir + os.path.sep):
221-
# truncate $HOME/. -> ~/...
222-
profile_dir = "~" + profile_dir[len(home_dir) :]
223223

224224
fields = {
225225
"cluster_id": repr(self.cluster_id),
226-
"profile_dir": repr(profile_dir),
227226
}
227+
profile_dir = self.profile_dir
228+
profile_prefix = os.path.join(IPython.paths.get_ipython_dir(), "profile_")
229+
if profile_dir.startswith(profile_prefix):
230+
fields["profile"] = repr(profile_dir[len(profile_prefix) :])
231+
else:
232+
home_dir = os.path.expanduser("~")
233+
234+
if profile_dir.startswith(home_dir + os.path.sep):
235+
# truncate $HOME/. -> ~/...
236+
profile_dir = "~" + profile_dir[len(home_dir) :]
237+
fields["profile_dir"] = repr(profile_dir)
238+
228239
if self._controller:
229240
fields["controller"] = "<running>"
230241
if self._engine_sets:
@@ -473,7 +484,7 @@ class ClusterManager(LoggingConfigurable):
473484

474485
_clusters = Dict(help="My cluster objects")
475486

476-
def load_clusters(self, serialized_state):
487+
def from_dict(self, serialized_state):
477488
"""Load serialized cluster state"""
478489
raise NotImplementedError("Serializing clusters not implemented")
479490

@@ -484,12 +495,13 @@ def list_clusters(self):
484495
# just cluster ids for now
485496
return sorted(self._clusters)
486497

487-
def new_cluster(self, cluster_cls, **kwargs):
498+
def new_cluster(self, **kwargs):
488499
"""Create a new cluster"""
489-
cluster = Cluster(parent=self)
500+
cluster = Cluster(parent=self, **kwargs)
490501
if cluster.cluster_id in self._clusters:
491502
raise KeyError(f"Cluster {cluster.cluster_id} already exists!")
492-
self._clusters[cluster]
503+
self._clusters[cluster.cluster_id] = cluster
504+
return cluster
493505

494506
def get_cluster(self, cluster_id):
495507
"""Get a Cluster object by id"""
@@ -499,19 +511,3 @@ def remove_cluster(self, cluster_id):
499511
"""Delete a cluster by id"""
500512
# TODO: check running?
501513
del self._clusters[cluster_id]
502-
503-
def _cluster_method(self, method_name, cluster_id, *args, **kwargs):
504-
"""Wrapper around single-cluster methods
505-
506-
Defines ClusterManager.method(cluster_id, ...)
507-
508-
which returns ClusterManager.clusters[cluster_id].method(...)
509-
"""
510-
cluster = self._clusters[cluster_id]
511-
method = getattr(cluster, method_name)
512-
return method(*args, **kwargs)
513-
514-
def __getattr__(self, key):
515-
if key in Cluster.__dict__:
516-
return partial(self._cluster_method, key)
517-
return super().__getattr__(self, key)

ipyparallel/cluster/launcher.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ def __init__(self, work_dir=u'.', config=None, **kwargs):
412412

413413
def start(self, n):
414414
"""Start n engines by profile or profile_dir."""
415+
self.n = n
415416
dlist = []
416417
for i in range(n):
417418
if i > 0:

ipyparallel/tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""pytest fixtures"""
22
import inspect
3+
import os
4+
from tempfile import TemporaryDirectory
5+
from unittest import mock
36

47
import pytest
58
from IPython.terminal.interactiveshell import TerminalInteractiveShell
@@ -9,6 +12,13 @@
912
from . import teardown
1013

1114

15+
@pytest.fixture(autouse=True, scope="session")
16+
def temp_ipython():
17+
with TemporaryDirectory(suffix="dotipython") as td:
18+
with mock.patch.dict(os.environ, {"IPYTHONDIR": td}):
19+
yield
20+
21+
1222
def pytest_collection_modifyitems(items):
1323
"""This function is automatically run by pytest passing all collected test
1424
functions.

ipyparallel/tests/test_cluster.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import logging
3+
import os
34
import shutil
45
import signal
56
import sys
@@ -12,6 +13,10 @@
1213
from ipyparallel import cluster
1314
from ipyparallel.cluster.launcher import find_launcher_class
1415

16+
_engine_launcher_classes = ["Local"]
17+
if shutil.which("mpiexec"):
18+
_engine_launcher_classes.append("MPI")
19+
1520

1621
@pytest.fixture
1722
def Cluster(request):
@@ -52,6 +57,13 @@ async def test_cluster_id(Cluster):
5257
assert cluster.cluster_id == 'abc'
5358

5459

60+
async def test_ipython_log(ipython):
61+
c = cluster.Cluster(parent=ipython)
62+
assert c.log.name == f"{cluster.Cluster.__module__}.{c.cluster_id}"
63+
assert len(c.log.handlers) == 1
64+
assert c.log.handlers[0].stream is sys.stdout
65+
66+
5567
async def test_start_stop_controller(Cluster):
5668
cluster = Cluster()
5769
await cluster.start_controller()
@@ -72,7 +84,7 @@ async def test_start_stop_controller(Cluster):
7284
# TODO: test file cleanup
7385

7486

75-
@pytest.mark.parametrize("engine_launcher_class", ["Local", "MPI"])
87+
@pytest.mark.parametrize("engine_launcher_class", _engine_launcher_classes)
7688
async def test_start_stop_engines(Cluster, engine_launcher_class):
7789
cluster = Cluster(engine_launcher_class=engine_launcher_class)
7890
await cluster.start_controller()
@@ -89,7 +101,23 @@ async def test_start_stop_engines(Cluster, engine_launcher_class):
89101
await cluster.stop_controller()
90102

91103

92-
@pytest.mark.parametrize("engine_launcher_class", ["Local", "MPI"])
104+
@pytest.mark.parametrize("engine_launcher_class", _engine_launcher_classes)
105+
async def test_start_stop_cluster(Cluster, engine_launcher_class):
106+
n = 2
107+
cluster = Cluster(engine_launcher_class=engine_launcher_class, n=n)
108+
await cluster.start_cluster()
109+
controller = cluster._controller
110+
assert controller is not None
111+
assert len(cluster._engine_sets) == 1
112+
113+
rc = cluster.connect_client()
114+
rc.wait_for_engines(n, timeout=10)
115+
await cluster.stop_cluster()
116+
assert cluster._controller is None
117+
assert cluster._engine_sets == {}
118+
119+
120+
@pytest.mark.parametrize("engine_launcher_class", _engine_launcher_classes)
93121
async def test_signal_engines(Cluster, engine_launcher_class):
94122
cluster = Cluster(engine_launcher_class=engine_launcher_class)
95123
await cluster.start_controller()
@@ -117,6 +145,25 @@ async def test_signal_engines(Cluster, engine_launcher_class):
117145
await cluster.stop_controller()
118146

119147

148+
@pytest.mark.parametrize("engine_launcher_class", _engine_launcher_classes)
149+
async def test_restart_engines(Cluster, engine_launcher_class):
150+
n = 3
151+
async with Cluster(engine_launcher_class=engine_launcher_class, n=n) as rc:
152+
cluster = rc.cluster
153+
engine_set_id = next(iter(cluster._engine_sets))
154+
engine_set = cluster._engine_sets[engine_set_id]
155+
assert rc.ids == list(range(n))
156+
before_pids = rc[:].apply_sync(os.getpid)
157+
await cluster.restart_engines()
158+
# wait for unregister
159+
while any(eid in rc.ids for eid in range(n)):
160+
await asyncio.sleep(0.1)
161+
# wait for register
162+
rc.wait_for_engines(n, timeout=10)
163+
after_pids = rc[:].apply_sync(os.getpid)
164+
assert set(after_pids).intersection(before_pids) == set()
165+
166+
120167
async def test_async_with(Cluster):
121168
async with Cluster(n=5) as rc:
122169
assert sorted(rc.ids) == list(range(5))
@@ -160,3 +207,22 @@ async def test_cluster_repr(Cluster):
160207
repr(c)
161208
== "<Cluster(cluster_id='test', profile_dir='/tmp', controller=<running>, engine_sets=['engineid'])>"
162209
)
210+
211+
212+
async def test_cluster_manager():
213+
m = cluster.ClusterManager()
214+
assert m.list_clusters() == []
215+
c = m.new_cluster(profile_dir="/tmp")
216+
assert c.profile_dir == "/tmp"
217+
assert m.get_cluster(c.cluster_id) is c
218+
with pytest.raises(KeyError):
219+
m.get_cluster("nosuchcluster")
220+
221+
with pytest.raises(KeyError):
222+
m.new_cluster(cluster_id=c.cluster_id)
223+
224+
assert m.list_clusters() == [c.cluster_id]
225+
m.remove_cluster(c.cluster_id)
226+
assert m.list_clusters() == []
227+
with pytest.raises(KeyError):
228+
m.remove_cluster("nosuchcluster")

0 commit comments

Comments
 (0)