Skip to content

Commit 0615199

Browse files
committed
Support return_exceptions=True in map, get
Allows exceptions to be returned instead of raised so that partial errors can be handled and successful values from partial success extracted
1 parent 37da0fd commit 0615199

File tree

6 files changed

+181
-38
lines changed

6 files changed

+181
-38
lines changed

ipyparallel/client/asyncresult.py

Lines changed: 87 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(
7373
fname='unknown',
7474
targets=None,
7575
owner=False,
76+
return_exceptions=False,
7677
):
7778
super(AsyncResult, self).__init__()
7879
if not isinstance(children, list):
@@ -81,6 +82,8 @@ def __init__(
8182
else:
8283
self._single_result = False
8384

85+
self._return_exceptions = return_exceptions
86+
8487
if isinstance(children[0], string_types):
8588
self.msg_ids = children
8689
self._children = []
@@ -204,22 +207,35 @@ def _reconstruct_result(self, res):
204207
else:
205208
return res
206209

207-
def get(self, timeout=-1):
210+
def get(self, timeout=None, return_exceptions=None):
208211
"""Return the result when it arrives.
209212
210-
If `timeout` is not ``None`` and the result does not arrive within
211-
`timeout` seconds then ``TimeoutError`` is raised. If the
212-
remote call raised an exception then that exception will be reraised
213-
by get() inside a `RemoteError`.
213+
Arguments:
214+
215+
timeout : int [default None]
216+
If `timeout` is not ``None`` and the result does not arrive within
217+
`timeout` seconds then ``TimeoutError`` is raised. If the
218+
remote call raised an exception then that exception will be reraised
219+
by get() inside a `RemoteError`.
220+
return_exceptions : bool [default False]
221+
If True, return Exceptions instead of raising them.
214222
"""
215223
if not self.ready():
216224
self.wait(timeout)
217225

226+
if return_exceptions is None:
227+
# default to attribute, if AsyncResult was created with return_exceptions=True
228+
return_exceptions = self._return_exceptions
229+
218230
if self._ready:
219231
if self._success:
220232
return self.result()
221233
else:
222-
raise self.exception()
234+
e = self.exception()
235+
if return_exceptions:
236+
return self._reconstruct_result(self._raw_results)
237+
else:
238+
raise e
223239
else:
224240
raise error.TimeoutError("Result not ready.")
225241

@@ -270,22 +286,37 @@ def wait(self, timeout=-1):
270286
def _resolve_result(self, f=None):
271287
if self.done():
272288
return
289+
if f:
290+
results = f.result()
291+
else:
292+
results = list(map(self._client.results.get, self.msg_ids))
293+
294+
# store raw results
295+
self._raw_results = results
296+
273297
try:
274-
if f:
275-
results = f.result()
276-
else:
277-
results = list(map(self._client.results.get, self.msg_ids))
278298
if self._single_result:
279299
r = results[0]
280300
if isinstance(r, Exception):
281301
raise r
282302
else:
283-
results = error.collect_exceptions(results, self._fname)
284-
self._success = True
285-
self.set_result(self._reconstruct_result(results))
303+
results = self._collect_exceptions(results)
286304
except Exception as e:
287305
self._success = False
288306
self.set_exception(e)
307+
else:
308+
self._success = True
309+
self.set_result(self._reconstruct_result(results))
310+
311+
def _collect_exceptions(self, results):
312+
"""Wrap Exceptions in a CompositeError
313+
314+
if self._return_exceptions is True, this is a no-op
315+
"""
316+
if self._return_exceptions:
317+
return results
318+
else:
319+
return error.collect_exceptions(results, self._fname)
289320

290321
def _finalize_result(self, f):
291322
if self.owner:
@@ -424,10 +455,10 @@ def __getitem__(self, key):
424455
"""getitem returns result value(s) if keyed by int/slice, or metadata if key is str."""
425456
if isinstance(key, int):
426457
self._check_ready()
427-
return error.collect_exceptions([self.result()[key]], self._fname)[0]
458+
return self._collect_exceptions([self.result()[key]])[0]
428459
elif isinstance(key, slice):
429460
self._check_ready()
430-
return error.collect_exceptions(self.result()[key], self._fname)
461+
return self._collect_exceptions(self.result()[key])
431462
elif isinstance(key, string_types):
432463
# metadata proxy *does not* require that results are done
433464
self.wait(0)
@@ -473,7 +504,7 @@ def __iter__(self):
473504
for child in self._children:
474505
self._wait_for_child(child, evt=evt)
475506
result = child.result()
476-
error.collect_exceptions([result], self._fname)
507+
self._collect_exceptions([result])
477508
yield result
478509
else:
479510
# already done
@@ -583,15 +614,15 @@ def wait_interactive(self, interval=0.1, timeout=-1, widget=None):
583614
Override default context-detection behavior for whether a widget-based progress bar
584615
should be used.
585616
"""
586-
if timeout is None:
587-
timeout = -1
617+
if timeout and timeout < 0:
618+
timeout = None
588619
N = len(self)
589620
tic = time.perf_counter()
590621
progress_bar = progress(widget=widget, total=N, unit='tasks', desc=self._fname)
591622

592623
n_prev = 0
593624
while not self.ready() and (
594-
timeout < 0 or time.perf_counter() - tic <= timeout
625+
timeout is None or time.perf_counter() - tic <= timeout
595626
):
596627
self.wait(interval)
597628
progress_bar.update(self.progress - n_prev)
@@ -751,25 +782,50 @@ def display_outputs(self, groupby="type", result_only=False):
751782

752783

753784
class AsyncMapResult(AsyncResult):
754-
"""Class for representing results of non-blocking gathers.
785+
"""Class for representing results of non-blocking maps.
755786
756-
This will properly reconstruct the gather.
787+
AsyncMapResult.get() will properly reconstruct gathers into single object.
757788
758-
This class is iterable at any time, and will wait on results as they come.
789+
AsyncMapResult is iterable at any time, and will wait on results as they come.
759790
760791
If ordered=False, then the first results to arrive will come first, otherwise
761792
results will be yielded in the order they were submitted.
762-
763793
"""
764794

765-
def __init__(self, client, children, mapObject, fname='', ordered=True):
795+
def __init__(
796+
self,
797+
client,
798+
children,
799+
mapObject,
800+
fname='',
801+
ordered=True,
802+
return_exceptions=False,
803+
):
766804
self._mapObject = mapObject
767805
self.ordered = ordered
768-
AsyncResult.__init__(self, client, children, fname=fname)
806+
AsyncResult.__init__(
807+
self,
808+
client,
809+
children,
810+
fname=fname,
811+
return_exceptions=return_exceptions,
812+
)
769813
self._single_result = False
770814

771815
def _reconstruct_result(self, res):
772816
"""Perform the gather on the actual results."""
817+
if self._return_exceptions:
818+
if any(isinstance(r, Exception) for r in res):
819+
# running with _return_exceptions,
820+
# cannot reconstruct original
821+
# use simple chain iterable
822+
flattened = []
823+
for r in res:
824+
if isinstance(r, Exception):
825+
flattened.append(r)
826+
else:
827+
flattened.extend(r)
828+
return flattened
773829
return self._mapObject.joinPartitions(res)
774830

775831
# asynchronous iterator:
@@ -786,7 +842,7 @@ def _yield_child_results(self, child):
786842
rlist = child.result()
787843
if not isinstance(rlist, list):
788844
rlist = [rlist]
789-
error.collect_exceptions(rlist, self._fname)
845+
self._collect_exceptions(rlist)
790846
for r in rlist:
791847
yield r
792848

@@ -841,6 +897,8 @@ def _init_futures(self):
841897
def wait(self, timeout=-1):
842898
"""wait for result to complete."""
843899
start = time.time()
900+
if timeout and timeout < 0:
901+
timeout = None
844902
if self._ready:
845903
return
846904
local_ids = [m for m in self.msg_ids if m in self._client.outstanding]
@@ -852,7 +910,7 @@ def wait(self, timeout=-1):
852910
else:
853911
rdict = self._client.result_status(remote_ids, status_only=False)
854912
pending = rdict['pending']
855-
while pending and (timeout < 0 or time.time() < start + timeout):
913+
while pending and (timeout is None or time.time() < start + timeout):
856914
rdict = self._client.result_status(remote_ids, status_only=False)
857915
pending = rdict['pending']
858916
if pending:
@@ -865,11 +923,10 @@ def wait(self, timeout=-1):
865923
results = list(map(self._client.results.get, self.msg_ids))
866924
if self._single_result:
867925
r = results[0]
868-
if isinstance(r, Exception):
926+
if isinstance(r, Exception) and not self._return_exceptions:
869927
raise r
870-
self.set_result(r)
871928
else:
872-
results = error.collect_exceptions(results, self._fname)
929+
results = self._collect_exceptions(results)
873930
self._success = True
874931
self.set_result(self._reconstruct_result(results))
875932
except Exception as e:

ipyparallel/client/remotefunction.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ class ParallelFunction(RemoteFunction):
186186
ordered : bool [default: True]
187187
Whether the result should be kept in order. If False,
188188
results become available as they arrive, regardless of submission order.
189+
return_exceptions : bool [default: False]
189190
**flags
190191
remaining kwargs are passed to View.temp_flags
191192
"""
@@ -195,11 +196,20 @@ class ParallelFunction(RemoteFunction):
195196
mapObject = None
196197

197198
def __init__(
198-
self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags
199+
self,
200+
view,
201+
f,
202+
dist='b',
203+
block=None,
204+
chunksize=None,
205+
ordered=True,
206+
return_exceptions=False,
207+
**flags,
199208
):
200209
super(ParallelFunction, self).__init__(view, f, block=block, **flags)
201210
self.chunksize = chunksize
202211
self.ordered = ordered
212+
self.return_exceptions = return_exceptions
203213

204214
mapClass = Map.dists[dist]
205215
self.mapObject = mapClass()
@@ -296,6 +306,7 @@ def __call__(self, *sequences, **kwargs):
296306
self.mapObject,
297307
fname=getname(self.func),
298308
ordered=self.ordered,
309+
return_exceptions=self.return_exceptions,
299310
)
300311

301312
if self.block:

ipyparallel/client/view.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,7 +1216,15 @@ def _really_apply(
12161216

12171217
@sync_results
12181218
@save_ids
1219-
def map(self, f, *sequences, block=None, chunksize=1, ordered=True):
1219+
def map(
1220+
self,
1221+
f,
1222+
*sequences,
1223+
block=None,
1224+
chunksize=1,
1225+
ordered=True,
1226+
return_exceptions=False,
1227+
):
12201228
"""Parallel version of builtin `map`, load-balanced by this View.
12211229
12221230
`block`, and `chunksize` can be specified by keyword only.
@@ -1242,6 +1250,9 @@ def map(self, f, *sequences, block=None, chunksize=1, ordered=True):
12421250
Only applies when iterating through AsyncMapResult as results arrive.
12431251
Has no effect when block=True.
12441252
1253+
return_exceptions: bool [default False]
1254+
Return Exceptions instead of raising on the first exception.
1255+
12451256
Returns
12461257
-------
12471258
if block=False
@@ -1260,7 +1271,12 @@ def map(self, f, *sequences, block=None, chunksize=1, ordered=True):
12601271
assert len(sequences) > 0, "must have some sequences to map onto!"
12611272

12621273
pf = ParallelFunction(
1263-
self, f, block=block, chunksize=chunksize, ordered=ordered
1274+
self,
1275+
f,
1276+
block=block,
1277+
chunksize=chunksize,
1278+
ordered=ordered,
1279+
return_exceptions=return_exceptions,
12641280
)
12651281
return pf.map(*sequences)
12661282

@@ -1270,6 +1286,7 @@ def imap(
12701286
*sequences,
12711287
ordered=True,
12721288
max_outstanding='auto',
1289+
return_exceptions=False,
12731290
):
12741291
"""Parallel version of lazily-evaluated `imap`, load-balanced by this View.
12751292
@@ -1308,6 +1325,9 @@ def imap(
13081325
13091326
Use this to tune how greedily input generator should be consumed.
13101327
1328+
return_exceptions : bool [default False]
1329+
Return Exceptions instead of raising them.
1330+
13111331
Returns
13121332
-------
13131333
@@ -1380,22 +1400,22 @@ def should_yield():
13801400
# yielding immediately means
13811401
if should_yield():
13821402
for ready_ar in wait_for_ready():
1383-
yield ready_ar.get()
1403+
yield ready_ar.get(return_exceptions=return_exceptions)
13841404

13851405
# we've filled the buffer, wait for at least one result before continuing
13861406
if len(outstanding) == max_outstanding:
13871407
for ready_ar in wait_for_ready():
1388-
yield ready_ar.get()
1408+
yield ready_ar.get(return_exceptions=return_exceptions)
13891409

13901410
# yield any remaining results
13911411
if ordered:
13921412
for ar in outstanding:
1393-
yield ar.get()
1413+
yield ar.get(return_exceptions=return_exceptions)
13941414
else:
13951415
while outstanding:
13961416
done, outstanding = concurrent.futures.wait(outstanding)
13971417
for ar in done:
1398-
yield ar.get()
1418+
yield ar.get(return_exceptions=return_exceptions)
13991419

14001420
def register_joblib_backend(self, name='ipyparallel', make_default=False):
14011421
"""Register this View as a joblib parallel backend

ipyparallel/tests/clienttest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def wait_on_engines(self, timeout=5):
137137

138138
def client_wait(self, client, jobs=None, timeout=-1):
139139
"""my wait wrapper, sets a default finite timeout to avoid hangs"""
140-
if timeout < 0:
140+
if timeout is None or timeout < 0:
141141
timeout = self.timeout
142142
return Client.wait(client, jobs, timeout)
143143

0 commit comments

Comments
 (0)