Skip to content

Commit 3d832e5

Browse files
committed
fix broadcast-view notification of engine death
and ensure this is included in test coverage previous crash test only did single-engine target, which doesn't trigger broadcast view's logic
1 parent e3401a3 commit 3d832e5

File tree

3 files changed

+43
-22
lines changed

3 files changed

+43
-22
lines changed

ipyparallel/client/view.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,15 +891,29 @@ def _make_async_result(self, message_future, s_idents, **kwargs):
891891
msg_and_target_id, async_result=True, track=True
892892
)
893893
self.client.outstanding.add(msg_and_target_id)
894+
self.client._outstanding_dict[ident].add(msg_and_target_id)
894895
self.outstanding.add(msg_and_target_id)
895896
futures.append(future[0])
896897
if original_msg_id in self.outstanding:
897898
self.outstanding.remove(original_msg_id)
898899
else:
899900
self.client.outstanding.add(original_msg_id)
901+
for ident in s_idents:
902+
self.client._outstanding_dict[ident].add(original_msg_id)
900903
futures = message_future
901904

902-
return AsyncResult(self.client, futures, owner=True, **kwargs)
905+
ar = AsyncResult(self.client, futures, owner=True, **kwargs)
906+
907+
if self.is_coalescing:
908+
# if coalescing, discard outstanding-tracking when we are done
909+
def _rm_outstanding(_):
910+
for ident in s_idents:
911+
if ident in self.client._outstanding_dict:
912+
self.client._outstanding_dict[ident].discard(original_msg_id)
913+
914+
ar.add_done_callback(_rm_outstanding)
915+
916+
return ar
903917

904918
@sync_results
905919
@save_ids

ipyparallel/tests/clienttest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ def crash():
3232
os._exit(1)
3333

3434

35+
def conditional_crash(condition):
36+
"""Ungracefully exit the process"""
37+
if condition:
38+
crash()
39+
40+
3541
def wait(n):
3642
"""sleep for a time"""
3743
import time

ipyparallel/tests/test_view.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,15 @@
1515
from IPython.utils.io import capture_output
1616
from ipython_genutils.py3compat import unicode_type
1717

18-
import ipyparallel as pmod
18+
import ipyparallel as ipp
1919
from .clienttest import ClusterTestCase
20-
from .clienttest import crash
20+
from .clienttest import conditional_crash
2121
from .clienttest import skip_without
2222
from .clienttest import wait
2323
from ipyparallel import AsyncHubResult
2424
from ipyparallel import AsyncMapResult
2525
from ipyparallel import AsyncResult
2626
from ipyparallel import error
27-
from ipyparallel.tests import add_engines
2827
from ipyparallel.util import interactive
2928

3029
point = namedtuple("point", "x y")
@@ -42,13 +41,15 @@ def setUp(self):
4241
def test_z_crash_mux(self):
4342
"""test graceful handling of engine death (direct)"""
4443
self.add_engines(1)
44+
self.minimum_engines(2)
4545
eid = self.client.ids[-1]
46-
ar = self.client[eid].apply_async(crash)
46+
view = self.client[-2:]
47+
view.scatter('should_crash', [False, True], flatten=True)
48+
ar = view.apply_async(conditional_crash, ipp.Reference("should_crash"))
4749
self.assertRaisesRemote(error.EngineError, ar.get, 10)
48-
eid = ar.engine_id
49-
tic = time.time()
50-
while eid in self.client.ids and time.time() - tic < 5:
51-
time.sleep(0.01)
50+
tic = time.perf_counter()
51+
while eid in self.client.ids and time.perf_counter() - tic < 5:
52+
time.sleep(0.05)
5253
assert eid not in self.client.ids
5354

5455
def test_push_pull(self):
@@ -135,7 +136,7 @@ def echo(a=10):
135136

136137
def test_get_result(self):
137138
"""test getting results from the Hub."""
138-
c = pmod.Client(profile='iptest')
139+
c = ipp.Client(profile='iptest')
139140
# self.add_engines(1)
140141
t = c.ids[-1]
141142
v = c[t]
@@ -161,7 +162,7 @@ def test_run_newline(self):
161162
)
162163
v = self.client[-1]
163164
v.run(f.name, block=True)
164-
self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
165+
self.assertEqual(v.apply_sync(lambda f: f(), ipp.Reference('g')), 5)
165166

166167
def test_apply_tracked(self):
167168
"""test tracking for apply"""
@@ -215,7 +216,7 @@ def test_scatter_tracked(self):
215216
def test_remote_reference(self):
216217
v = self.client[-1]
217218
v['a'] = 123
218-
ra = pmod.Reference('a')
219+
ra = ipp.Reference('a')
219220
b = v.apply_sync(lambda x: x, ra)
220221
self.assertEqual(b, 123)
221222

@@ -472,7 +473,7 @@ def test_map_reference(self):
472473
v = self.client[:]
473474
v.scatter('n', self.client.ids, flatten=True)
474475
v.execute("f = lambda x,y: x*y")
475-
rf = pmod.Reference('f')
476+
rf = ipp.Reference('f')
476477
nlist = list(range(10))
477478
mlist = nlist[::-1]
478479
expected = [m * n for m, n in zip(mlist, nlist)]
@@ -484,21 +485,21 @@ def test_apply_reference(self):
484485
v = self.client[:]
485486
v.scatter('n', self.client.ids, flatten=True)
486487
v.execute("f = lambda x: n*x")
487-
rf = pmod.Reference('f')
488+
rf = ipp.Reference('f')
488489
result = v.apply_sync(rf, 5)
489490
expected = [5 * id for id in self.client.ids]
490491
self.assertEqual(result, expected)
491492

492493
def test_eval_reference(self):
493494
v = self.client[self.client.ids[0]]
494495
v['g'] = list(range(5))
495-
rg = pmod.Reference('g[0]')
496+
rg = ipp.Reference('g[0]')
496497
echo = lambda x: x
497498
self.assertEqual(v.apply_sync(echo, rg), 0)
498499

499500
def test_reference_nameerror(self):
500501
v = self.client[self.client.ids[0]]
501-
r = pmod.Reference('elvis_has_left')
502+
r = ipp.Reference('elvis_has_left')
502503
echo = lambda x: x
503504
self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
504505

@@ -745,7 +746,7 @@ def test_can_list_arg(self):
745746
"""args in lists are canned"""
746747
view = self.client[-1]
747748
view['a'] = 128
748-
rA = pmod.Reference('a')
749+
rA = ipp.Reference('a')
749750
ar = view.apply_async(lambda x: x, [rA])
750751
r = ar.get(5)
751752
self.assertEqual(r, [128])
@@ -754,7 +755,7 @@ def test_can_dict_arg(self):
754755
"""args in dicts are canned"""
755756
view = self.client[-1]
756757
view['a'] = 128
757-
rA = pmod.Reference('a')
758+
rA = ipp.Reference('a')
758759
ar = view.apply_async(lambda x: x, dict(foo=rA))
759760
r = ar.get(5)
760761
self.assertEqual(r, dict(foo=128))
@@ -763,7 +764,7 @@ def test_can_list_kwarg(self):
763764
"""kwargs in lists are canned"""
764765
view = self.client[-1]
765766
view['a'] = 128
766-
rA = pmod.Reference('a')
767+
rA = ipp.Reference('a')
767768
ar = view.apply_async(lambda x=5: x, x=[rA])
768769
r = ar.get(5)
769770
self.assertEqual(r, [128])
@@ -772,7 +773,7 @@ def test_can_dict_kwarg(self):
772773
"""kwargs in dicts are canned"""
773774
view = self.client[-1]
774775
view['a'] = 128
775-
rA = pmod.Reference('a')
776+
rA = ipp.Reference('a')
776777
ar = view.apply_async(lambda x=5: x, dict(foo=rA))
777778
r = ar.get(5)
778779
self.assertEqual(r, dict(foo=128))
@@ -782,7 +783,7 @@ def test_map_ref(self):
782783
view = self.client[:]
783784
ranks = sorted(self.client.ids)
784785
view.scatter('rank', ranks, flatten=True)
785-
rrank = pmod.Reference('rank')
786+
rrank = ipp.Reference('rank')
786787

787788
amr = view.map_async(lambda x: x * 2, [rrank] * len(view))
788789
drank = amr.get(5)
@@ -801,7 +802,7 @@ def test_nested_getitem_setitem(self):
801802
),
802803
block=True,
803804
)
804-
ra = pmod.Reference('a')
805+
ra = ipp.Reference('a')
805806

806807
r = view.apply_sync(lambda x: x.b, ra)
807808
self.assertEqual(r, 128)

0 commit comments

Comments
 (0)