2929from numbers import Integral
3030from urllib .parse import urlparse
3131from weakref import ref , WeakKeyDictionary , WeakSet
32- from typing import Any , Callable , Coroutine , Dict , List , Optional , Tuple , Type , Union
32+ from typing import Any , Callable , Coroutine , Dict , List , Optional , Tuple , Union
3333
3434import numpy as np
3535
5454from ...services .mutable import MutableAPI , MutableTensor
5555from ...services .storage import StorageAPI
5656from ...services .task import AbstractTaskAPI , TaskAPI , TaskResult
57+ from ...services .task .execution .api import Fetcher
5758from ...services .web import OscarWebAPI
5859from ...tensor .utils import slice_split
5960from ...typing import ClientType , BandType
@@ -441,7 +442,7 @@ def init(
441442 cls ,
442443 address : str ,
443444 session_id : str ,
444- backend : str = "oscar " ,
445+ backend : str = "mars " ,
445446 new : bool = True ,
446447 ** kwargs ,
447448 ) -> "AbstractSession" :
@@ -658,14 +659,6 @@ def fetch_log(
658659 return fetch (tileables , self , offsets = offsets , sizes = sizes )
659660
660661
661- _type_name_to_session_cls : Dict [str , Type [AbstractAsyncSession ]] = dict ()
662-
663-
664- def register_session_cls (session_cls : Type [AbstractAsyncSession ]):
665- _type_name_to_session_cls [session_cls .name ] = session_cls
666- return session_cls
667-
668-
669662@dataclass
670663class ChunkFetchInfo :
671664 tileable : TileableType
@@ -755,14 +748,12 @@ def gen_submit_tileable_graph(
755748 return graph , to_execute_tileables
756749
757750
758- @register_session_cls
759751class _IsolatedSession (AbstractAsyncSession ):
760- name = "oscar"
761-
762752 def __init__ (
763753 self ,
764754 address : str ,
765755 session_id : str ,
756+ backend : str ,
766757 session_api : AbstractSessionAPI ,
767758 meta_api : AbstractMetaAPI ,
768759 lifecycle_api : AbstractLifecycleAPI ,
@@ -775,6 +766,7 @@ def __init__(
775766 request_rewriter : Callable = None ,
776767 ):
777768 super ().__init__ (address , session_id )
769+ self ._backend = backend
778770 self ._session_api = session_api
779771 self ._task_api = task_api
780772 self ._meta_api = meta_api
@@ -800,7 +792,12 @@ def __init__(
800792
801793 @classmethod
802794 async def _init (
803- cls , address : str , session_id : str , new : bool = True , timeout : float = None
795+ cls ,
796+ address : str ,
797+ session_id : str ,
798+ backend : str ,
799+ new : bool = True ,
800+ timeout : float = None ,
804801 ):
805802 session_api = await SessionAPI .create (address )
806803 if new :
@@ -820,6 +817,7 @@ async def _init(
820817 return cls (
821818 address ,
822819 session_id ,
820+ backend ,
823821 session_api ,
824822 meta_api ,
825823 lifecycle_api ,
@@ -836,6 +834,7 @@ async def init(
836834 cls ,
837835 address : str ,
838836 session_id : str ,
837+ backend : str ,
839838 new : bool = True ,
840839 timeout : float = None ,
841840 ** kwargs ,
@@ -859,12 +858,19 @@ async def init(
859858 return await _IsolatedWebSession ._init (
860859 address ,
861860 session_id ,
861+ backend ,
862862 new = new ,
863863 timeout = timeout ,
864864 request_rewriter = request_rewriter ,
865865 )
866866 else :
867- return await cls ._init (address , session_id , new = new , timeout = timeout )
867+ return await cls ._init (
868+ address ,
869+ session_id ,
870+ backend ,
871+ new = new ,
872+ timeout = timeout ,
873+ )
868874
869875 async def _update_progress (self , task_id : str , progress : Progress ):
870876 zero_acc_time = 0
@@ -1084,6 +1090,8 @@ async def fetch(self, *tileables, **kwargs) -> list:
10841090 unexpected_keys = ", " .join (list (kwargs .keys ()))
10851091 raise TypeError (f"`fetch` got unexpected arguments: { unexpected_keys } " )
10861092
1093+ fetcher = Fetcher .create (self ._backend , get_storage_api = self ._get_storage_api )
1094+
10871095 with enter_mode (build = True ):
10881096 chunks = []
10891097 get_chunk_metas = []
@@ -1099,7 +1107,10 @@ async def fetch(self, *tileables, **kwargs) -> list:
10991107 continue
11001108 chunks .append (chunk )
11011109 get_chunk_metas .append (
1102- self ._meta_api .get_chunk_meta .delay (chunk .key , fields = ["bands" ])
1110+ self ._meta_api .get_chunk_meta .delay (
1111+ chunk .key ,
1112+ fields = fetcher .required_meta_keys ,
1113+ )
11031114 )
11041115 indexes = (
11051116 chunk_to_slice [chunk ] if chunk_to_slice is not None else None
@@ -1108,29 +1119,17 @@ async def fetch(self, *tileables, **kwargs) -> list:
11081119 ChunkFetchInfo (tileable = tileable , chunk = chunk , indexes = indexes )
11091120 )
11101121 fetch_infos_list .append (fetch_infos )
1111- chunk_metas = await self ._meta_api .get_chunk_meta .batch (* get_chunk_metas )
1112- chunk_to_band = {
1113- chunk : meta ["bands" ][0 ] for chunk , meta in zip (chunks , chunk_metas )
1114- }
11151122
1116- storage_api_to_gets = defaultdict (list )
1117- storage_api_to_fetch_infos = defaultdict (list )
1118- for fetch_info in itertools .chain (* fetch_infos_list ):
1119- conditions = fetch_info .indexes
1120- chunk = fetch_info .chunk
1121- band = chunk_to_band [chunk ]
1122- storage_api = await self ._get_storage_api (band )
1123- storage_api_to_gets [storage_api ].append (
1124- storage_api .get .delay (chunk .key , conditions = conditions )
1125- )
1126- storage_api_to_fetch_infos [storage_api ].append (fetch_info )
1127- for storage_api in storage_api_to_gets :
1128- fetched_data = await storage_api .get .batch (
1129- * storage_api_to_gets [storage_api ]
1130- )
1131- infos = storage_api_to_fetch_infos [storage_api ]
1132- for info , data in zip (infos , fetched_data ):
1133- info .data = data
1123+ chunk_metas = await self ._meta_api .get_chunk_meta .batch (* get_chunk_metas )
1124+ for chunk , meta , fetch_info in zip (
1125+ chunks , chunk_metas , itertools .chain (* fetch_infos_list )
1126+ ):
1127+ await fetcher .append (chunk .key , meta , fetch_info .indexes )
1128+ fetched_data = await fetcher .get ()
1129+ for fetch_info , data in zip (
1130+ itertools .chain (* fetch_infos_list ), fetched_data
1131+ ):
1132+ fetch_info .data = data
11341133
11351134 result = []
11361135 for tileable , fetch_infos in zip (tileables , fetch_infos_list ):
@@ -1317,6 +1316,7 @@ async def _init(
13171316 cls ,
13181317 address : str ,
13191318 session_id : str ,
1319+ backend : str ,
13201320 new : bool = True ,
13211321 timeout : float = None ,
13221322 request_rewriter : Callable = None ,
@@ -1341,6 +1341,7 @@ async def _init(
13411341 return cls (
13421342 address ,
13431343 session_id ,
1344+ backend ,
13441345 session_api ,
13451346 meta_api ,
13461347 lifecycle_api ,
@@ -1415,13 +1416,12 @@ async def init(
14151416 cls ,
14161417 address : str ,
14171418 session_id : str ,
1418- backend : str = "oscar " ,
1419+ backend : str = "mars " ,
14191420 new : bool = True ,
14201421 ** kwargs ,
14211422 ) -> "AbstractSession" :
1422- session_cls = _type_name_to_session_cls [backend ]
14231423 isolation = ensure_isolation_created (kwargs )
1424- coro = session_cls .init (address , session_id , new = new , ** kwargs )
1424+ coro = _IsolatedSession .init (address , session_id , backend , new = new , ** kwargs )
14251425 fut = asyncio .run_coroutine_threadsafe (coro , isolation .loop )
14261426 isolated_session = await asyncio .wrap_future (fut )
14271427 return AsyncSession (address , session_id , isolated_session , isolation )
@@ -1587,13 +1587,12 @@ def init(
15871587 cls ,
15881588 address : str ,
15891589 session_id : str ,
1590- backend : str = "oscar " ,
1590+ backend : str = "mars " ,
15911591 new : bool = True ,
15921592 ** kwargs ,
15931593 ) -> "AbstractSession" :
1594- session_cls = _type_name_to_session_cls [backend ]
15951594 isolation = ensure_isolation_created (kwargs )
1596- coro = session_cls .init (address , session_id , new = new , ** kwargs )
1595+ coro = _IsolatedSession .init (address , session_id , backend , new = new , ** kwargs )
15971596 fut = asyncio .run_coroutine_threadsafe (coro , isolation .loop )
15981597 isolated_session = fut .result ()
15991598 return SyncSession (address , session_id , isolated_session , isolation )
@@ -1963,7 +1962,7 @@ def _new_session_id():
19631962async def _new_session (
19641963 address : str ,
19651964 session_id : str = None ,
1966- backend : str = "oscar " ,
1965+ backend : str = "mars " ,
19671966 default : bool = False ,
19681967 ** kwargs ,
19691968) -> AbstractSession :
@@ -1981,7 +1980,7 @@ async def _new_session(
19811980def new_session (
19821981 address : str = None ,
19831982 session_id : str = None ,
1984- backend : str = "oscar " ,
1983+ backend : str = "mars " ,
19851984 default : bool = True ,
19861985 new : bool = True ,
19871986 ** kwargs ,
0 commit comments