Skip to content

Commit b23437f

Browse files
committed
use chunk size for length in AsyncMapResults
len(AMR) and AMR.progress represent the number of inputs, _not_ the number of messages Mainly produces more intuitive reporting of progress for DirectView.map Reporting granularity is unchanged, as each chunk will still only be updated once, but at least it will be the expected number
1 parent d206c91 commit b23437f

File tree

5 files changed

+55
-13
lines changed

5 files changed

+55
-13
lines changed

ipyparallel/client/asyncresult.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class AsyncResult(Future):
7272
owner = False
7373
_last_display_prefix = ""
7474
_stream_trailing_newline = True
75+
_chunk_sizes = None
7576

7677
def __init__(
7778
self,
@@ -81,6 +82,7 @@ def __init__(
8182
targets=None,
8283
owner=False,
8384
return_exceptions=False,
85+
chunk_sizes=None,
8486
):
8587
super().__init__()
8688
if not isinstance(children, list):
@@ -90,6 +92,7 @@ def __init__(
9092
self._single_result = False
9193

9294
self._return_exceptions = return_exceptions
95+
self._chunk_sizes = chunk_sizes or {}
9396

9497
if isinstance(children[0], str):
9598
self.msg_ids = children
@@ -748,8 +751,14 @@ def __iter__(self):
748751
# already done
749752
yield from rlist
750753

754+
@lru_cache()
751755
def __len__(self):
752-
return len(self.msg_ids)
756+
return self._count_chunks(*self.msg_ids)
757+
758+
@lru_cache()
759+
def _count_chunks(self, *msg_ids):
760+
"""Count the granular tasks"""
761+
return sum(self._chunk_sizes.setdefault(msg_id, 1) for msg_id in msg_ids)
753762

754763
# -------------------------------------
755764
# Sugar methods and attributes
@@ -795,7 +804,9 @@ def progress(self):
795804
Fractional progress would be given by 1.0 * ar.progress / len(ar)
796805
"""
797806
self.wait(0)
798-
return len(self) - len(set(self.msg_ids).intersection(self._client.outstanding))
807+
finished_msg_ids = set(self.msg_ids).intersection(self._client.outstanding)
808+
finished_count = self._count_chunks(*finished_msg_ids)
809+
return len(self) - finished_count
799810

800811
@property
801812
def elapsed(self):
@@ -1069,6 +1080,7 @@ def __init__(
10691080
fname='',
10701081
ordered=True,
10711082
return_exceptions=False,
1083+
chunk_sizes=None,
10721084
):
10731085
self._mapObject = mapObject
10741086
self.ordered = ordered
@@ -1078,6 +1090,7 @@ def __init__(
10781090
children,
10791091
fname=fname,
10801092
return_exceptions=return_exceptions,
1093+
chunk_sizes=chunk_sizes,
10811094
)
10821095
self._single_result = False
10831096

ipyparallel/client/remotefunction.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,9 @@ def __call__(self, *sequences, **kwargs):
271271

272272
pf = PrePickled(self.func)
273273

274+
chunk_sizes = {}
275+
chunk_size = 1
276+
274277
for index, t in enumerate(targets):
275278
args = []
276279
for seq in sequences:
@@ -279,6 +282,10 @@ def __call__(self, *sequences, **kwargs):
279282

280283
if sum(len(arg) for arg in args) == 0:
281284
continue
285+
286+
if _mapping:
287+
chunk_size = min(len(arg) for arg in args)
288+
282289
args = [PrePickled(arg) for arg in args]
283290

284291
if _mapping:
@@ -292,6 +299,8 @@ def __call__(self, *sequences, **kwargs):
292299
ar = view.apply(f, *args)
293300
ar.owner = False
294301

302+
msg_id = ar.msg_ids[0]
303+
chunk_sizes[msg_id] = chunk_size
295304
futures.extend(ar._children)
296305

297306
r = AsyncMapResult(
@@ -301,6 +310,7 @@ def __call__(self, *sequences, **kwargs):
301310
fname=getname(self.func),
302311
ordered=self.ordered,
303312
return_exceptions=self.return_exceptions,
313+
chunk_sizes=chunk_sizes,
304314
)
305315

306316
if self.block:

ipyparallel/tests/test_asyncresult.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,3 +480,11 @@ def fail(i):
480480
done, pending = amr.wait(timeout=0, return_when=ipp.FIRST_EXCEPTION)
481481
assert pending == set()
482482
assert len(done) == len(amr)
483+
484+
def test_progress(self):
485+
dv = self.client[:]
486+
amr = dv.map_async(time.sleep, [0.2] * 2 * len(dv))
487+
assert len(amr) == len(dv) * 2
488+
assert amr.progress == 0
489+
amr.wait_interactive()
490+
assert amr.progress == len(amr)

ipyparallel/tests/test_lbview.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,19 @@ def f(x):
2929
return x ** 2
3030

3131
data = list(range(16))
32-
r = self.view.map_sync(f, data)
33-
self.assertEqual(r, list(map(f, data)))
32+
ar = self.view.map_async(f, data)
33+
assert len(ar) == len(data)
34+
r = ar.get()
35+
assert r == list(map(f, data))
3436

3537
def test_map_generator(self):
3638
def f(x):
3739
return x ** 2
3840

3941
data = list(range(16))
40-
r = self.view.map_sync(f, iter(data))
41-
self.assertEqual(r, list(map(f, iter(data))))
42+
ar = self.view.map_async(f, iter(data))
43+
r = ar.get()
44+
assert r == list(map(f, iter(data)))
4245

4346
def test_map_short_first(self):
4447
def f(x, y):
@@ -51,8 +54,10 @@ def f(x, y):
5154
data = list(range(10))
5255
data2 = list(range(4))
5356

54-
r = self.view.map_sync(f, data, data2)
55-
self.assertEqual(r, list(map(f, data, data2)))
57+
ar = self.view.map_async(f, data, data2)
58+
assert len(ar) == len(data2)
59+
r = ar.get()
60+
assert r == list(map(f, data, data2))
5661

5762
def test_map_short_last(self):
5863
def f(x, y):
@@ -65,8 +70,10 @@ def f(x, y):
6570
data = list(range(4))
6671
data2 = list(range(10))
6772

68-
r = self.view.map_sync(f, data, data2)
69-
self.assertEqual(r, list(map(f, data, data2)))
73+
ar = self.view.map_async(f, data, data2)
74+
assert len(ar) == len(data)
75+
r = ar.get()
76+
assert r == list(map(f, data, data2))
7077

7178
def test_map_unordered(self):
7279
def f(x):

ipyparallel/tests/test_view.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,9 @@ def f(x):
334334
return x ** 2
335335

336336
data = list(range(16))
337-
r = view.map_sync(f, data)
337+
ar = view.map_async(f, data)
338+
assert len(ar) == len(data)
339+
r = ar.get()
338340
self.assertEqual(r, list(map(f, data)))
339341

340342
def test_map_empty_sequence(self):
@@ -361,15 +363,17 @@ def test_map_numpy(self):
361363
view = self.client[:]
362364
# 101 is prime, so it won't be evenly distributed
363365
arr = numpy.arange(101)
364-
r = view.map_sync(lambda x: x, arr)
366+
ar = view.map_async(lambda x: x, arr)
367+
assert len(ar) == len(arr)
368+
r = ar.get()
365369
assert_array_equal(r, arr)
366370

367371
def test_scatter_gather_nonblocking(self):
368372
data = list(range(16))
369373
view = self.client[:]
370374
view.scatter('a', data, block=False)
371375
ar = view.gather('a', block=False)
372-
self.assertEqual(ar.get(), data)
376+
assert ar.get() == data
373377

374378
@skip_without('numpy')
375379
def test_scatter_gather_numpy_nonblocking(self):

0 commit comments

Comments
 (0)