Skip to content

Commit b77a93f

Browse files
authored
Python client: new AsyncBucket interface (#17)
New client interface to support async usage. This commit implements async versions of write() and privateRead(); all other calls fall back to blocking synchronous versions.
1 parent 7b65562 commit b77a93f

File tree

7 files changed

+231
-30
lines changed

7 files changed

+231
-30
lines changed

.vscode/settings.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"python.analysis.typeCheckingMode": "basic"
3+
}

python/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "blyss-client-python"
3-
version = "0.1.0"
3+
version = "0.1.7"
44
edition = "2021"
55

66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

python/blyss/api.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
from typing import Any, Optional, Union
99
import requests
10+
import httpx
11+
import gzip
12+
import asyncio
1013
import json
1114
import logging
1215
import base64
@@ -58,6 +61,24 @@ def _get_data(api_key: Optional[str], url: str) -> bytes:
5861
return resp.content
5962

6063

64+
async def _async_get_data(
65+
api_key: Optional[str], url: str, decode_json: bool = True
66+
) -> Any:
67+
headers = {}
68+
if api_key:
69+
headers["x-api-key"] = api_key
70+
71+
logging.info(f"GET {url} {headers}")
72+
async with httpx.AsyncClient() as client:
73+
r = await client.get(url, headers=headers)
74+
_check_http_error(r)
75+
76+
if decode_json:
77+
return r.json()
78+
else:
79+
return r.content
80+
81+
6182
def _get_data_json(api_key: str, url: str) -> dict[Any, Any]:
6283
"""Perform an HTTP GET request, returning a JSON-parsed dict"""
6384
return json.loads(_get_data(api_key, url))
@@ -94,7 +115,37 @@ def _post_form_data(url: str, fields: dict[Any, Any], data: bytes):
94115
_check_http_error(resp)
95116

96117

97-
# API
118+
async def _async_post_data(
119+
api_key: str,
120+
url: str,
121+
data: Union[str, bytes],
122+
compress: bool = True,
123+
decode_json: bool = True,
124+
) -> Any:
125+
"""Perform an async HTTP POST request, returning a JSON-parsed dict response"""
126+
headers = {
127+
"x-api-key": api_key,
128+
}
129+
if type(data) == str:
130+
headers["Content-Type"] = "application/json"
131+
data = data.encode("utf-8")
132+
else:
133+
headers["Content-Type"] = "application/octet-stream"
134+
assert type(data) == bytes
135+
136+
if compress:
137+
# apply gzip compression to data before sending
138+
data = gzip.compress(data)
139+
headers["Content-Encoding"] = "gzip"
140+
141+
async with httpx.AsyncClient(timeout=httpx.Timeout(5, read=None)) as client:
142+
r = await client.post(url, content=data, headers=headers)
143+
144+
_check_http_error(r) # type: ignore
145+
if decode_json:
146+
return r.json()
147+
else:
148+
return r.content
98149

99150

100151
class API:
@@ -138,6 +189,13 @@ def check(self, uuid: str) -> dict[Any, Any]:
138189
self.api_key, self._service_url_for("/" + uuid + CHECK_PATH)
139190
)
140191

192+
async def async_check(self, uuid: str) -> dict[Any, Any]:
193+
return await _async_get_data(
194+
self.api_key,
195+
self._service_url_for("/" + uuid + CHECK_PATH),
196+
decode_json=True,
197+
)
198+
141199
def list_buckets(self) -> dict[Any, Any]:
142200
"""List all buckets accessible to this API key.
143201
@@ -204,6 +262,12 @@ def write(self, bucket_name: str, data: bytes):
204262
"""Write some data to this bucket."""
205263
_post_data(self.api_key, self._url_for(bucket_name, WRITE_PATH), data)
206264

265+
async def async_write(self, bucket_name: str, data: str):
266+
"""Write JSON payload to this bucket."""
267+
await _async_post_data(
268+
self.api_key, self._url_for(bucket_name, WRITE_PATH), data, decode_json=True
269+
)
270+
207271
def delete_key(self, bucket_name: str, key: str):
208272
"""Delete a key in this bucket."""
209273
_post_data(
@@ -214,3 +278,12 @@ def private_read(self, bucket_name: str, data: bytes) -> bytes:
214278
"""Privately read data from this bucket."""
215279
val = _post_data(self.api_key, self._url_for(bucket_name, READ_PATH), data)
216280
return base64.b64decode(val)
281+
282+
async def async_private_read(self, bucket_name: str, data: bytes) -> bytes:
283+
"""Privately read data from this bucket."""
284+
val: bytes = await _async_post_data(
285+
self.api_key, self._url_for(bucket_name, READ_PATH), data, decode_json=False
286+
)
287+
# AWS APIGW encodes its responses as base64
288+
return base64.b64decode(val)
289+
# return self.private_read(bucket_name, data)

python/blyss/bucket.py

Lines changed: 130 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
from .blyss_lib import BlyssLib
1010

1111
import json
12+
import base64
1213
import bz2
1314
import time
15+
import asyncio
1416

1517

1618
def _chunk_parser(raw_data: bytes) -> Iterator[bytes]:
@@ -71,23 +73,72 @@ def _check(self, uuid: str) -> bool:
7173
else:
7274
raise e
7375

74-
def _private_read(self, keys: list[str]) -> list[tuple[bytes, Optional[dict[Any, Any]]]]:
75-
"""Performs the underlying private retrieval.
76+
async def _async_check(self, uuid: str) -> bool:
77+
try:
78+
await self.api.async_check(uuid)
79+
return True
80+
except api.ApiException as e:
81+
if e.code == 404:
82+
return False
83+
else:
84+
raise e
7685

77-
Args:
78-
keys (str): A list of keys to retrieve.
86+
def _split_into_chunks(
87+
self, kv_pairs: dict[str, bytes]
88+
) -> list[list[dict[str, str]]]:
89+
_MAX_PAYLOAD = 5 * 2**20 # 5 MiB
90+
91+
# 1. Bin keys by row index
92+
keys_by_index: dict[int, list[str]] = {}
93+
for k in kv_pairs.keys():
94+
i = self.lib.get_row(k)
95+
if i in keys_by_index:
96+
keys_by_index[i].append(k)
97+
else:
98+
keys_by_index[i] = [k]
99+
100+
# 2. Prepare chunks of items, where each is a JSON-ready structure.
101+
# Each chunk is less than the maximum payload size, and guarantees
102+
# zero overlap of rows across chunks.
103+
kv_chunks: list[list[dict[str, str]]] = []
104+
current_chunk: list[dict[str, str]] = []
105+
current_chunk_size = 0
106+
sorted_indices = sorted(keys_by_index.keys())
107+
for i in sorted_indices:
108+
keys = keys_by_index[i]
109+
# prepare all keys in this row
110+
row = []
111+
row_size = 0
112+
for key in keys:
113+
value = kv_pairs[key]
114+
value_str = base64.b64encode(value).decode("utf-8")
115+
fmt = {
116+
"key": key,
117+
"value": value_str,
118+
"content-type": "application/octet-stream",
119+
}
120+
row.append(fmt)
121+
row_size += int(72 + len(key) + len(value_str))
122+
123+
# if the new row doesn't fit into the current chunk, start a new one
124+
if current_chunk_size + row_size > _MAX_PAYLOAD:
125+
kv_chunks.append(current_chunk)
126+
current_chunk = row
127+
current_chunk_size = row_size
128+
else:
129+
current_chunk.extend(row)
130+
current_chunk_size += row_size
79131

80-
Returns:
81-
tuple[bytes, Optional[dict]]: Returns a tuple of (value, optional_metadata).
82-
"""
83-
if not self.public_uuid or not self._check(self.public_uuid):
84-
self.setup()
85-
assert self.public_uuid
132+
# add the last chunk
133+
if len(current_chunk) > 0:
134+
kv_chunks.append(current_chunk)
135+
136+
return kv_chunks
86137

138+
def _generate_query_stream(self, keys: list[str]) -> bytes:
87139
# generate encrypted queries
88140
queries: list[bytes] = [
89-
self.lib.generate_query(self.public_uuid, self.lib.get_row(k))
90-
for k in keys
141+
self.lib.generate_query(self.public_uuid, self.lib.get_row(k)) for k in keys
91142
]
92143
# interleave the queries with their lengths (uint64_t)
93144
query_lengths = [len(q).to_bytes(8, "little") for q in queries]
@@ -96,18 +147,43 @@ def _private_read(self, keys: list[str]) -> list[tuple[bytes, Optional[dict[Any,
96147
lengths_and_queries.insert(0, len(queries).to_bytes(8, "little"))
97148
# serialize the queries
98149
multi_query = b"".join(lengths_and_queries)
99-
100-
start = time.perf_counter()
101-
multi_result = self.api.private_read(self.name, multi_query)
102-
self.exfil = time.perf_counter() - start
150+
return multi_query
103151

104-
retrievals = []
105-
for key, result in zip(keys, _chunk_parser(multi_result)):
152+
def _unpack_query_result(
153+
self, keys: list[str], raw_result: bytes, parse_metadata: bool = True
154+
) -> list[bytes]:
155+
retrievals = []
156+
for key, result in zip(keys, _chunk_parser(raw_result)):
106157
decrypted_result = self.lib.decode_response(result)
107158
decompressed_result = bz2.decompress(decrypted_result)
108159
extracted_result = self.lib.extract_result(key, decompressed_result)
109-
output = serializer.deserialize(extracted_result)
160+
if parse_metadata:
161+
output = serializer.deserialize(extracted_result)
162+
else:
163+
output = extracted_result
110164
retrievals.append(output)
165+
return retrievals
166+
167+
def _private_read(self, keys: list[str]) -> list[bytes]:
168+
"""Performs the underlying private retrieval.
169+
170+
Args:
171+
keys (str): A list of keys to retrieve.
172+
173+
Returns:
174+
tuple[bytes, Optional[dict]]: Returns a tuple of (value, optional_metadata).
175+
"""
176+
if not self.public_uuid or not self._check(self.public_uuid):
177+
self.setup()
178+
assert self.public_uuid
179+
180+
multi_query = self._generate_query_stream(keys)
181+
182+
start = time.perf_counter()
183+
multi_result = self.api.private_read(self.name, multi_query)
184+
self.exfil = time.perf_counter() - start
185+
186+
retrievals = self._unpack_query_result(keys, multi_result)
111187

112188
return retrievals
113189

@@ -184,11 +260,11 @@ def private_read(self, keys: Union[str, list[str]]) -> Union[bytes, list[bytes]]
184260
185261
Args:
186262
keys (str): A key or list of keys to privately read.
187-
If a list of keys is supplied,
263+
If a list of keys is supplied,
188264
results will be returned in the same order.
189265
190266
Returns:
191-
bytes: The value found for the key in the bucket,
267+
bytes: The value found for the key in the bucket,
192268
or None if the key was not found.
193269
"""
194270
single_query = False
@@ -202,7 +278,6 @@ def private_read(self, keys: Union[str, list[str]]) -> Union[bytes, list[bytes]]
202278

203279
return results
204280

205-
206281
def private_key_intersect(self, keys: list[str]) -> list[str]:
207282
"""Privately intersects the given set of keys with the keys in this bucket,
208283
returning the keys that intersected. This is generally slower than a single
@@ -217,3 +292,36 @@ def private_key_intersect(self, keys: list[str]) -> list[str]:
217292
bloom_filter = self.api.bloom(self.name)
218293
present_keys = list(filter(bloom_filter.lookup, keys))
219294
return present_keys
295+
296+
297+
class AsyncBucket(Bucket):
298+
def __init__(self, *args, **kwargs):
299+
super().__init__(*args, **kwargs)
300+
301+
async def write(self, kv_pairs: dict[str, bytes], MAX_CONCURRENCY=8):
302+
# Split the key-value pairs into chunks not exceeding max payload size.
303+
kv_chunks = self._split_into_chunks(kv_pairs)
304+
# Make one write call per chunk, while respecting a max concurrency limit.
305+
sem = asyncio.Semaphore(MAX_CONCURRENCY)
306+
307+
async def _paced_writer(chunk):
308+
async with sem:
309+
await self.api.async_write(self.name, json.dumps(chunk))
310+
311+
_tasks = [asyncio.create_task(_paced_writer(c)) for c in kv_chunks]
312+
await asyncio.gather(*_tasks)
313+
314+
async def private_read(self, keys: list[str]) -> list[bytes]:
315+
if not self.public_uuid or not await self._async_check(self.public_uuid):
316+
self.setup()
317+
assert self.public_uuid
318+
319+
multi_query = self._generate_query_stream(keys)
320+
321+
start = time.perf_counter()
322+
multi_result = await self.api.async_private_read(self.name, multi_query)
323+
self.exfil = time.perf_counter() - start
324+
325+
retrievals = self._unpack_query_result(keys, multi_result, parse_metadata=False)
326+
327+
return retrievals

python/blyss/bucket_service.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def __init__(self, api_config: Union[str, ApiConfig]):
3232
self.api = api.API(self.api_config["api_key"], self.service_endpoint)
3333

3434
def connect(
35-
self, bucket_name: str, secret_seed: Optional[str] = None
35+
self,
36+
bucket_name: str,
37+
secret_seed: Optional[str] = None,
3638
) -> bucket.Bucket:
3739
"""Connect to an existing Blyss bucket.
3840
@@ -47,10 +49,22 @@ def connect(
4749
"""
4850
if secret_seed is None:
4951
secret_seed = seed.get_random_seed()
50-
5152
b = bucket.Bucket(self.api, bucket_name, secret_seed=secret_seed)
5253
return b
5354

55+
def connect_async(
56+
self, bucket_name: str, secret_seed: Optional[str] = None
57+
) -> bucket.AsyncBucket:
58+
"""Connect to an existing Blyss bucket, using an asyncio-ready interface.
59+
60+
Args:
61+
see connect()
62+
63+
Returns:
64+
bucket.Bucket: An object representing a client to the Blyss bucket.
65+
"""
66+
return bucket.AsyncBucket(self.api, bucket_name, secret_seed=secret_seed)
67+
5468
def create(
5569
self,
5670
bucket_name: str,

python/pyproject.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@ build-backend = "maturin"
44

55
[project]
66
name = "blyss"
7-
requires-python = ">=3.7"
7+
requires-python = ">=3.8"
88
classifiers = [
99
"Programming Language :: Rust",
1010
"Programming Language :: Python :: Implementation :: CPython",
1111
"Programming Language :: Python :: Implementation :: PyPy",
1212
]
13-
14-
13+
dependencies = [
14+
"requests",
15+
"httpx",
16+
]
17+
dynamic = ["version"]

0 commit comments

Comments
 (0)