16
16
from IPython import get_ipython
17
17
from traitlets import Any , Bool , CFloat , Dict , HasTraits , Instance , Integer , List , Set
18
18
19
+ import ipyparallel as ipp
19
20
from ipyparallel import util
20
21
from ipyparallel .controller .dependency import Dependency , dependent
21
22
@@ -767,7 +768,7 @@ def scatter(
767
768
mapObject = Map .dists [dist ]()
768
769
nparts = len (targets )
769
770
futures = []
770
- trackers = []
771
+ _lengths = []
771
772
for index , engineid in enumerate (targets ):
772
773
partition = mapObject .getPartition (seq , index , nparts )
773
774
if flatten and len (partition ) == 1 :
@@ -777,10 +778,12 @@ def scatter(
777
778
r = self .push (ns , block = False , track = track , targets = engineid )
778
779
r .owner = False
779
780
futures .extend (r ._children )
781
+ _lengths .append (len (partition ))
780
782
781
783
r = AsyncResult (
782
784
self .client , futures , fname = 'scatter' , targets = targets , owner = True
783
785
)
786
+ r ._scatter_lengths = _lengths
784
787
if block :
785
788
r .wait ()
786
789
else :
@@ -930,7 +933,6 @@ def _really_apply(
930
933
track = self .track if track is None else track
931
934
targets = self .targets if targets is None else targets
932
935
idents , _targets = self .client ._build_targets (targets )
933
- futures = []
934
936
935
937
pf = PrePickled (f )
936
938
pargs = [PrePickled (arg ) for arg in args ]
@@ -1014,8 +1016,113 @@ def make_asyncresult(message_future):
1014
1016
pass
1015
1017
return ar
1016
1018
1017
- def map (self , f , * sequences , ** kwargs ):
1018
- raise NotImplementedError ("BroadcastView.map not yet implemented" )
1019
+ @staticmethod
1020
+ def _broadcast_map (f , * sequence_names ):
1021
+ """Function passed to apply
1022
+
1023
+ Equivalent, but account for the fact that scatter
1024
+ occurs in a separate step.
1025
+
1026
+ Does these things:
1027
+ - resolve sequence names to sequences in the user namespace
1028
+ - collect list(map(f, *squences))
1029
+ - cleanup temporary sequence variables from scatter
1030
+ """
1031
+ sequences = []
1032
+ ip = get_ipython ()
1033
+ for seq_name in sequence_names :
1034
+ sequences .append (ip .user_ns .pop (seq_name ))
1035
+ return list (map (f , * sequences ))
1036
+
1037
+ @_not_coalescing
1038
+ def map (self , f , * sequences , block = None , track = False , return_exceptions = False ):
1039
+ """Parallel version of builtin `map`, using this View's `targets`.
1040
+
1041
+ There will be one task per engine, so work will be chunked
1042
+ if the sequences are longer than `targets`.
1043
+
1044
+ Results can be iterated as they are ready, but will become available in chunks.
1045
+
1046
+ .. note::
1047
+
1048
+ BroadcastView does not yet have a fully native map implementation.
1049
+ In particular, the scatter step is still one message per engine,
1050
+ identical to DirectView,
1051
+ and typically slower due to the more complex scheduler.
1052
+
1053
+ It is more efficient to partition inputs via other means (e.g. SPMD based on rank & size)
1054
+ and use `apply` to submit all tasks in one broadcast.
1055
+
1056
+ .. versionadded:: 8.8
1057
+
1058
+ Parameters
1059
+ ----------
1060
+ f : callable
1061
+ function to be mapped
1062
+ *sequences : one or more sequences of matching length
1063
+ the sequences to be distributed and passed to `f`
1064
+ block : bool [default self.block]
1065
+ whether to wait for the result or not
1066
+ track : bool [default False]
1067
+ Track underlying zmq send to indicate when it is safe to modify memory.
1068
+ Only for zero-copy sends such as numpy arrays that are going to be modified in-place.
1069
+ return_exceptions : bool [default False]
1070
+ Return remote Exceptions in the result sequence instead of raising them.
1071
+
1072
+ Returns
1073
+ -------
1074
+ If block=False
1075
+ An :class:`~ipyparallel.client.asyncresult.AsyncMapResult` instance.
1076
+ An object like AsyncResult, but which reassembles the sequence of results
1077
+ into a single list. AsyncMapResults can be iterated through before all
1078
+ results are complete.
1079
+ else
1080
+ A list, the result of ``map(f,*sequences)``
1081
+ """
1082
+ if block is None :
1083
+ block = self .block
1084
+ if track is None :
1085
+ track = self .track
1086
+
1087
+ # unique identifier, since we're living in the interactive namespace
1088
+ map_key = secrets .token_hex (5 )
1089
+ dist = 'b'
1090
+ map_object = Map .dists [dist ]()
1091
+
1092
+ seq_names = []
1093
+ for i , seq in enumerate (sequences ):
1094
+ seq_name = f"_seq_{ map_key } _{ i } "
1095
+ seq_names .append (seq_name )
1096
+ try :
1097
+ len (seq )
1098
+ except Exception :
1099
+ # cast length-less sequences (e.g. Range) to list
1100
+ seq = list (seq )
1101
+
1102
+ ar = self .scatter (seq_name , seq , dist = dist , block = False , track = track )
1103
+ scatter_chunk_sizes = ar ._scatter_lengths
1104
+
1105
+ # submit the map tasks as an actual broadcast
1106
+ ar = self .apply (self ._broadcast_map , f , * seq_names )
1107
+ ar .owner = False
1108
+ # re-wrap messages in an AsyncMapResult to get map API
1109
+ # this is where the 'gather' reconstruction happens
1110
+ amr = ipp .AsyncMapResult (
1111
+ self .client ,
1112
+ ar ._children ,
1113
+ map_object ,
1114
+ fname = getname (f ),
1115
+ return_exceptions = return_exceptions ,
1116
+ chunk_sizes = {
1117
+ future .msg_id : chunk_size
1118
+ for future , chunk_size in zip (ar ._children , scatter_chunk_sizes )
1119
+ },
1120
+ )
1121
+
1122
+ if block :
1123
+ return amr .get ()
1124
+ else :
1125
+ return amr
1019
1126
1020
1127
# scatter/gather cannot be coalescing yet
1021
1128
scatter = _not_coalescing (DirectView .scatter )
0 commit comments