99from .blyss_lib import BlyssLib
1010
1111import json
12+ import base64
1213import bz2
1314import time
15+ import asyncio
1416
1517
1618def _chunk_parser (raw_data : bytes ) -> Iterator [bytes ]:
@@ -71,23 +73,72 @@ def _check(self, uuid: str) -> bool:
7173 else :
7274 raise e
7375
74- def _private_read (self , keys : list [str ]) -> list [tuple [bytes , Optional [dict [Any , Any ]]]]:
75- """Performs the underlying private retrieval.
76+ async def _async_check (self , uuid : str ) -> bool :
77+ try :
78+ await self .api .async_check (uuid )
79+ return True
80+ except api .ApiException as e :
81+ if e .code == 404 :
82+ return False
83+ else :
84+ raise e
7685
77- Args:
78- keys (str): A list of keys to retrieve.
86+ def _split_into_chunks (
87+ self , kv_pairs : dict [str , bytes ]
88+ ) -> list [list [dict [str , str ]]]:
89+ _MAX_PAYLOAD = 5 * 2 ** 20 # 5 MiB
90+
91+ # 1. Bin keys by row index
92+ keys_by_index : dict [int , list [str ]] = {}
93+ for k in kv_pairs .keys ():
94+ i = self .lib .get_row (k )
95+ if i in keys_by_index :
96+ keys_by_index [i ].append (k )
97+ else :
98+ keys_by_index [i ] = [k ]
99+
100+ # 2. Prepare chunks of items, where each is a JSON-ready structure.
101+ # Each chunk is less than the maximum payload size, and guarantees
102+ # zero overlap of rows across chunks.
103+ kv_chunks : list [list [dict [str , str ]]] = []
104+ current_chunk : list [dict [str , str ]] = []
105+ current_chunk_size = 0
106+ sorted_indices = sorted (keys_by_index .keys ())
107+ for i in sorted_indices :
108+ keys = keys_by_index [i ]
109+ # prepare all keys in this row
110+ row = []
111+ row_size = 0
112+ for key in keys :
113+ value = kv_pairs [key ]
114+ value_str = base64 .b64encode (value ).decode ("utf-8" )
115+ fmt = {
116+ "key" : key ,
117+ "value" : value_str ,
118+ "content-type" : "application/octet-stream" ,
119+ }
120+ row .append (fmt )
121+ row_size += int (72 + len (key ) + len (value_str ))
122+
123+ # if the new row doesn't fit into the current chunk, start a new one
124+ if current_chunk_size + row_size > _MAX_PAYLOAD :
125+ kv_chunks .append (current_chunk )
126+ current_chunk = row
127+ current_chunk_size = row_size
128+ else :
129+ current_chunk .extend (row )
130+ current_chunk_size += row_size
79131
80- Returns:
81- tuple[bytes, Optional[dict]]: Returns a tuple of (value, optional_metadata).
82- """
83- if not self .public_uuid or not self ._check (self .public_uuid ):
84- self .setup ()
85- assert self .public_uuid
132+ # add the last chunk
133+ if len (current_chunk ) > 0 :
134+ kv_chunks .append (current_chunk )
135+
136+ return kv_chunks
86137
138+ def _generate_query_stream (self , keys : list [str ]) -> bytes :
87139 # generate encrypted queries
88140 queries : list [bytes ] = [
89- self .lib .generate_query (self .public_uuid , self .lib .get_row (k ))
90- for k in keys
141+ self .lib .generate_query (self .public_uuid , self .lib .get_row (k )) for k in keys
91142 ]
92143 # interleave the queries with their lengths (uint64_t)
93144 query_lengths = [len (q ).to_bytes (8 , "little" ) for q in queries ]
@@ -96,18 +147,43 @@ def _private_read(self, keys: list[str]) -> list[tuple[bytes, Optional[dict[Any,
96147 lengths_and_queries .insert (0 , len (queries ).to_bytes (8 , "little" ))
97148 # serialize the queries
98149 multi_query = b"" .join (lengths_and_queries )
99-
100- start = time .perf_counter ()
101- multi_result = self .api .private_read (self .name , multi_query )
102- self .exfil = time .perf_counter () - start
150+ return multi_query
103151
104- retrievals = []
105- for key , result in zip (keys , _chunk_parser (multi_result )):
152+ def _unpack_query_result (
153+ self , keys : list [str ], raw_result : bytes , parse_metadata : bool = True
154+ ) -> list [bytes ]:
155+ retrievals = []
156+ for key , result in zip (keys , _chunk_parser (raw_result )):
106157 decrypted_result = self .lib .decode_response (result )
107158 decompressed_result = bz2 .decompress (decrypted_result )
108159 extracted_result = self .lib .extract_result (key , decompressed_result )
109- output = serializer .deserialize (extracted_result )
160+ if parse_metadata :
161+ output = serializer .deserialize (extracted_result )
162+ else :
163+ output = extracted_result
110164 retrievals .append (output )
165+ return retrievals
166+
167+ def _private_read (self , keys : list [str ]) -> list [bytes ]:
168+ """Performs the underlying private retrieval.
169+
170+ Args:
171+ keys (str): A list of keys to retrieve.
172+
173+ Returns:
174+ tuple[bytes, Optional[dict]]: Returns a tuple of (value, optional_metadata).
175+ """
176+ if not self .public_uuid or not self ._check (self .public_uuid ):
177+ self .setup ()
178+ assert self .public_uuid
179+
180+ multi_query = self ._generate_query_stream (keys )
181+
182+ start = time .perf_counter ()
183+ multi_result = self .api .private_read (self .name , multi_query )
184+ self .exfil = time .perf_counter () - start
185+
186+ retrievals = self ._unpack_query_result (keys , multi_result )
111187
112188 return retrievals
113189
@@ -184,11 +260,11 @@ def private_read(self, keys: Union[str, list[str]]) -> Union[bytes, list[bytes]]
184260
185261 Args:
186262 keys (str): A key or list of keys to privately read.
187- If a list of keys is supplied,
263+ If a list of keys is supplied,
188264 results will be returned in the same order.
189265
190266 Returns:
191- bytes: The value found for the key in the bucket,
267+ bytes: The value found for the key in the bucket,
192268 or None if the key was not found.
193269 """
194270 single_query = False
@@ -202,7 +278,6 @@ def private_read(self, keys: Union[str, list[str]]) -> Union[bytes, list[bytes]]
202278
203279 return results
204280
205-
206281 def private_key_intersect (self , keys : list [str ]) -> list [str ]:
207282 """Privately intersects the given set of keys with the keys in this bucket,
208283 returning the keys that intersected. This is generally slower than a single
@@ -217,3 +292,36 @@ def private_key_intersect(self, keys: list[str]) -> list[str]:
217292 bloom_filter = self .api .bloom (self .name )
218293 present_keys = list (filter (bloom_filter .lookup , keys ))
219294 return present_keys
295+
296+
297+ class AsyncBucket (Bucket ):
298+ def __init__ (self , * args , ** kwargs ):
299+ super ().__init__ (* args , ** kwargs )
300+
301+ async def write (self , kv_pairs : dict [str , bytes ], MAX_CONCURRENCY = 8 ):
302+ # Split the key-value pairs into chunks not exceeding max payload size.
303+ kv_chunks = self ._split_into_chunks (kv_pairs )
304+ # Make one write call per chunk, while respecting a max concurrency limit.
305+ sem = asyncio .Semaphore (MAX_CONCURRENCY )
306+
307+ async def _paced_writer (chunk ):
308+ async with sem :
309+ await self .api .async_write (self .name , json .dumps (chunk ))
310+
311+ _tasks = [asyncio .create_task (_paced_writer (c )) for c in kv_chunks ]
312+ await asyncio .gather (* _tasks )
313+
314+ async def private_read (self , keys : list [str ]) -> list [bytes ]:
315+ if not self .public_uuid or not await self ._async_check (self .public_uuid ):
316+ self .setup ()
317+ assert self .public_uuid
318+
319+ multi_query = self ._generate_query_stream (keys )
320+
321+ start = time .perf_counter ()
322+ multi_result = await self .api .async_private_read (self .name , multi_query )
323+ self .exfil = time .perf_counter () - start
324+
325+ retrievals = self ._unpack_query_result (keys , multi_result , parse_metadata = False )
326+
327+ return retrievals
0 commit comments