Skip to content

Commit e889f9c

Browse files
committed
Fix ut
1 parent 70f0c96 commit e889f9c

File tree

2 files changed

+66
-42
lines changed

2 files changed

+66
-42
lines changed

mars/services/storage/handler.py

Lines changed: 62 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ...typing import BandType
2424
from ...utils import calc_data_size, lazy_import
2525
from ..cluster import ClusterAPI, StorageInfo
26-
from ..meta import MetaAPI
26+
from ..meta import MetaAPI, WorkerMetaAPI
2727
from .core import (
2828
StorageQuotaActor,
2929
DataManagerActor,
@@ -39,7 +39,7 @@
3939
logger = logging.getLogger(__name__)
4040

4141

42-
class StorageHandlerActor(mo.Actor):
42+
class StorageHandlerActor(mo.StatelessActor):
4343
"""
4444
Storage handler actor, provide methods like `get`, `put`, etc.
4545
This actor is stateful and created on worker's sub pools.
@@ -468,29 +468,37 @@ async def _fetch_via_transfer(
468468
)
469469
logger.debug("Finish fetching %s from band %s", data_keys, remote_band)
470470

471-
async def fetch_batch(
472-
self,
473-
session_id: str,
474-
data_keys: List[str],
475-
level: StorageLevel,
476-
band_name: str,
477-
address: str,
478-
error: str,
471+
async def _get_remote_bands(
472+
self, session_id: str, data_keys: List, fetch_keys: List, band_name: str
479473
):
480-
if error not in ("raise", "ignore"): # pragma: no cover
481-
raise ValueError("error must be raise or ignore")
482-
474+
# For mapper data, we need to return fetch keys to remove them after
475+
# execution(see #2922). However, after introducing prefetch, some
476+
# mapper data is fetched in advance and not in missing keys, so here
477+
# we need to query worker meta service to check if it is on local.
483478
meta_api = await self._get_meta_api(session_id)
484-
remote_keys = defaultdict(set)
479+
mapper_main_keys = defaultdict(list)
485480
missing_keys = []
486-
get_metas = []
487481
get_info_delays = []
488482
for data_key in data_keys:
483+
if isinstance(data_key, tuple):
484+
mapper_main_keys[data_key[0]].append(data_key)
489485
get_info_delays.append(
490486
self._data_manager_ref.get_data_info.delay(
491487
session_id, data_key, band_name, error="ignore"
492488
)
493489
)
490+
# use worker meta api to check if mapper data is executed locally.
491+
worker_meta_api = await WorkerMetaAPI.create(session_id, self.address)
492+
493+
worker_meta_tasks = [
494+
worker_meta_api.get_chunk_meta.delay(key, error="ignore")
495+
for key in mapper_main_keys
496+
]
497+
mapper_metas = await worker_meta_api.get_chunk_meta.batch(*worker_meta_tasks)
498+
for key, meta in zip(mapper_main_keys, mapper_metas):
499+
if meta is None:
500+
fetch_keys.extend(mapper_main_keys[key])
501+
494502
data_infos = await self._data_manager_ref.get_data_info.batch(*get_info_delays)
495503
pin_delays = []
496504
for data_key, info in zip(data_keys, data_infos):
@@ -507,31 +515,50 @@ async def fetch_batch(
507515
else:
508516
# Not exists in local, fetch from remote worker
509517
missing_keys.append(data_key)
510-
if address is None or band_name is None:
511-
# some mapper keys are absent, specify error='ignore'
512-
# remember that meta only records those main keys
513-
get_metas = [
514-
(
515-
meta_api.get_chunk_meta.delay(
516-
data_key[0] if isinstance(data_key, tuple) else data_key,
517-
fields=["bands"],
518-
error="ignore",
519-
)
518+
# some mapper keys are absent, specify error='ignore'
519+
# remember that meta only records those main keys
520+
get_metas = [
521+
(
522+
meta_api.get_chunk_meta.delay(
523+
data_key[0] if isinstance(data_key, tuple) else data_key,
524+
fields=["bands"],
525+
error="ignore",
520526
)
521-
for data_key in missing_keys
522-
]
527+
)
528+
for data_key in missing_keys
529+
]
523530
await self._data_manager_ref.pin.batch(*pin_delays)
524-
525-
if get_metas:
526-
metas = await meta_api.get_chunk_meta.batch(*get_metas)
527-
else: # pragma: no cover
528-
metas = [{"bands": [(address, band_name)]}] * len(missing_keys)
531+
metas = await meta_api.get_chunk_meta.batch(*get_metas)
529532
assert len(metas) == len(missing_keys)
530-
for data_key, bands in zip(missing_keys, metas):
533+
return missing_keys, metas
534+
535+
async def fetch_batch(
536+
self,
537+
session_id: str,
538+
data_keys: List[str],
539+
level: StorageLevel,
540+
band_name: str,
541+
address: str,
542+
error: str,
543+
):
544+
if error not in ("raise", "ignore"): # pragma: no cover
545+
raise ValueError("error must be raise or ignore")
546+
547+
fetch_keys = []
548+
if address and band_name:
549+
# if specify address, we don't need to query meta
550+
missing_keys = data_keys
551+
bands = [{"bands": [(address, band_name)]}] * len(missing_keys)
552+
else:
553+
missing_keys, bands = await self._get_remote_bands(
554+
session_id, data_keys, fetch_keys, band_name
555+
)
556+
557+
remote_keys = defaultdict(set)
558+
for data_key, bands in zip(missing_keys, bands):
531559
if bands is not None:
532560
remote_keys[bands["bands"][0]].add(data_key)
533561
transfer_tasks = []
534-
fetch_keys = []
535562
for band, keys in remote_keys.items():
536563
if StorageLevel.REMOTE in self._quota_refs:
537564
# if storage support remote level, just fetch object id
@@ -550,6 +577,7 @@ async def fetch_batch(
550577
await asyncio.gather(*transfer_tasks)
551578

552579
append_bands_delays = []
580+
meta_api = await self._get_meta_api(session_id)
553581
for data_key in fetch_keys:
554582
# skip shuffle keys
555583
if isinstance(data_key, tuple):

mars/services/subtask/worker/processor.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -675,17 +675,13 @@ async def run(self, subtask: Subtask, wait_post_run: bool = False):
675675
try:
676676
result = yield self._running_aio_task
677677
logger.info("Finished subtask: %s", subtask.subtask_id)
678-
# post run with actor tell which will not block
679-
if not wait_post_run:
680-
await self.ref().post_run.tell(processor)
681-
else:
682-
await self.post_run(processor)
678+
if wait_post_run:
679+
await processor.post_run()
683680
raise mo.Return(result)
684681
finally:
685682
self._processor = self._running_aio_task = None
686-
687-
async def post_run(self, processor: SubtaskProcessor):
688-
await processor.post_run()
683+
if not wait_post_run:
684+
yield processor.post_run()
689685

690686
async def wait(self):
691687
return self._processor.is_done.wait()

0 commit comments

Comments
 (0)