Skip to content

Commit 919ec27

Browse files
committed
clean up storage api
1 parent 8ea4447 commit 919ec27

File tree

3 files changed

+70
-76
lines changed

3 files changed

+70
-76
lines changed

pymongo_voyageai/client.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import io
34
import logging
45
from collections.abc import Mapping, Sequence
56
from time import monotonic, sleep
@@ -206,7 +207,19 @@ def image_to_storage(self, document: ImageDocument | Image.Image) -> StoredDocum
206207
"""
207208
if isinstance(document, Image.Image):
208209
document = ImageDocument(image=document)
209-
return self._storage.save_image(document)
210+
object_name = f"{ObjectId()}.png"
211+
fd = io.BytesIO()
212+
document.image.save(fd, "png")
213+
fd.seek(0)
214+
self._storage.save_data(fd, object_name)
215+
return StoredDocument(
216+
root_location=self._storage.root_location,
217+
object_name=object_name,
218+
page_number=document.page_number,
219+
source_url=document.source_url,
220+
name=document.name,
221+
metadata=document.metadata,
222+
)
210223

211224
async def aimage_to_storage(self, document: ImageDocument | Image.Image) -> StoredDocument:
212225
"""Convert an image to a stored document.
@@ -232,7 +245,15 @@ def storage_to_image(self, document: StoredDocument | str) -> ImageDocument:
232245
document = StoredDocument(
233246
root_location=self._storage.root_location, object_name=document
234247
)
235-
return self._storage.load_image(document=document)
248+
buffer = self._storage.read_data(document.object_name)
249+
image = Image.open(buffer)
250+
return ImageDocument(
251+
image=image,
252+
source_url=document.source_url,
253+
page_number=document.page_number,
254+
metadata=document.metadata,
255+
name=document.name,
256+
)
236257

237258
async def astorage_to_image(self, document: StoredDocument | str) -> ImageDocument:
238259
"""Convert a stored document to an image document.
@@ -470,7 +491,7 @@ def delete_many(
470491
self._expand_doc(obj, False)
471492
for inp in obj["inputs"]:
472493
if isinstance(inp, StoredDocument):
473-
self._storage.delete_image(inp)
494+
self._storage.delete_data(inp.object_name)
474495
return self._coll.delete_many(filter=filter, **kwargs).acknowledged
475496

476497
async def adelete_many(

pymongo_voyageai/storage.py

Lines changed: 44 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,36 @@
22

33
import boto3 # type:ignore[import-untyped]
44
import botocore # type:ignore[import-untyped]
5-
from bson import ObjectId
6-
from PIL import Image
7-
8-
from .document import ImageDocument, StoredDocument
95

106

117
class ObjectStorage:
12-
"""A class used store image documents."""
8+
"""A class used to store binary data."""
139

1410
root_location: str
15-
"""The root location to use in the object store."""
11+
"""The default root location to use in the object store."""
1612

1713
url_prefixes: list[str] | None
1814
"""The url prefixes used by the object store, for reading data from a url."""
1915

20-
def save_image(self, image: ImageDocument) -> StoredDocument:
21-
"""Save an image document to the object store."""
16+
def save_data(self, data: io.BytesIO, object_name: str) -> None:
17+
"""Save data to the object store."""
2218
raise NotImplementedError
2319

24-
def load_image(self, document: StoredDocument) -> ImageDocument:
25-
"""Load an image document from the object store."""
20+
def read_data(self, object_name: str) -> io.BytesIO:
21+
"""Read data from the object store."""
2622
raise NotImplementedError
2723

28-
def read_from_url(self, url: str) -> io.BytesIO:
29-
"""Read data from a url into a BytesIO object."""
24+
def load_url(self, url: str) -> io.BytesIO:
25+
"""Load data from a url."""
3026
raise NotImplementedError
3127

32-
def delete_image(self, document: StoredDocument) -> None:
33-
"""Remove an image document from the object store."""
28+
def delete_data(self, object_name: str) -> None:
29+
"""Delete data from the object store."""
3430
raise NotImplementedError
3531

36-
def close(self) -> None:
32+
def close(self):
3733
"""Close the object store."""
38-
raise NotImplementedError
34+
pass
3935

4036

4137
class S3Storage(ObjectStorage):
@@ -59,41 +55,26 @@ def __init__(
5955
self.client = client or boto3.client("s3", region_name=region_name)
6056
self.root_location = bucket_name
6157

62-
def save_image(self, image: ImageDocument) -> StoredDocument:
63-
object_name = f"{ObjectId()}.png"
64-
fd = io.BytesIO()
65-
image.image.save(fd, "png")
66-
fd.seek(0)
67-
self.client.upload_fileobj(fd, self.root_location, object_name)
68-
return StoredDocument(
69-
root_location=self.root_location,
70-
object_name=object_name,
71-
page_number=image.page_number,
72-
source_url=image.source_url,
73-
name=image.name,
74-
metadata=image.metadata,
75-
)
76-
77-
def load_image(self, document: StoredDocument) -> ImageDocument:
58+
def save_data(self, data: io.BytesIO, object_name: str) -> None:
59+
"""Save data to the object store."""
60+
self.client.upload_fileobj(data, self.root_location, object_name)
61+
62+
def read_data(self, object_name: str) -> io.BytesIO:
63+
"""Read data using the object store."""
7864
buffer = io.BytesIO()
79-
self.client.download_fileobj(document.root_location, document.object_name, buffer)
80-
image = Image.open(buffer)
81-
return ImageDocument(
82-
image=image,
83-
source_url=document.source_url,
84-
page_number=document.page_number,
85-
metadata=document.metadata,
86-
name=document.name,
87-
)
88-
89-
def read_from_url(self, url: str) -> io.BytesIO:
90-
bucket, key = url.replace("s3://", "").split("/")
65+
self.client.download_fileobj(self.root_location, object_name, buffer)
66+
return buffer
67+
68+
def load_url(self, url: str) -> io.BytesIO:
69+
"""Load data from a url."""
70+
bucket, _, object_name = url.replace("s3://", "").partition("/")
9171
buffer = io.BytesIO()
92-
self.client.download_fileobj(bucket, key, buffer)
72+
self.client.download_fileobj(bucket, object_name, buffer)
9373
return buffer
9474

95-
def delete_image(self, document: StoredDocument) -> None:
96-
self.client.delete_object(Bucket=document.root_location, Key=document.object_name)
75+
def delete_data(self, object_name: str) -> None:
76+
"""Delete data from the object store."""
77+
self.client.delete_object(Bucket=self.root_location, Key=object_name)
9778

9879
def close(self) -> None:
9980
self.client.close()
@@ -106,29 +87,21 @@ class MemoryStorage(ObjectStorage):
10687

10788
def __init__(self) -> None:
10889
self.root_location = "foo"
109-
self.storage: dict[str, ImageDocument] = dict()
110-
111-
def save_image(self, image: ImageDocument) -> StoredDocument:
112-
object_name = str(ObjectId())
113-
self.storage[object_name] = image
114-
return StoredDocument(
115-
root_location=self.root_location,
116-
name=image.name,
117-
object_name=object_name,
118-
source_url=image.source_url,
119-
page_number=image.page_number,
120-
)
121-
122-
def load_image(self, document: StoredDocument) -> ImageDocument:
123-
return self.storage[document.object_name]
124-
125-
def read_from_url(self, url: str) -> io.BytesIO:
126-
with open(url.replace("file://", ""), "rb") as fid:
127-
data = fid.read()
128-
return io.BytesIO(data)
90+
self.storage: dict[str, io.BytesIO] = dict()
12991

130-
def delete_image(self, document: StoredDocument) -> None:
131-
self.storage.pop(document.object_name, None)
92+
def save_data(self, data: io.BytesIO, object_name: str) -> None:
93+
"""Save data to the object store."""
94+
self.storage[object_name] = data
13295

133-
def close(self):
134-
pass
96+
def read_data(self, object_name: str) -> io.BytesIO:
97+
"""Read data using the object store."""
98+
return self.storage[object_name]
99+
100+
def load_url(self, url: str) -> io.BytesIO:
101+
"""Load data from a url."""
102+
with open(url.replace("file://", ""), "rb") as fid:
103+
return io.BytesIO(fid.read())
104+
105+
def delete_data(self, object_name: str) -> None:
106+
"""Delete data from the object store."""
107+
self.storage.pop(object_name, None)

pymongo_voyageai/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,15 @@ def url_to_images(
9292
if storage and storage.url_prefixes:
9393
for pattern in storage.url_prefixes:
9494
if url.startswith(pattern):
95-
source = storage.read_from_url(url)
95+
source = storage.load_url(url)
9696
break
9797
# For parquet files that are not loaded by the storage object, let pandas handle the download.
9898
if source is None and url.endswith(".parquet"):
9999
source = url
100100
# For s3 files that are not loaded by the storage object, create a temp S3Storage object.
101101
if source is None and url.startswith("s3://"):
102102
storage = S3Storage("")
103-
source = storage.read_from_url(url)
103+
source = storage.load_url(url)
104104
storage.close()
105105
# For all other files, use the native download.
106106
if source is None:

0 commit comments

Comments
 (0)