1515import asyncio
1616import logging
1717from dataclasses import dataclass
18- from typing import Dict , List
18+ from typing import Dict , List , Tuple , Union
1919
2020from ... import oscar as mo
2121from ...lib .aio import alru_cache
2222from ...storage import StorageLevel
2323from ...utils import dataslots
24- from .core import DataManagerActor , WrappedStorageFileObject
24+ from .core import DataManagerActor , WrappedStorageFileObject , DataInfo
2525from .handler import StorageHandlerActor
2626
2727DEFAULT_TRANSFER_BLOCK_SIZE = 4 * 1024 ** 2
@@ -65,6 +65,7 @@ async def _send_data(
6565 receiver_ref : mo .ActorRefType ["ReceiverManagerActor" ],
6666 session_id : str ,
6767 data_keys : List [str ],
68+ data_infos : List [DataInfo ],
6869 level : StorageLevel ,
6970 block_size : int ,
7071 ):
@@ -93,11 +94,12 @@ async def send(self, buffer, eof_mark, key):
9394
9495 sender = BufferedSender ()
9596 open_reader_tasks = []
96- for data_key in data_keys :
97+ storage_client = await self ._storage_handler .get_client (level )
98+ for info in data_infos :
9799 open_reader_tasks .append (
98- self . _storage_handler . open_reader . delay ( session_id , data_key )
100+ storage_client . open_reader ( info . object_id )
99101 )
100- readers = await self . _storage_handler . open_reader . batch (* open_reader_tasks )
102+ readers = await asyncio . gather (* open_reader_tasks )
101103
102104 for data_key , reader in zip (data_keys , readers ):
103105 while True :
@@ -116,7 +118,58 @@ async def send(self, buffer, eof_mark, key):
116118 break
117119 await sender .flush ()
118120
119- @mo .extensible
121+ async def _send (
122+ self ,
123+ session_id : str ,
124+ data_keys : List [Union [str , Tuple ]],
125+ data_infos : List [DataInfo ],
126+ data_sizes : List [int ],
127+ block_size : int ,
128+ address : str ,
129+ band_name : str ,
130+ level : StorageLevel ,
131+ ):
132+ receiver_ref : mo .ActorRefType [ReceiverManagerActor ] = await self .get_receiver_ref (address , band_name )
133+ is_transferring_list = await receiver_ref .open_writers (
134+ session_id , data_keys , data_sizes , level
135+ )
136+ to_send_keys = []
137+ to_send_infos = []
138+ to_wait_keys = []
139+ for data_key , is_transferring , info in zip (
140+ data_keys , is_transferring_list , data_infos
141+ ):
142+ if is_transferring :
143+ to_wait_keys .append (data_key )
144+ else :
145+ to_send_keys .append (data_key )
146+ to_send_infos .append (info )
147+
148+ if to_send_keys :
149+ await self ._send_data (
150+ receiver_ref , session_id , to_send_keys , to_send_infos , level , block_size
151+ )
152+ if to_wait_keys :
153+ await receiver_ref .wait_transfer_done (session_id , to_wait_keys )
154+
155+ async def _send_small_objects (
156+ self ,
157+ session_id : str ,
158+ data_keys : List [Union [str , Tuple ]],
159+ data_infos : List [DataInfo ],
160+ address : str ,
161+ band_name : str ,
162+ level : StorageLevel ,
163+ ):
164+ # simple get all objects and send them all to receiver
165+ storage_client = await self ._storage_handler .get_client (level )
166+ get_tasks = [
167+ storage_client .get (info .object_id ) for info in data_infos
168+ ]
169+ data_list = list (await asyncio .gather (* get_tasks ))
170+ receiver_ref : mo .ActorRefType [ReceiverManagerActor ] = await self .get_receiver_ref (address , band_name )
171+ await receiver_ref .put_small_objects (session_id , data_keys , data_list , level )
172+
120173 async def send_batch_data (
121174 self ,
122175 session_id : str ,
@@ -125,15 +178,13 @@ async def send_batch_data(
125178 level : StorageLevel ,
126179 band_name : str = "numa-0" ,
127180 block_size : int = None ,
181+ is_small_objects = None ,
128182 error : str = "raise" ,
129183 ):
130184 logger .debug (
131185 "Begin to send data (%s, %s) to %s" , session_id , data_keys , address
132186 )
133187 block_size = block_size or self ._transfer_block_size
134- receiver_ref : mo .ActorRefType [
135- ReceiverManagerActor
136- ] = await self .get_receiver_ref (address , band_name )
137188 get_infos = []
138189 pin_tasks = []
139190 for data_key in data_keys :
@@ -162,23 +213,29 @@ async def send_batch_data(
162213 data_sizes = [info .store_size for info in infos ]
163214 if level is None :
164215 level = infos [0 ].level
165- is_transferring_list = await receiver_ref .open_writers (
166- session_id , data_keys , data_sizes , level
167- )
168- to_send_keys = []
169- to_wait_keys = []
170- for data_key , is_transferring in zip (data_keys , is_transferring_list ):
171- if is_transferring :
172- to_wait_keys .append (data_key )
173- else :
174- to_send_keys .append (data_key )
175-
176- if to_send_keys :
177- await self ._send_data (
178- receiver_ref , session_id , to_send_keys , level , block_size
216+ total_size = sum (data_sizes )
217+ if is_small_objects is None :
218+ is_small_objects = total_size <= block_size
219+ if is_small_objects :
220+ logger .debug (
221+ "Choose send_small_objects method for sending data of %s bytes" ,
222+ total_size ,
223+ )
224+ await self ._send_small_objects (
225+ session_id , data_keys , infos , address , band_name , level
226+ )
227+ else :
228+ logger .debug ("Choose block method for sending data of %s bytes" , total_size )
229+ await self ._send (
230+ session_id ,
231+ data_keys ,
232+ infos ,
233+ data_sizes ,
234+ block_size ,
235+ address ,
236+ band_name ,
237+ level ,
179238 )
180- if to_wait_keys :
181- await receiver_ref .wait_transfer_done (session_id , to_wait_keys )
182239 unpin_tasks = []
183240 for data_key in data_keys :
184241 unpin_tasks .append (
@@ -232,6 +289,15 @@ def _decref_writing_key(self, session_id: str, data_key: str):
232289 if self ._writing_infos [(session_id , data_key )].ref_counts == 0 :
233290 del self ._writing_infos [(session_id , data_key )]
234291
292+ async def put_small_objects (
293+ self , session_id : str , data_keys : List [str ], objects : List , level : StorageLevel
294+ ):
295+ tasks = [
296+ self ._storage_handler .put .delay (session_id , data_key , obj , level )
297+ for data_key , obj in zip (data_keys , objects )
298+ ]
299+ await self ._storage_handler .put .batch (* tasks )
300+
235301 async def create_writers (
236302 self ,
237303 session_id : str ,
0 commit comments