Skip to content

Commit cf2b552

Browse files
feat(aadd_images): add support for store_uri_only parameter (#428)
* feat(aadd_images): add support for store_uri_only parameter * chore: balck reformating * Delete get-pip.py * fix: tests for store_uri_only --------- Co-authored-by: Vishwaraj Anand <[email protected]>
1 parent 762698e commit cf2b552

File tree

4 files changed

+148
-12
lines changed

4 files changed

+148
-12
lines changed

src/langchain_google_alloydb_pg/async_vectorstore.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -397,29 +397,42 @@ async def aadd_images(
397397
uris: list[str],
398398
metadatas: Optional[list[dict]] = None,
399399
ids: Optional[list[str]] = None,
400+
store_uri_only: bool = False,
400401
**kwargs: Any,
401402
) -> list[str]:
402403
"""Embed images and add to the table.
403404
404405
Args:
405-
uris (list[str]): List of local image URIs to add to the table.
406+
uris (list[str]): List of image URIs to add to the table.
406407
metadatas (Optional[list[dict]]): List of metadatas to add to table records.
407408
ids: (Optional[list[str]]): List of IDs to add to table records.
409+
store_uri_only (bool): If True, stores the URI in the content column
410+
instead of the base64 encoded image. Defaults to False.
411+
**kwargs: Any other arguments to pass to the embedding service.
408412
409413
Returns:
410414
List of record IDs added.
411415
"""
412-
encoded_images = []
413416
if metadatas is None:
417+
# Ensure URI is always in metadata if not explicitly provided elsewhere
414418
metadatas = [{"image_uri": uri} for uri in uris]
419+
elif store_uri_only:
420+
# If storing URI only and metadatas are provided, ensure image_uri is present
421+
for i, m in enumerate(metadatas):
422+
if "image_uri" not in m: # Add if not already provided by user
423+
m["image_uri"] = uris[i]
424+
425+
texts_for_content_column: list[str]
426+
if store_uri_only:
427+
texts_for_content_column = uris
428+
else:
429+
texts_for_content_column = [self._encode_image(uri) for uri in uris]
415430

416-
for uri in uris:
417-
encoded_image = self._encode_image(uri)
418-
encoded_images.append(encoded_image)
419-
431+
# Embeddings are always generated from the actual image content via URIs
420432
embeddings = self._images_embedding_helper(uris)
433+
421434
ids = await self.aadd_embeddings(
422-
encoded_images, embeddings, metadatas=metadatas, ids=ids, **kwargs
435+
texts_for_content_column, embeddings, metadatas=metadatas, ids=ids, **kwargs
423436
)
424437
return ids
425438

src/langchain_google_alloydb_pg/vectorstore.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,14 @@ async def aadd_images(
231231
uris: list[str],
232232
metadatas: Optional[list[dict]] = None,
233233
ids: Optional[list[str]] = None,
234+
store_uri_only: bool = False,
234235
**kwargs: Any,
235236
) -> list[str]:
236237
"""Embed images and add to the table."""
237238
return await self._engine._run_as_async(
238-
self.__vs.aadd_images(uris, metadatas, ids, **kwargs)
239+
self.__vs.aadd_images(
240+
uris, metadatas, ids, store_uri_only=store_uri_only, **kwargs
241+
)
239242
)
240243

241244
def add_embeddings(
@@ -287,11 +290,14 @@ def add_images(
287290
uris: list[str],
288291
metadatas: Optional[list[dict]] = None,
289292
ids: Optional[list[str]] = None,
293+
store_uri_only: bool = False,
290294
**kwargs: Any,
291295
) -> list[str]:
292296
"""Embed images and add to the table."""
293297
return self._engine._run_as_sync(
294-
self.__vs.aadd_images(uris, metadatas, ids, **kwargs)
298+
self.__vs.aadd_images(
299+
uris, metadatas, ids, store_uri_only=store_uri_only, **kwargs
300+
)
295301
)
296302

297303
async def adelete(

tests/test_async_vectorstore.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import json
1516
import os
1617
import uuid
1718
from typing import Sequence
@@ -48,7 +49,7 @@
4849
class FakeImageEmbedding(DeterministicFakeEmbedding):
4950

5051
def embed_image(self, image_paths: list[str]) -> list[list[float]]:
51-
return [self.embed_query(path) for path in image_paths]
52+
return [self.embed_query(f"Image Path: {path}") for path in image_paths]
5253

5354

5455
image_embedding_service = FakeImageEmbedding(size=VECTOR_SIZE)
@@ -154,6 +155,7 @@ async def image_vs(self, engine):
154155
Column("image_id", "TEXT"),
155156
Column("source", "TEXT"),
156157
],
158+
metadata_json_column="mymeta",
157159
)
158160
vs = await AsyncAlloyDBVectorStore.create(
159161
engine,
@@ -278,6 +280,32 @@ async def test_aadd_images(self, engine, image_vs, image_uris):
278280
assert results[0]["source"] == "google.com"
279281
await aexecute(engine, (f'TRUNCATE TABLE "{IMAGE_TABLE}"'))
280282

283+
async def test_aadd_images_store_uri_only(self, engine, image_vs, image_uris):
284+
ids = [str(uuid.uuid4()) for i in range(len(image_uris))]
285+
metadatas = [
286+
{"image_id": str(i), "source": "google.com"} for i in range(len(image_uris))
287+
]
288+
await image_vs.aadd_images(image_uris, metadatas, ids, store_uri_only=True)
289+
results = await afetch(engine, (f'SELECT * FROM "{IMAGE_TABLE}"'))
290+
assert len(results) == len(image_uris)
291+
# Check that content column stores the URI
292+
for i, result_row in enumerate(results):
293+
assert result_row[image_vs.content_column] == image_uris[i]
294+
# Check that embedding is not an embedding of the URI string itself (basic check)
295+
uri_embedding = embeddings_service.embed_query(image_uris[i])
296+
image_embedding = image_embedding_service.embed_image([image_uris[i]])[0]
297+
actual_embedding = json.loads(result_row[image_vs.embedding_column])
298+
assert actual_embedding != pytest.approx(uri_embedding)
299+
assert actual_embedding == pytest.approx(image_embedding)
300+
assert result_row["image_id"] == str(i)
301+
assert result_row["source"] == "google.com"
302+
# Check that the original URI is also in the metadata (json column)
303+
assert (
304+
result_row[image_vs.metadata_json_column]["image_uri"] == image_uris[i]
305+
)
306+
307+
await aexecute(engine, (f'TRUNCATE TABLE "{IMAGE_TABLE}"'))
308+
281309
async def test_adelete(self, engine, vs):
282310
ids = [str(uuid.uuid4()) for i in range(len(texts))]
283311
await vs.aadd_texts(texts, ids=ids)

tests/test_vectorstore.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
import json
1617
import os
1718
import uuid
1819
from threading import Thread
@@ -52,7 +53,7 @@
5253
class FakeImageEmbedding(DeterministicFakeEmbedding):
5354

5455
def embed_image(self, image_paths: list[str]) -> list[list[float]]:
55-
return [self.embed_query(path) for path in image_paths]
56+
return [self.embed_query(f"Image Path: {path}") for path in image_paths]
5657

5758

5859
image_embedding_service = FakeImageEmbedding(size=VECTOR_SIZE)
@@ -357,7 +358,52 @@ async def test_aadd_images(self, engine_sync, image_uris):
357358
assert len(results) == len(image_uris)
358359
assert results[0]["image_id"] == "0"
359360
assert results[0]["source"] == "google.com"
360-
await aexecute(engine_sync, f'TRUNCATE TABLE "{IMAGE_TABLE}"')
361+
await aexecute(engine_sync, f'DROP TABLE IF EXISTS "{IMAGE_TABLE}"')
362+
363+
async def test_aadd_images_store_uri_only(self, engine_sync, image_uris):
364+
table_name = IMAGE_TABLE_SYNC + "_store_uri_only"
365+
engine_sync.init_vectorstore_table(
366+
table_name,
367+
VECTOR_SIZE,
368+
metadata_columns=[
369+
Column("image_id", "TEXT"),
370+
Column("source", "TEXT"),
371+
],
372+
metadata_json_column="mymeta",
373+
)
374+
vs = AlloyDBVectorStore.create_sync(
375+
engine_sync,
376+
embedding_service=image_embedding_service,
377+
table_name=table_name,
378+
metadata_columns=["image_id", "source"],
379+
metadata_json_column="mymeta",
380+
)
381+
ids = [str(uuid.uuid4()) for i in range(len(image_uris))]
382+
metadatas = [
383+
{"image_id": str(i), "source": "google.com"} for i in range(len(image_uris))
384+
]
385+
# Test the async method on the sync class
386+
await vs.aadd_images(image_uris, metadatas, ids, store_uri_only=True)
387+
results = await afetch(engine_sync, f'SELECT * FROM "{table_name}"')
388+
assert len(results) == len(image_uris)
389+
for i, result_row in enumerate(results):
390+
assert (
391+
result_row[vs._AlloyDBVectorStore__vs.content_column] == image_uris[i]
392+
)
393+
uri_embedding = embeddings_service.embed_query(image_uris[i])
394+
image_embedding = image_embedding_service.embed_image([image_uris[i]])[0]
395+
actual_embedding = json.loads(
396+
result_row[vs._AlloyDBVectorStore__vs.embedding_column]
397+
)
398+
assert actual_embedding != pytest.approx(uri_embedding)
399+
assert actual_embedding == pytest.approx(image_embedding)
400+
assert result_row["image_id"] == str(i)
401+
assert result_row["source"] == "google.com"
402+
assert (
403+
result_row[vs._AlloyDBVectorStore__vs.metadata_json_column]["image_uri"]
404+
== image_uris[i]
405+
)
406+
await aexecute(engine_sync, f'DROP TABLE IF EXISTS "{table_name}"')
361407

362408
async def test_adelete_custom(self, engine, vs_custom):
363409
ids = [str(uuid.uuid4()) for i in range(len(texts))]
@@ -405,6 +451,49 @@ async def test_add_images(self, engine_sync, image_uris):
405451
await vs.adelete(ids)
406452
await aexecute(engine_sync, f'DROP TABLE IF EXISTS "{IMAGE_TABLE_SYNC}"')
407453

454+
async def test_add_images_store_uri_only(self, engine_sync, image_uris):
455+
table_name = IMAGE_TABLE_SYNC + "_store_uri_only"
456+
engine_sync.init_vectorstore_table(
457+
table_name,
458+
VECTOR_SIZE,
459+
metadata_columns=[Column("image_id", "TEXT"), Column("source", "TEXT")],
460+
metadata_json_column="mymeta",
461+
)
462+
vs = AlloyDBVectorStore.create_sync(
463+
engine_sync,
464+
embedding_service=image_embedding_service,
465+
table_name=table_name,
466+
metadata_columns=["image_id", "source"],
467+
metadata_json_column="mymeta",
468+
)
469+
470+
ids = [str(uuid.uuid4()) for i in range(len(image_uris))]
471+
metadatas = [
472+
{"image_id": str(i), "source": "google.com"} for i in range(len(image_uris))
473+
]
474+
vs.add_images(image_uris, metadatas, ids, store_uri_only=True)
475+
results = await afetch(engine_sync, (f'SELECT * FROM "{table_name}"'))
476+
assert len(results) == len(image_uris)
477+
for i, result_row in enumerate(results):
478+
assert (
479+
result_row[vs._AlloyDBVectorStore__vs.content_column] == image_uris[i]
480+
)
481+
uri_embedding = embeddings_service.embed_query(image_uris[i])
482+
image_embedding = image_embedding_service.embed_image([image_uris[i]])[0]
483+
actual_embedding = json.loads(
484+
result_row[vs._AlloyDBVectorStore__vs.embedding_column]
485+
)
486+
assert actual_embedding != pytest.approx(uri_embedding)
487+
assert actual_embedding == pytest.approx(image_embedding)
488+
assert result_row["image_id"] == str(i)
489+
assert result_row["source"] == "google.com"
490+
assert (
491+
result_row[vs._AlloyDBVectorStore__vs.metadata_json_column]["image_uri"]
492+
== image_uris[i]
493+
)
494+
await vs.adelete(ids)
495+
await aexecute(engine_sync, f'DROP TABLE IF EXISTS "{table_name}"')
496+
408497
async def test_cross_env(self, engine_sync, vs_sync):
409498
ids = [str(uuid.uuid4()) for i in range(len(texts))]
410499
await vs_sync.aadd_texts(texts, ids=ids)

0 commit comments

Comments
 (0)