Skip to content

Commit bb7f906

Browse files
hekaisheng继盛
andauthored
Fix potential data leak for shuffle tasks (#2975)
Co-authored-by: 继盛 <[email protected]>
1 parent dddc201 commit bb7f906

File tree

6 files changed

+67
-7
lines changed

6 files changed

+67
-7
lines changed

mars/deploy/oscar/session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,7 @@ async def fetch_infos(self, *tileables, fields, **kwargs) -> list:
12381238
return result
12391239

12401240
async def decref(self, *tileable_keys):
1241+
logger.debug("Decref tileables on client: %s", tileable_keys)
12411242
return await self._lifecycle_api.decref_tileables(list(tileable_keys))
12421243

12431244
async def _get_ref_counts(self) -> Dict[str, int]:

mars/deploy/oscar/tests/test_local.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,14 @@ async def test_execute(create_cluster, config):
230230

231231
del a, b
232232

233+
if not isinstance(session._isolated_session, _IsolatedWebSession):
234+
worker_pools = session.client._cluster._worker_pools
235+
await session.destroy()
236+
for worker_pool in worker_pools:
237+
_assert_storage_cleaned(
238+
session.session_id, worker_pool.external_address, StorageLevel.MEMORY
239+
)
240+
233241

234242
@pytest.mark.asyncio
235243
async def test_iterative_tiling(create_cluster):
@@ -254,6 +262,14 @@ async def test_iterative_tiling(create_cluster):
254262
assert df2.index_value.min_val >= 1
255263
assert df2.index_value.max_val <= 30
256264

265+
if not isinstance(session._isolated_session, _IsolatedWebSession):
266+
worker_pools = session.client._cluster._worker_pools
267+
await session.destroy()
268+
for worker_pool in worker_pools:
269+
_assert_storage_cleaned(
270+
session.session_id, worker_pool.external_address, StorageLevel.MEMORY
271+
)
272+
257273

258274
@pytest.mark.asyncio
259275
async def test_execute_describe(create_cluster):
@@ -271,6 +287,14 @@ async def test_execute_describe(create_cluster):
271287
res = await session.fetch(r)
272288
pd.testing.assert_frame_equal(res, raw.describe())
273289

290+
if not isinstance(session._isolated_session, _IsolatedWebSession):
291+
worker_pools = session.client._cluster._worker_pools
292+
await session.destroy()
293+
for worker_pool in worker_pools:
294+
_assert_storage_cleaned(
295+
session.session_id, worker_pool.external_address, StorageLevel.MEMORY
296+
)
297+
274298

275299
@pytest.mark.asyncio
276300
async def test_sync_execute_in_async(create_cluster):
@@ -395,6 +419,12 @@ async def test_web_session(create_cluster, config):
395419
await session.destroy()
396420
await _run_web_session_test(web_address)
397421

422+
worker_pools = client._cluster._worker_pools
423+
for worker_pool in worker_pools:
424+
_assert_storage_cleaned(
425+
session.session_id, worker_pool.external_address, StorageLevel.MEMORY
426+
)
427+
398428

399429
@pytest.mark.parametrize("config", [{"backend": "mars", "incremental_index": True}])
400430
def test_sync_execute(config):
@@ -546,6 +576,26 @@ def test_decref(setup_session):
546576
ref_counts = session._get_ref_counts()
547577
assert len(ref_counts) == 0
548578

579+
with tempfile.TemporaryDirectory() as tempdir:
580+
file_path = os.path.join(tempdir, "test.csv")
581+
pdf = pd.DataFrame(
582+
np.random.RandomState(0).rand(100, 10),
583+
columns=[f"col{i}" for i in range(10)],
584+
)
585+
pdf.to_csv(file_path, index=False)
586+
587+
df = md.read_csv(file_path, chunk_bytes=os.stat(file_path).st_size / 5)
588+
df2 = df.head(10)
589+
590+
result = df2.execute().fetch()
591+
expected = pdf.head(10)
592+
pd.testing.assert_frame_equal(result, expected)
593+
594+
del df, df2
595+
596+
ref_counts = session._get_ref_counts()
597+
assert len(ref_counts) == 0
598+
549599
worker_addr = session._session.client._cluster._worker_pools[0].external_address
550600
_assert_storage_cleaned(session.session_id, worker_addr, StorageLevel.MEMORY)
551601

mars/services/lifecycle/supervisor/tracker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ def _check_ref_counts(cls, keys: List[str], ref_counts: List[int]):
7878
)
7979

8080
def incref_chunks(self, chunk_keys: List[str], counts: List[int] = None):
81-
logger.debug("Increase reference count for chunks %s", chunk_keys)
81+
logger.debug(
82+
"Increase reference count for chunks %s",
83+
{ck: self._chunk_ref_counts[ck] for ck in chunk_keys},
84+
)
8285
self._check_ref_counts(chunk_keys, counts)
8386
counts = counts if counts is not None else itertools.repeat(1)
8487
for chunk_key, count in zip(chunk_keys, counts):

mars/services/storage/api/oscar.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,11 @@ async def fetch(
185185
error: str
186186
raise or ignore
187187
"""
188-
await self._storage_handler_ref.fetch_batch(
188+
fetch_key = await self._storage_handler_ref.fetch_batch(
189189
self._session_id, [data_key], level, band_name, remote_address, error
190190
)
191+
if fetch_key:
192+
return fetch_key
191193

192194
@fetch.batch
193195
async def batch_fetch(self, args_list, kwargs_list):
@@ -201,7 +203,7 @@ async def batch_fetch(self, args_list, kwargs_list):
201203
assert extracted_args == (level, band_name, dest_address, error)
202204
extracted_args = (level, band_name, dest_address, error)
203205
data_keys.append(data_key)
204-
await self._storage_handler_ref.fetch_batch(
206+
return await self._storage_handler_ref.fetch_batch(
205207
self._session_id, data_keys, *extracted_args
206208
)
207209

mars/services/storage/transfer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,11 @@ async def send_batch_data(
224224
)
225225
await self._data_manager_ref.unpin.batch(*unpin_tasks)
226226
logger.debug(
227-
"Finish sending data (%s, %s) to %s", session_id, data_keys, address
227+
"Finish sending data (%s, %s) to %s, total size is %s",
228+
session_id,
229+
data_keys,
230+
address,
231+
sum(data_sizes),
228232
)
229233

230234

mars/services/task/execution/mars/executor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,9 +388,9 @@ def _get_decref_stage_chunk_key_to_counts(
388388
for inp_subtask in subtask_graph.predecessors(subtask):
389389
for c in inp_subtask.chunk_graph.results:
390390
decref_chunk_key_to_counts[c.key] += 1
391-
# decref result of chunk graphs
392-
for c in stage_processor.chunk_graph.results:
393-
decref_chunk_key_to_counts[c.key] += 1
391+
# decref result of chunk graphs
392+
for c in stage_processor.chunk_graph.results:
393+
decref_chunk_key_to_counts[c.key] += 1
394394
return decref_chunk_key_to_counts
395395

396396
@mo.extensible

0 commit comments

Comments
 (0)