1515import asyncio
1616from dataclasses import dataclass
1717import logging
18- from typing import Dict , Union , Any , List
18+ from typing import Dict , Union , Any , List , Tuple
1919
2020from ... import oscar as mo
2121from ...lib .aio import alru_cache
2828)
2929from ...storage import StorageLevel
3030from ...utils import dataslots
31- from .core import DataManagerActor , WrappedStorageFileObject
31+ from .core import DataManagerActor , WrappedStorageFileObject , DataInfo
3232from .handler import StorageHandlerActor
3333
3434DEFAULT_TRANSFER_BLOCK_SIZE = 4 * 1024 ** 2
@@ -96,6 +96,7 @@ async def _send_data(
9696 receiver_ref : Union [mo .ActorRef ],
9797 session_id : str ,
9898 data_keys : List [str ],
99+ data_infos : List [DataInfo ],
99100 level : StorageLevel ,
100101 block_size : int ,
101102 ):
@@ -129,11 +130,13 @@ async def send(self, buffer, eof_mark, key):
129130
130131 sender = BufferedSender ()
131132 open_reader_tasks = []
132- for data_key in data_keys :
133+ for data_key , info in zip ( data_keys , data_infos ) :
133134 open_reader_tasks .append (
134- self ._storage_handler .open_reader .delay (session_id , data_key )
135+ self ._storage_handler .open_reader_by_info .delay (info )
135136 )
136- readers = await self ._storage_handler .open_reader .batch (* open_reader_tasks )
137+ readers = await self ._storage_handler .open_reader_by_info .batch (
138+ * open_reader_tasks
139+ )
137140
138141 for data_key , reader in zip (data_keys , readers ):
139142 while True :
@@ -152,7 +155,61 @@ async def send(self, buffer, eof_mark, key):
152155 break
153156 await sender .flush ()
154157
155- @mo .extensible
158+ async def _send (
159+ self ,
160+ session_id : str ,
161+ data_keys : List [Union [str , Tuple ]],
162+ data_infos : List [DataInfo ],
163+ data_sizes : List [int ],
164+ block_size : int ,
165+ address : str ,
166+ band_name : str ,
167+ level : StorageLevel ,
168+ ):
169+ receiver_ref : Union [
170+ ReceiverManagerActor , mo .ActorRef
171+ ] = await self .get_receiver_ref (address , band_name )
172+ is_transferring_list = await receiver_ref .open_writers (
173+ session_id , data_keys , data_sizes , level
174+ )
175+ to_send_keys = []
176+ to_send_infos = []
177+ to_wait_keys = []
178+ for data_key , is_transferring , info in zip (
179+ data_keys , is_transferring_list , data_infos
180+ ):
181+ if is_transferring :
182+ to_wait_keys .append (data_key )
183+ else :
184+ to_send_keys .append (data_key )
185+ to_send_infos .append (info )
186+
187+ if to_send_keys :
188+ await self ._send_data (
189+ receiver_ref , session_id , to_send_keys , to_send_infos , level , block_size
190+ )
191+ if to_wait_keys :
192+ await receiver_ref .wait_transfer_done (session_id , to_wait_keys )
193+
194+ async def _send_small_objects (
195+ self ,
196+ session_id : str ,
197+ data_keys : List [Union [str , Tuple ]],
198+ data_infos : List [DataInfo ],
199+ address : str ,
200+ band_name : str ,
201+ level : StorageLevel ,
202+ ):
203+ # simple get all objects and send them all to receiver
204+ get_tasks = [
205+ self ._storage_handler .get_data_by_info .delay (info ) for info in data_infos
206+ ]
207+ data_list = await self ._storage_handler .get_data_by_info .batch (* get_tasks )
208+ receiver_ref : Union [
209+ ReceiverManagerActor , mo .ActorRef
210+ ] = await self .get_receiver_ref (address , band_name )
211+ await receiver_ref .put_small_objects (session_id , data_keys , data_list , level )
212+
156213 async def send_batch_data (
157214 self ,
158215 session_id : str ,
@@ -167,9 +224,6 @@ async def send_batch_data(
167224 "Begin to send data (%s, %s) to %s" , session_id , data_keys , address
168225 )
169226 block_size = block_size or self ._transfer_block_size
170- receiver_ref : Union [
171- ReceiverManagerActor , mo .ActorRef
172- ] = await self .get_receiver_ref (address , band_name )
173227 get_infos = []
174228 pin_tasks = []
175229 for data_key in data_keys :
@@ -198,23 +252,27 @@ async def send_batch_data(
198252 data_sizes = [info .store_size for info in infos ]
199253 if level is None :
200254 level = infos [0 ].level
201- is_transferring_list = await receiver_ref .open_writers (
202- session_id , data_keys , data_sizes , level
203- )
204- to_send_keys = []
205- to_wait_keys = []
206- for data_key , is_transferring in zip (data_keys , is_transferring_list ):
207- if is_transferring :
208- to_wait_keys .append (data_key )
209- else :
210- to_send_keys .append (data_key )
211-
212- if to_send_keys :
213- await self ._send_data (
214- receiver_ref , session_id , to_send_keys , level , block_size
255+ total_size = sum (data_sizes )
256+ if total_size > block_size :
257+ logger .debug ("Choose block method for sending data of %s bytes" , total_size )
258+ await self ._send (
259+ session_id ,
260+ data_keys ,
261+ infos ,
262+ data_sizes ,
263+ block_size ,
264+ address ,
265+ band_name ,
266+ level ,
267+ )
268+ else :
269+ logger .debug (
270+ "Choose send_small_objects method for sending data of %s bytes" ,
271+ total_size ,
272+ )
273+ await self ._send_small_objects (
274+ session_id , data_keys , infos , address , band_name , level
215275 )
216- if to_wait_keys :
217- await receiver_ref .wait_transfer_done (session_id , to_wait_keys )
218276 unpin_tasks = []
219277 for data_key in data_keys :
220278 unpin_tasks .append (
@@ -268,6 +326,15 @@ def _decref_writing_key(self, session_id: str, data_key: str):
268326 if self ._writing_infos [(session_id , data_key )].ref_counts == 0 :
269327 del self ._writing_infos [(session_id , data_key )]
270328
329+ async def put_small_objects (
330+ self , session_id : str , data_keys : List [str ], objects : List , level : StorageLevel
331+ ):
332+ tasks = [
333+ self ._storage_handler .put .delay (session_id , data_key , obj , level )
334+ for data_key , obj in zip (data_keys , objects )
335+ ]
336+ await self ._storage_handler .put .batch (* tasks )
337+
271338 async def create_writers (
272339 self ,
273340 session_id : str ,
0 commit comments