Skip to content

Commit 7732dc5

Browse files
committed
Merge remote-tracking branch 'upstream/main' into dashboard_address
2 parents 680f132 + 5f01fe6 commit 7732dc5

File tree

6 files changed

+416
-154
lines changed

6 files changed

+416
-154
lines changed

distributed/client.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3095,11 +3095,9 @@ async def _rebalance(self, futures=None, workers=None):
30953095
else:
30963096
keys = None
30973097
result = await self.scheduler.rebalance(keys=keys, workers=workers)
3098-
if result["status"] == "missing-data":
3099-
raise KeyError(
3100-
f"During rebalance {len(result['keys'])} keys were found to be missing"
3101-
)
3102-
assert result["status"] == "OK"
3098+
if result["status"] == "partial-fail":
3099+
raise KeyError(f"Could not rebalance keys: {result['keys']}")
3100+
assert result["status"] == "OK", result
31033101

31043102
def rebalance(self, futures=None, workers=None, **kwargs):
31053103
"""Rebalance data within network

distributed/scheduler.py

Lines changed: 141 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import sortedcontainers
2525
from tlz import (
2626
compose,
27-
concat,
2827
first,
2928
groupby,
3029
merge,
@@ -5392,7 +5391,7 @@ async def scatter(
53925391
return keys
53935392

53945393
async def gather(self, comm=None, keys=None, serializers=None):
5395-
"""Collect data in from workers"""
5394+
"""Collect data from workers to the scheduler"""
53965395
parent: SchedulerState = cast(SchedulerState, self)
53975396
ws: WorkerState
53985397
keys = list(keys)
@@ -5598,31 +5597,108 @@ async def proxy(self, comm=None, msg=None, worker=None, serializers=None):
55985597
)
55995598
return d[worker]
56005599

5601-
async def _delete_worker_data(self, worker_address, keys):
5600+
async def _gather_on_worker(
5601+
self, worker_address: str, who_has: "dict[Hashable, list[str]]"
5602+
) -> set:
5603+
"""Peer-to-peer copy of keys from multiple workers to a single worker
5604+
5605+
Parameters
5606+
----------
5607+
worker_address: str
5608+
Recipient worker address to copy keys to
5609+
who_has: dict[Hashable, list[str]]
5610+
{key: [sender address, sender address, ...], key: ...}
5611+
5612+
Returns
5613+
-------
5614+
returns:
5615+
set of keys that failed to be copied
5616+
"""
5617+
try:
5618+
result = await retry_operation(
5619+
self.rpc(addr=worker_address).gather, who_has=who_has
5620+
)
5621+
except OSError as e:
5622+
# This can happen e.g. if the worker is going through controlled shutdown;
5623+
# it doesn't necessarily mean that it went unexpectedly missing
5624+
logger.warning(
5625+
f"Communication with worker {worker_address} failed during "
5626+
f"replication: {e.__class__.__name__}: {e}"
5627+
)
5628+
return set(who_has)
5629+
5630+
parent: SchedulerState = cast(SchedulerState, self)
5631+
ws: WorkerState = parent._workers_dv.get(worker_address)
5632+
5633+
if ws is None:
5634+
logger.warning(f"Worker {worker_address} lost during replication")
5635+
return set(who_has)
5636+
elif result["status"] == "OK":
5637+
keys_failed = set()
5638+
keys_ok = who_has.keys()
5639+
elif result["status"] == "partial-fail":
5640+
keys_failed = set(result["keys"])
5641+
keys_ok = who_has.keys() - keys_failed
5642+
logger.warning(
5643+
f"Worker {worker_address} failed to acquire keys: {result['keys']}"
5644+
)
5645+
else: # pragma: nocover
5646+
raise ValueError(f"Unexpected message from {worker_address}: {result}")
5647+
5648+
for key in keys_ok:
5649+
ts: TaskState = parent._tasks.get(key)
5650+
if ts is None or ts._state != "memory":
5651+
logger.warning(f"Key lost during replication: {key}")
5652+
continue
5653+
if ts not in ws._has_what:
5654+
ws._nbytes += ts.get_nbytes()
5655+
ws._has_what[ts] = None
5656+
ts._who_has.add(ws)
5657+
5658+
return keys_failed
5659+
5660+
async def _delete_worker_data(self, worker_address: str, keys: "list[str]") -> None:
56025661
"""Delete data from a worker and update the corresponding worker/task states
56035662
56045663
Parameters
56055664
----------
56065665
worker_address: str
56075666
Worker address to delete keys from
5608-
keys: List[str]
5667+
keys: list[str]
56095668
List of keys to delete on the specified worker
56105669
"""
56115670
parent: SchedulerState = cast(SchedulerState, self)
56125671

5613-
await retry_operation(
5614-
self.rpc(addr=worker_address).free_keys,
5615-
keys=list(keys),
5616-
reason="rebalance/replicate",
5617-
)
5672+
try:
5673+
await retry_operation(
5674+
self.rpc(addr=worker_address).free_keys,
5675+
keys=list(keys),
5676+
reason="rebalance/replicate",
5677+
)
5678+
except OSError as e:
5679+
# This can happen e.g. if the worker is going through controlled shutdown;
5680+
# it doesn't necessarily mean that it went unexpectedly missing
5681+
logger.warning(
5682+
f"Communication with worker {worker_address} failed during "
5683+
f"replication: {e.__class__.__name__}: {e}"
5684+
)
5685+
return
5686+
5687+
ws: WorkerState = parent._workers_dv.get(worker_address)
5688+
if ws is None:
5689+
return
5690+
5691+
for key in keys:
5692+
ts: TaskState = parent._tasks.get(key)
5693+
if ts is not None and ts in ws._has_what:
5694+
assert ts._state == "memory"
5695+
del ws._has_what[ts]
5696+
ts._who_has.remove(ws)
5697+
ws._nbytes -= ts.get_nbytes()
5698+
if not ts._who_has:
5699+
# Last copy deleted
5700+
self.transitions({key: "released"})
56185701

5619-
ws: WorkerState = parent._workers_dv[worker_address]
5620-
ts: TaskState
5621-
tasks: set = {parent._tasks[key] for key in keys}
5622-
for ts in tasks:
5623-
del ws._has_what[ts]
5624-
ts._who_has.remove(ws)
5625-
ws._nbytes -= ts.get_nbytes()
56265702
self.log_event(ws._address, {"action": "remove-worker-data", "keys": keys})
56275703

56285704
async def rebalance(
@@ -5717,14 +5793,18 @@ async def rebalance(
57175793
if k not in parent._tasks or not parent._tasks[k].who_has
57185794
]
57195795
if missing_data:
5720-
return {"status": "missing-data", "keys": missing_data}
5796+
return {"status": "partial-fail", "keys": missing_data}
57215797

57225798
msgs = self._rebalance_find_msgs(keys, workers)
57235799
if not msgs:
57245800
return {"status": "OK"}
57255801

57265802
async with self._lock:
5727-
return await self._rebalance_move_data(msgs)
5803+
result = await self._rebalance_move_data(msgs)
5804+
if result["status"] == "partial-fail" and keys is None:
5805+
# Only return failed keys if the client explicitly asked for them
5806+
result = {"status": "OK"}
5807+
return result
57285808

57295809
def _rebalance_find_msgs(
57305810
self: SchedulerState,
@@ -5881,7 +5961,7 @@ def _rebalance_find_msgs(
58815961
# move on to the next task of the same sender.
58825962
continue
58835963

5884-
# Schedule task for transfer from sender to receiver
5964+
# Schedule task for transfer from sender to recipient
58855965
msgs.append((snd_ws, rec_ws, ts))
58865966

58875967
# *_bytes_max/min are all negative for heap sorting
@@ -5902,7 +5982,7 @@ def _rebalance_find_msgs(
59025982
else:
59035983
heapq.heappop(senders)
59045984

5905-
# If receiver still has bytes to gain, push it back into the receivers
5985+
# If recipient still has bytes to gain, push it back into the recipients
59065986
# heap; it may or may not come back on top again.
59075987
if rec_bytes_min < 0:
59085988
# See definition of recipients above
@@ -5927,29 +6007,46 @@ async def _rebalance_move_data(
59276007
self, msgs: "list[tuple[WorkerState, WorkerState, TaskState]]"
59286008
) -> dict:
59296009
"""Perform the actual transfer of data across the network in rebalance().
5930-
Takes in input the output of _rebalance_find_msgs().
6010+
Takes in input the output of _rebalance_find_msgs(), that is a list of tuples:
6011+
6012+
- sender worker
6013+
- recipient worker
6014+
- task to be transferred
59316015
59326016
FIXME this method is not robust when the cluster is not idle.
59336017
"""
5934-
ts: TaskState
59356018
snd_ws: WorkerState
59366019
rec_ws: WorkerState
6020+
ts: TaskState
59376021

59386022
to_recipients = defaultdict(lambda: defaultdict(list))
6023+
for snd_ws, rec_ws, ts in msgs:
6024+
to_recipients[rec_ws.address][ts._key].append(snd_ws.address)
6025+
failed_keys_by_recipient = dict(
6026+
zip(
6027+
to_recipients,
6028+
await asyncio.gather(
6029+
*(
6030+
# Note: this never raises exceptions
6031+
self._gather_on_worker(w, who_has)
6032+
for w, who_has in to_recipients.items()
6033+
)
6034+
),
6035+
)
6036+
)
6037+
59396038
to_senders = defaultdict(list)
5940-
for sender, recipient, ts in msgs:
5941-
to_recipients[recipient.address][ts._key].append(sender.address)
5942-
to_senders[sender.address].append(ts._key)
6039+
for snd_ws, rec_ws, ts in msgs:
6040+
if ts._key not in failed_keys_by_recipient[rec_ws.address]:
6041+
to_senders[snd_ws.address].append(ts._key)
59436042

5944-
result = await asyncio.gather(
5945-
*(
5946-
retry_operation(self.rpc(addr=r).gather, who_has=v)
5947-
for r, v in to_recipients.items()
5948-
)
6043+
# Note: this never raises exceptions
6044+
await asyncio.gather(
6045+
*(self._delete_worker_data(r, v) for r, v in to_senders.items())
59496046
)
6047+
59506048
for r, v in to_recipients.items():
59516049
self.log_event(r, {"action": "rebalance", "who_has": v})
5952-
59536050
self.log_event(
59546051
"all",
59556052
{
@@ -5960,31 +6057,11 @@ async def _rebalance_move_data(
59606057
},
59616058
)
59626059

5963-
if any(r["status"] != "OK" for r in result):
5964-
return {
5965-
"status": "missing-data",
5966-
"keys": list(
5967-
concat(
5968-
r["keys"].keys()
5969-
for r in result
5970-
if r["status"] == "missing-data"
5971-
)
5972-
),
5973-
}
5974-
5975-
for snd_ws, rec_ws, ts in msgs:
5976-
assert ts._state == "memory"
5977-
ts._who_has.add(rec_ws)
5978-
rec_ws._has_what[ts] = None
5979-
rec_ws.nbytes += ts.get_nbytes()
5980-
self.log.append(
5981-
("rebalance", ts._key, time(), snd_ws.address, rec_ws.address)
5982-
)
5983-
5984-
await asyncio.gather(
5985-
*(self._delete_worker_data(r, v) for r, v in to_senders.items())
5986-
)
5987-
return {"status": "OK"}
6060+
missing_keys = {k for r in failed_keys_by_recipient.values() for k in r}
6061+
if missing_keys:
6062+
return {"status": "partial-fail", "keys": list(missing_keys)}
6063+
else:
6064+
return {"status": "OK"}
59886065

59896066
async def replicate(
59906067
self,
@@ -6035,7 +6112,7 @@ async def replicate(
60356112
tasks = {parent._tasks[k] for k in keys}
60366113
missing_data = [ts._key for ts in tasks if not ts._who_has]
60376114
if missing_data:
6038-
return {"status": "missing-data", "keys": missing_data}
6115+
return {"status": "partial-fail", "keys": missing_data}
60396116

60406117
# Delete extraneous data
60416118
if delete:
@@ -6048,6 +6125,7 @@ async def replicate(
60486125
):
60496126
del_worker_tasks[ws].add(ts)
60506127

6128+
# Note: this never raises exceptions
60516129
await asyncio.gather(
60526130
*[
60536131
self._delete_worker_data(ws._address, [t.key for t in tasks])
@@ -6077,19 +6155,15 @@ async def replicate(
60776155
wws._address for wws in ts._who_has
60786156
]
60796157

6080-
results = await asyncio.gather(
6158+
await asyncio.gather(
60816159
*(
6082-
retry_operation(self.rpc(addr=w).gather, who_has=who_has)
6160+
# Note: this never raises exceptions
6161+
self._gather_on_worker(w, who_has)
60836162
for w, who_has in gathers.items()
60846163
)
60856164
)
6086-
for w, v in zip(gathers, results):
6087-
if v["status"] == "OK":
6088-
self.add_keys(worker=w, keys=list(gathers[w]))
6089-
else:
6090-
logger.warning("Communication failed during replication: %s", v)
6091-
6092-
self.log_event(w, {"action": "replicate-add", "keys": gathers[w]})
6165+
for r, v in gathers.items():
6166+
self.log_event(r, {"action": "replicate-add", "who_has": v})
60936167

60946168
self.log_event(
60956169
"all",
@@ -7655,7 +7729,7 @@ def validate_task_state(ts: TaskState):
76557729
assert dts._state != "forgotten"
76567730

76577731
assert (ts._processing_on is not None) == (ts._state == "processing")
7658-
assert (not not ts._who_has) == (ts._state == "memory"), (ts, ts._who_has)
7732+
assert bool(ts._who_has) == (ts._state == "memory"), (ts, ts._who_has, ts._state)
76597733

76607734
if ts._state == "processing":
76617735
assert all([dts._who_has for dts in ts._dependencies]), (

distributed/tests/test_client.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2993,21 +2993,12 @@ async def test_rebalance_unprepared(c, s, a, b):
29932993
s.validate_state()
29942994

29952995

2996-
@gen_cluster(client=True, Worker=Nanny, worker_kwargs={"memory_limit": "1 GiB"})
2997-
async def test_rebalance_raises_missing_data(c, s, *_):
2998-
a, b = s.workers
2999-
futures = c.map(lambda _: "x" * (2 ** 29 // 10), range(10), workers=[a])
3000-
await wait(futures)
3001-
# Wait for heartbeats
3002-
while s.memory.process < 2 ** 29:
3003-
await asyncio.sleep(0.1)
3004-
3005-
# Descoping the futures enqueues a coroutine to release the data on the server
3006-
del futures
3007-
with pytest.raises(KeyError, match="keys were found to be missing"):
3008-
# During the synchronous part of rebalance, the futures still exist, but they
3009-
# will be (partially) gone by the time the actual transferring happens.
3010-
await c.rebalance()
2996+
@gen_cluster(client=True)
2997+
async def test_rebalance_raises_on_explicit_missing_data(c, s, a, b):
2998+
"""rebalance() raises KeyError if explicitly listed futures disappear"""
2999+
f = Future("x", client=c, state="memory")
3000+
with pytest.raises(KeyError, match="Could not rebalance keys:"):
3001+
await c.rebalance(futures=[f])
30113002

30123003

30133004
@gen_cluster(client=True)

0 commit comments

Comments
 (0)