2424import sortedcontainers
2525from tlz import (
2626 compose ,
27- concat ,
2827 first ,
2928 groupby ,
3029 merge ,
@@ -5392,7 +5391,7 @@ async def scatter(
53925391 return keys
53935392
53945393 async def gather (self , comm = None , keys = None , serializers = None ):
5395- """Collect data in from workers"""
5394+ """Collect data from workers to the scheduler """
53965395 parent : SchedulerState = cast (SchedulerState , self )
53975396 ws : WorkerState
53985397 keys = list (keys )
@@ -5598,31 +5597,108 @@ async def proxy(self, comm=None, msg=None, worker=None, serializers=None):
55985597 )
55995598 return d [worker ]
56005599
5601- async def _delete_worker_data (self , worker_address , keys ):
5600+ async def _gather_on_worker (
5601+ self , worker_address : str , who_has : "dict[Hashable, list[str]]"
5602+ ) -> set :
5603+ """Peer-to-peer copy of keys from multiple workers to a single worker
5604+
5605+ Parameters
5606+ ----------
5607+ worker_address: str
5608+ Recipient worker address to copy keys to
5609+ who_has: dict[Hashable, list[str]]
5610+ {key: [sender address, sender address, ...], key: ...}
5611+
5612+ Returns
5613+ -------
5614+ returns:
5615+ set of keys that failed to be copied
5616+ """
5617+ try :
5618+ result = await retry_operation (
5619+ self .rpc (addr = worker_address ).gather , who_has = who_has
5620+ )
5621+ except OSError as e :
5622+ # This can happen e.g. if the worker is going through controlled shutdown;
5623+ # it doesn't necessarily mean that it went unexpectedly missing
5624+ logger .warning (
5625+ f"Communication with worker { worker_address } failed during "
5626+ f"replication: { e .__class__ .__name__ } : { e } "
5627+ )
5628+ return set (who_has )
5629+
5630+ parent : SchedulerState = cast (SchedulerState , self )
5631+ ws : WorkerState = parent ._workers_dv .get (worker_address )
5632+
5633+ if ws is None :
5634+ logger .warning (f"Worker { worker_address } lost during replication" )
5635+ return set (who_has )
5636+ elif result ["status" ] == "OK" :
5637+ keys_failed = set ()
5638+ keys_ok = who_has .keys ()
5639+ elif result ["status" ] == "partial-fail" :
5640+ keys_failed = set (result ["keys" ])
5641+ keys_ok = who_has .keys () - keys_failed
5642+ logger .warning (
5643+ f"Worker { worker_address } failed to acquire keys: { result ['keys' ]} "
5644+ )
5645+ else : # pragma: nocover
5646+ raise ValueError (f"Unexpected message from { worker_address } : { result } " )
5647+
5648+ for key in keys_ok :
5649+ ts : TaskState = parent ._tasks .get (key )
5650+ if ts is None or ts ._state != "memory" :
5651+ logger .warning (f"Key lost during replication: { key } " )
5652+ continue
5653+ if ts not in ws ._has_what :
5654+ ws ._nbytes += ts .get_nbytes ()
5655+ ws ._has_what [ts ] = None
5656+ ts ._who_has .add (ws )
5657+
5658+ return keys_failed
5659+
5660+ async def _delete_worker_data (self , worker_address : str , keys : "list[str]" ) -> None :
56025661 """Delete data from a worker and update the corresponding worker/task states
56035662
56045663 Parameters
56055664 ----------
56065665 worker_address: str
56075666 Worker address to delete keys from
5608- keys: List [str]
5667+ keys: list [str]
56095668 List of keys to delete on the specified worker
56105669 """
56115670 parent : SchedulerState = cast (SchedulerState , self )
56125671
5613- await retry_operation (
5614- self .rpc (addr = worker_address ).free_keys ,
5615- keys = list (keys ),
5616- reason = "rebalance/replicate" ,
5617- )
5672+ try :
5673+ await retry_operation (
5674+ self .rpc (addr = worker_address ).free_keys ,
5675+ keys = list (keys ),
5676+ reason = "rebalance/replicate" ,
5677+ )
5678+ except OSError as e :
5679+ # This can happen e.g. if the worker is going through controlled shutdown;
5680+ # it doesn't necessarily mean that it went unexpectedly missing
5681+ logger .warning (
5682+ f"Communication with worker { worker_address } failed during "
5683+ f"replication: { e .__class__ .__name__ } : { e } "
5684+ )
5685+ return
5686+
5687+ ws : WorkerState = parent ._workers_dv .get (worker_address )
5688+ if ws is None :
5689+ return
5690+
5691+ for key in keys :
5692+ ts : TaskState = parent ._tasks .get (key )
5693+ if ts is not None and ts in ws ._has_what :
5694+ assert ts ._state == "memory"
5695+ del ws ._has_what [ts ]
5696+ ts ._who_has .remove (ws )
5697+ ws ._nbytes -= ts .get_nbytes ()
5698+ if not ts ._who_has :
5699+ # Last copy deleted
5700+ self .transitions ({key : "released" })
56185701
5619- ws : WorkerState = parent ._workers_dv [worker_address ]
5620- ts : TaskState
5621- tasks : set = {parent ._tasks [key ] for key in keys }
5622- for ts in tasks :
5623- del ws ._has_what [ts ]
5624- ts ._who_has .remove (ws )
5625- ws ._nbytes -= ts .get_nbytes ()
56265702 self .log_event (ws ._address , {"action" : "remove-worker-data" , "keys" : keys })
56275703
56285704 async def rebalance (
@@ -5717,14 +5793,18 @@ async def rebalance(
57175793 if k not in parent ._tasks or not parent ._tasks [k ].who_has
57185794 ]
57195795 if missing_data :
5720- return {"status" : "missing-data " , "keys" : missing_data }
5796+ return {"status" : "partial-fail " , "keys" : missing_data }
57215797
57225798 msgs = self ._rebalance_find_msgs (keys , workers )
57235799 if not msgs :
57245800 return {"status" : "OK" }
57255801
57265802 async with self ._lock :
5727- return await self ._rebalance_move_data (msgs )
5803+ result = await self ._rebalance_move_data (msgs )
5804+ if result ["status" ] == "partial-fail" and keys is None :
5805+ # Only return failed keys if the client explicitly asked for them
5806+ result = {"status" : "OK" }
5807+ return result
57285808
57295809 def _rebalance_find_msgs (
57305810 self : SchedulerState ,
@@ -5881,7 +5961,7 @@ def _rebalance_find_msgs(
58815961 # move on to the next task of the same sender.
58825962 continue
58835963
5884- # Schedule task for transfer from sender to receiver
5964+ # Schedule task for transfer from sender to recipient
58855965 msgs .append ((snd_ws , rec_ws , ts ))
58865966
58875967 # *_bytes_max/min are all negative for heap sorting
@@ -5902,7 +5982,7 @@ def _rebalance_find_msgs(
59025982 else :
59035983 heapq .heappop (senders )
59045984
5905- # If receiver still has bytes to gain, push it back into the receivers
5985+ # If recipient still has bytes to gain, push it back into the recipients
59065986 # heap; it may or may not come back on top again.
59075987 if rec_bytes_min < 0 :
59085988 # See definition of recipients above
@@ -5927,29 +6007,46 @@ async def _rebalance_move_data(
59276007 self , msgs : "list[tuple[WorkerState, WorkerState, TaskState]]"
59286008 ) -> dict :
59296009 """Perform the actual transfer of data across the network in rebalance().
5930- Takes in input the output of _rebalance_find_msgs().
6010+ Takes in input the output of _rebalance_find_msgs(), that is a list of tuples:
6011+
6012+ - sender worker
6013+ - recipient worker
6014+ - task to be transferred
59316015
59326016 FIXME this method is not robust when the cluster is not idle.
59336017 """
5934- ts : TaskState
59356018 snd_ws : WorkerState
59366019 rec_ws : WorkerState
6020+ ts : TaskState
59376021
59386022 to_recipients = defaultdict (lambda : defaultdict (list ))
6023+ for snd_ws , rec_ws , ts in msgs :
6024+ to_recipients [rec_ws .address ][ts ._key ].append (snd_ws .address )
6025+ failed_keys_by_recipient = dict (
6026+ zip (
6027+ to_recipients ,
6028+ await asyncio .gather (
6029+ * (
6030+ # Note: this never raises exceptions
6031+ self ._gather_on_worker (w , who_has )
6032+ for w , who_has in to_recipients .items ()
6033+ )
6034+ ),
6035+ )
6036+ )
6037+
59396038 to_senders = defaultdict (list )
5940- for sender , recipient , ts in msgs :
5941- to_recipients [ recipient . address ][ ts ._key ]. append ( sender . address )
5942- to_senders [sender .address ].append (ts ._key )
6039+ for snd_ws , rec_ws , ts in msgs :
6040+ if ts ._key not in failed_keys_by_recipient [ rec_ws . address ]:
6041+ to_senders [snd_ws .address ].append (ts ._key )
59436042
5944- result = await asyncio .gather (
5945- * (
5946- retry_operation (self .rpc (addr = r ).gather , who_has = v )
5947- for r , v in to_recipients .items ()
5948- )
6043+ # Note: this never raises exceptions
6044+ await asyncio .gather (
6045+ * (self ._delete_worker_data (r , v ) for r , v in to_senders .items ())
59496046 )
6047+
59506048 for r , v in to_recipients .items ():
59516049 self .log_event (r , {"action" : "rebalance" , "who_has" : v })
5952-
59536050 self .log_event (
59546051 "all" ,
59556052 {
@@ -5960,31 +6057,11 @@ async def _rebalance_move_data(
59606057 },
59616058 )
59626059
5963- if any (r ["status" ] != "OK" for r in result ):
5964- return {
5965- "status" : "missing-data" ,
5966- "keys" : list (
5967- concat (
5968- r ["keys" ].keys ()
5969- for r in result
5970- if r ["status" ] == "missing-data"
5971- )
5972- ),
5973- }
5974-
5975- for snd_ws , rec_ws , ts in msgs :
5976- assert ts ._state == "memory"
5977- ts ._who_has .add (rec_ws )
5978- rec_ws ._has_what [ts ] = None
5979- rec_ws .nbytes += ts .get_nbytes ()
5980- self .log .append (
5981- ("rebalance" , ts ._key , time (), snd_ws .address , rec_ws .address )
5982- )
5983-
5984- await asyncio .gather (
5985- * (self ._delete_worker_data (r , v ) for r , v in to_senders .items ())
5986- )
5987- return {"status" : "OK" }
6060+ missing_keys = {k for r in failed_keys_by_recipient .values () for k in r }
6061+ if missing_keys :
6062+ return {"status" : "partial-fail" , "keys" : list (missing_keys )}
6063+ else :
6064+ return {"status" : "OK" }
59886065
59896066 async def replicate (
59906067 self ,
@@ -6035,7 +6112,7 @@ async def replicate(
60356112 tasks = {parent ._tasks [k ] for k in keys }
60366113 missing_data = [ts ._key for ts in tasks if not ts ._who_has ]
60376114 if missing_data :
6038- return {"status" : "missing-data " , "keys" : missing_data }
6115+ return {"status" : "partial-fail " , "keys" : missing_data }
60396116
60406117 # Delete extraneous data
60416118 if delete :
@@ -6048,6 +6125,7 @@ async def replicate(
60486125 ):
60496126 del_worker_tasks [ws ].add (ts )
60506127
6128+ # Note: this never raises exceptions
60516129 await asyncio .gather (
60526130 * [
60536131 self ._delete_worker_data (ws ._address , [t .key for t in tasks ])
@@ -6077,19 +6155,15 @@ async def replicate(
60776155 wws ._address for wws in ts ._who_has
60786156 ]
60796157
6080- results = await asyncio .gather (
6158+ await asyncio .gather (
60816159 * (
6082- retry_operation (self .rpc (addr = w ).gather , who_has = who_has )
6160+ # Note: this never raises exceptions
6161+ self ._gather_on_worker (w , who_has )
60836162 for w , who_has in gathers .items ()
60846163 )
60856164 )
6086- for w , v in zip (gathers , results ):
6087- if v ["status" ] == "OK" :
6088- self .add_keys (worker = w , keys = list (gathers [w ]))
6089- else :
6090- logger .warning ("Communication failed during replication: %s" , v )
6091-
6092- self .log_event (w , {"action" : "replicate-add" , "keys" : gathers [w ]})
6165+ for r , v in gathers .items ():
6166+ self .log_event (r , {"action" : "replicate-add" , "who_has" : v })
60936167
60946168 self .log_event (
60956169 "all" ,
@@ -7655,7 +7729,7 @@ def validate_task_state(ts: TaskState):
76557729 assert dts ._state != "forgotten"
76567730
76577731 assert (ts ._processing_on is not None ) == (ts ._state == "processing" )
7658- assert ( not not ts ._who_has ) == (ts ._state == "memory" ), (ts , ts ._who_has )
7732+ assert bool ( ts ._who_has ) == (ts ._state == "memory" ), (ts , ts ._who_has , ts . _state )
76597733
76607734 if ts ._state == "processing" :
76617735 assert all ([dts ._who_has for dts in ts ._dependencies ]), (
0 commit comments