Skip to content

Commit 142583e

Browse files
authored
Merge pull request #874 from minrk/broadcast-map
add BroadcastView.map
2 parents 9ff1800 + 7aa1d07 commit 142583e

File tree

2 files changed

+111
-28
lines changed

2 files changed

+111
-28
lines changed

ipyparallel/client/view.py

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from IPython import get_ipython
1717
from traitlets import Any, Bool, CFloat, Dict, HasTraits, Instance, Integer, List, Set
1818

19+
import ipyparallel as ipp
1920
from ipyparallel import util
2021
from ipyparallel.controller.dependency import Dependency, dependent
2122

@@ -767,7 +768,7 @@ def scatter(
767768
mapObject = Map.dists[dist]()
768769
nparts = len(targets)
769770
futures = []
770-
trackers = []
771+
_lengths = []
771772
for index, engineid in enumerate(targets):
772773
partition = mapObject.getPartition(seq, index, nparts)
773774
if flatten and len(partition) == 1:
@@ -777,10 +778,12 @@ def scatter(
777778
r = self.push(ns, block=False, track=track, targets=engineid)
778779
r.owner = False
779780
futures.extend(r._children)
781+
_lengths.append(len(partition))
780782

781783
r = AsyncResult(
782784
self.client, futures, fname='scatter', targets=targets, owner=True
783785
)
786+
r._scatter_lengths = _lengths
784787
if block:
785788
r.wait()
786789
else:
@@ -930,7 +933,6 @@ def _really_apply(
930933
track = self.track if track is None else track
931934
targets = self.targets if targets is None else targets
932935
idents, _targets = self.client._build_targets(targets)
933-
futures = []
934936

935937
pf = PrePickled(f)
936938
pargs = [PrePickled(arg) for arg in args]
@@ -1014,8 +1016,113 @@ def make_asyncresult(message_future):
10141016
pass
10151017
return ar
10161018

1017-
def map(self, f, *sequences, **kwargs):
1018-
raise NotImplementedError("BroadcastView.map not yet implemented")
1019+
@staticmethod
1020+
def _broadcast_map(f, *sequence_names):
1021+
"""Function passed to apply
1022+
1023+
Equivalent, but account for the fact that scatter
1024+
occurs in a separate step.
1025+
1026+
Does these things:
1027+
- resolve sequence names to sequences in the user namespace
1028+
- collect list(map(f, *squences))
1029+
- cleanup temporary sequence variables from scatter
1030+
"""
1031+
sequences = []
1032+
ip = get_ipython()
1033+
for seq_name in sequence_names:
1034+
sequences.append(ip.user_ns.pop(seq_name))
1035+
return list(map(f, *sequences))
1036+
1037+
@_not_coalescing
1038+
def map(self, f, *sequences, block=None, track=False, return_exceptions=False):
1039+
"""Parallel version of builtin `map`, using this View's `targets`.
1040+
1041+
There will be one task per engine, so work will be chunked
1042+
if the sequences are longer than `targets`.
1043+
1044+
Results can be iterated as they are ready, but will become available in chunks.
1045+
1046+
.. note::
1047+
1048+
BroadcastView does not yet have a fully native map implementation.
1049+
In particular, the scatter step is still one message per engine,
1050+
identical to DirectView,
1051+
and typically slower due to the more complex scheduler.
1052+
1053+
It is more efficient to partition inputs via other means (e.g. SPMD based on rank & size)
1054+
and use `apply` to submit all tasks in one broadcast.
1055+
1056+
.. versionadded:: 8.8
1057+
1058+
Parameters
1059+
----------
1060+
f : callable
1061+
function to be mapped
1062+
*sequences : one or more sequences of matching length
1063+
the sequences to be distributed and passed to `f`
1064+
block : bool [default self.block]
1065+
whether to wait for the result or not
1066+
track : bool [default False]
1067+
Track underlying zmq send to indicate when it is safe to modify memory.
1068+
Only for zero-copy sends such as numpy arrays that are going to be modified in-place.
1069+
return_exceptions : bool [default False]
1070+
Return remote Exceptions in the result sequence instead of raising them.
1071+
1072+
Returns
1073+
-------
1074+
If block=False
1075+
An :class:`~ipyparallel.client.asyncresult.AsyncMapResult` instance.
1076+
An object like AsyncResult, but which reassembles the sequence of results
1077+
into a single list. AsyncMapResults can be iterated through before all
1078+
results are complete.
1079+
else
1080+
A list, the result of ``map(f,*sequences)``
1081+
"""
1082+
if block is None:
1083+
block = self.block
1084+
if track is None:
1085+
track = self.track
1086+
1087+
# unique identifier, since we're living in the interactive namespace
1088+
map_key = secrets.token_hex(5)
1089+
dist = 'b'
1090+
map_object = Map.dists[dist]()
1091+
1092+
seq_names = []
1093+
for i, seq in enumerate(sequences):
1094+
seq_name = f"_seq_{map_key}_{i}"
1095+
seq_names.append(seq_name)
1096+
try:
1097+
len(seq)
1098+
except Exception:
1099+
# cast length-less sequences (e.g. Range) to list
1100+
seq = list(seq)
1101+
1102+
ar = self.scatter(seq_name, seq, dist=dist, block=False, track=track)
1103+
scatter_chunk_sizes = ar._scatter_lengths
1104+
1105+
# submit the map tasks as an actual broadcast
1106+
ar = self.apply(self._broadcast_map, f, *seq_names)
1107+
ar.owner = False
1108+
# re-wrap messages in an AsyncMapResult to get map API
1109+
# this is where the 'gather' reconstruction happens
1110+
amr = ipp.AsyncMapResult(
1111+
self.client,
1112+
ar._children,
1113+
map_object,
1114+
fname=getname(f),
1115+
return_exceptions=return_exceptions,
1116+
chunk_sizes={
1117+
future.msg_id: chunk_size
1118+
for future, chunk_size in zip(ar._children, scatter_chunk_sizes)
1119+
},
1120+
)
1121+
1122+
if block:
1123+
return amr.get()
1124+
else:
1125+
return amr
10191126

10201127
# scatter/gather cannot be coalescing yet
10211128
scatter = _not_coalescing(DirectView.scatter)

ipyparallel/tests/test_view_broadcast.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,30 +34,6 @@ def teardown_method(self):
3434
if not self._broadcast_view_used:
3535
pytest.skip("No broadcast view used")
3636

37-
@needs_map
38-
def test_map(self):
39-
pass
40-
41-
@needs_map
42-
def test_map_ref(self):
43-
pass
44-
45-
@needs_map
46-
def test_map_reference(self):
47-
pass
48-
49-
@needs_map
50-
def test_map_iterable(self):
51-
pass
52-
53-
@needs_map
54-
def test_map_empty_sequence(self):
55-
pass
56-
57-
@needs_map
58-
def test_map_numpy(self):
59-
pass
60-
6137
@pytest.mark.xfail(reason="Tracking gets disconnected from original message")
6238
def test_scatter_tracked(self):
6339
pass

0 commit comments

Comments
 (0)