2323from ...typing import BandType
2424from ...utils import calc_data_size , lazy_import
2525from ..cluster import ClusterAPI , StorageInfo
26- from ..meta import MetaAPI
26+ from ..meta import MetaAPI , WorkerMetaAPI
2727from .core import (
2828 StorageQuotaActor ,
2929 DataManagerActor ,
3939logger = logging .getLogger (__name__ )
4040
4141
42- class StorageHandlerActor (mo .Actor ):
42+ class StorageHandlerActor (mo .StatelessActor ):
4343 """
4444 Storage handler actor, provide methods like `get`, `put`, etc.
4545 This actor is stateful and created on worker's sub pools.
@@ -468,29 +468,37 @@ async def _fetch_via_transfer(
468468 )
469469 logger .debug ("Finish fetching %s from band %s" , data_keys , remote_band )
470470
471- async def fetch_batch (
472- self ,
473- session_id : str ,
474- data_keys : List [str ],
475- level : StorageLevel ,
476- band_name : str ,
477- address : str ,
478- error : str ,
471+ async def _get_remote_bands (
472+ self , session_id : str , data_keys : List , fetch_keys : List , band_name : str
479473 ):
480- if error not in ("raise" , "ignore" ): # pragma: no cover
481- raise ValueError ("error must be raise or ignore" )
482-
474+ # For mapper data, we need to return fetch keys to remove them after
475+ # execution(see #2922). However, after introducing prefetch, some
476+ # mapper data is fetched in advance and not in missing keys, so here
477+ # we need to query worker meta service to check if it is on local.
483478 meta_api = await self ._get_meta_api (session_id )
484- remote_keys = defaultdict (set )
479+ mapper_main_keys = defaultdict (list )
485480 missing_keys = []
486- get_metas = []
487481 get_info_delays = []
488482 for data_key in data_keys :
483+ if isinstance (data_key , tuple ):
484+ mapper_main_keys [data_key [0 ]].append (data_key )
489485 get_info_delays .append (
490486 self ._data_manager_ref .get_data_info .delay (
491487 session_id , data_key , band_name , error = "ignore"
492488 )
493489 )
490+ # use worker meta api to check if mapper data is executed locally.
491+ worker_meta_api = await WorkerMetaAPI .create (session_id , self .address )
492+
493+ worker_meta_tasks = [
494+ worker_meta_api .get_chunk_meta .delay (key , error = "ignore" )
495+ for key in mapper_main_keys
496+ ]
497+ mapper_metas = await worker_meta_api .get_chunk_meta .batch (* worker_meta_tasks )
498+ for key , meta in zip (mapper_main_keys , mapper_metas ):
499+ if meta is None :
500+ fetch_keys .extend (mapper_main_keys [key ])
501+
494502 data_infos = await self ._data_manager_ref .get_data_info .batch (* get_info_delays )
495503 pin_delays = []
496504 for data_key , info in zip (data_keys , data_infos ):
@@ -507,31 +515,50 @@ async def fetch_batch(
507515 else :
508516 # Not exists in local, fetch from remote worker
509517 missing_keys .append (data_key )
510- if address is None or band_name is None :
511- # some mapper keys are absent, specify error='ignore'
512- # remember that meta only records those main keys
513- get_metas = [
514- (
515- meta_api .get_chunk_meta .delay (
516- data_key [0 ] if isinstance (data_key , tuple ) else data_key ,
517- fields = ["bands" ],
518- error = "ignore" ,
519- )
518+ # some mapper keys are absent, specify error='ignore'
519+ # remember that meta only records those main keys
520+ get_metas = [
521+ (
522+ meta_api .get_chunk_meta .delay (
523+ data_key [0 ] if isinstance (data_key , tuple ) else data_key ,
524+ fields = ["bands" ],
525+ error = "ignore" ,
520526 )
521- for data_key in missing_keys
522- ]
527+ )
528+ for data_key in missing_keys
529+ ]
523530 await self ._data_manager_ref .pin .batch (* pin_delays )
524-
525- if get_metas :
526- metas = await meta_api .get_chunk_meta .batch (* get_metas )
527- else : # pragma: no cover
528- metas = [{"bands" : [(address , band_name )]}] * len (missing_keys )
531+ metas = await meta_api .get_chunk_meta .batch (* get_metas )
529532 assert len (metas ) == len (missing_keys )
530- for data_key , bands in zip (missing_keys , metas ):
533+ return missing_keys , metas
534+
535+ async def fetch_batch (
536+ self ,
537+ session_id : str ,
538+ data_keys : List [str ],
539+ level : StorageLevel ,
540+ band_name : str ,
541+ address : str ,
542+ error : str ,
543+ ):
544+ if error not in ("raise" , "ignore" ): # pragma: no cover
545+ raise ValueError ("error must be raise or ignore" )
546+
547+ fetch_keys = []
548+ if address and band_name :
549+ # if specify address, we don't need to query meta
550+ missing_keys = data_keys
551+ bands = [{"bands" : [(address , band_name )]}] * len (missing_keys )
552+ else :
553+ missing_keys , bands = await self ._get_remote_bands (
554+ session_id , data_keys , fetch_keys , band_name
555+ )
556+
557+ remote_keys = defaultdict (set )
558+ for data_key , bands in zip (missing_keys , bands ):
531559 if bands is not None :
532560 remote_keys [bands ["bands" ][0 ]].add (data_key )
533561 transfer_tasks = []
534- fetch_keys = []
535562 for band , keys in remote_keys .items ():
536563 if StorageLevel .REMOTE in self ._quota_refs :
537564 # if storage support remote level, just fetch object id
@@ -550,6 +577,7 @@ async def fetch_batch(
550577 await asyncio .gather (* transfer_tasks )
551578
552579 append_bands_delays = []
580+ meta_api = await self ._get_meta_api (session_id )
553581 for data_key in fetch_keys :
554582 # skip shuffle keys
555583 if isinstance (data_key , tuple ):
0 commit comments