Skip to content

Commit c307d5d

Browse files
committed
make raises_remote usable outside ClientTest
1 parent 331e1b4 commit c307d5d

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

ipyparallel/tests/clienttest.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
import os
55
import signal
6+
import sys
67
import time
8+
from contextlib import contextmanager
79

810
import pytest
911
import zmq
@@ -50,7 +52,6 @@ def generate_output():
5052
a rich displayable object.
5153
"""
5254

53-
import sys
5455
from IPython.core.display import display, HTML, Math
5556

5657
print("stdout")
@@ -83,6 +84,28 @@ def skip_without_names(f, *args, **kwargs):
8384
return skip_without_names
8485

8586

87+
@contextmanager
88+
def raises_remote(etype):
89+
if isinstance(etype, str):
90+
# allow Exception or 'Exception'
91+
expected_ename = etype
92+
else:
93+
expected_ename = etype.__name__
94+
95+
try:
96+
try:
97+
yield
98+
except error.CompositeError as e:
99+
e.raise_exception()
100+
except error.RemoteError as e:
101+
assert (
102+
expected_ename == e.ename
103+
), f"Should have raised {expected_ename}, but raised {e.ename}"
104+
105+
else:
106+
pytest.fail("should have raised a RemoteError")
107+
108+
86109
# -------------------------------------------------------------------------------
87110
# Classes
88111
# -------------------------------------------------------------------------------
@@ -130,19 +153,8 @@ def connect_client(self):
130153
return c
131154

132155
def assertRaisesRemote(self, etype, f, *args, **kwargs):
133-
try:
134-
try:
135-
f(*args, **kwargs)
136-
except error.CompositeError as e:
137-
e.raise_exception()
138-
except error.RemoteError as e:
139-
self.assertEqual(
140-
etype.__name__,
141-
e.ename,
142-
"Should have raised %r, but raised %r" % (etype.__name__, e.ename),
143-
)
144-
else:
145-
self.fail("should have raised a RemoteError")
156+
with raises_remote(etype):
157+
f(*args, **kwargs)
146158

147159
def _wait_for(self, f, timeout=10):
148160
"""wait for a condition"""

0 commit comments

Comments
 (0)