Skip to content

refactor!: Refactor AlloyDBVectorStore to depend on PGVectorstore #435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/vector_store.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@
"await custom_store.aadd_texts(all_texts, metadatas=metadatas, ids=ids)\n",
"\n",
"# Use filter on search\n",
"docs = await custom_store.asimilarity_search(query, filter=\"len >= 6\")\n",
"docs = await custom_store.asimilarity_search(query, filter={\"len\": {\"$gte\": 6}})\n",
"\n",
"print(docs)"
]
Expand Down Expand Up @@ -774,7 +774,7 @@
"source": [
"import uuid\n",
"\n",
"docs = await custom_store.asimilarity_search(query, filter=\"price_usd > 100\")\n",
"docs = await custom_store.asimilarity_search(query, filter={\"price_usd\": {\"$gte\": 100}})\n",
"\n",
"print(docs)"
]
Expand Down
1,226 changes: 11 additions & 1,215 deletions src/langchain_google_alloydb_pg/async_vectorstore.py

Large diffs are not rendered by default.

751 changes: 31 additions & 720 deletions src/langchain_google_alloydb_pg/vectorstore.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions tests/test_async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

DEFAULT_TABLE = "test_table" + str(uuid.uuid4())
DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4())
CUSTOM_TABLE = "test-table-custom" + str(uuid.uuid4())
IMAGE_TABLE = "test_image_table" + str(uuid.uuid4())
CUSTOM_TABLE = "custom" + str(uuid.uuid4())
IMAGE_TABLE = "image" + str(uuid.uuid4())
VECTOR_SIZE = 768

embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)
Expand Down Expand Up @@ -111,6 +111,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name):
yield engine
await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"')
await aexecute(engine, f'DROP TABLE IF EXISTS "{CUSTOM_TABLE}"')
await aexecute(engine, f'DROP TABLE IF EXISTS "{IMAGE_TABLE}"')
await engine.close()

@pytest_asyncio.fixture(scope="class")
Expand Down
22 changes: 20 additions & 2 deletions tests/test_async_vectorstore_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
DistanceStrategy,
HNSWIndex,
IVFFlatIndex,
IVFIndex,
)

DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_")
DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX
UUID_STR = str(uuid.uuid4()).replace("-", "_")
DEFAULT_TABLE = "test_table" + UUID_STR
DEFAULT_INDEX_NAME = DEFAULT_INDEX_NAME_SUFFIX + UUID_STR
VECTOR_SIZE = 768

embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)
Expand Down Expand Up @@ -109,6 +111,22 @@ async def vs(self, engine):
await vs.adrop_vector_index()
yield vs

async def test_aapply_vector_index_ivf(self, vs):
index = IVFIndex(
name=DEFAULT_INDEX_NAME,
distance_strategy=DistanceStrategy.EUCLIDEAN,
)
await vs.aapply_vector_index(index, concurrently=True)
assert await vs.is_valid_index(DEFAULT_INDEX_NAME)
index = IVFIndex(
name="secondindex",
distance_strategy=DistanceStrategy.INNER_PRODUCT,
)
await vs.aapply_vector_index(index)
assert await vs.is_valid_index("secondindex")
await vs.adrop_vector_index("secondindex")
await vs.adrop_vector_index()

async def test_aapply_vector_index(self, vs):
index = HNSWIndex()
await vs.aapply_vector_index(index)
Expand Down
20 changes: 8 additions & 12 deletions tests/test_async_vectorstore_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@

DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_")
CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_")
IMAGE_TABLE = "test_image_table" + str(uuid.uuid4()).replace("-", "_")
CUSTOM_FILTER_TABLE = "test_table_custom_filter" + str(uuid.uuid4()).replace("-", "_")
IMAGE_TABLE = "image" + str(uuid.uuid4()).replace("-", "_")
CUSTOM_FILTER_TABLE = "custom_filter" + str(uuid.uuid4()).replace("-", "_")
VECTOR_SIZE = 768
sync_method_exception_str = "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead."

Expand Down Expand Up @@ -118,6 +118,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name):
await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}")
await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE}")
await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE}")
await aexecute(engine, f"DROP TABLE IF EXISTS {IMAGE_TABLE}")
await engine.close()

@pytest_asyncio.fixture(scope="class")
Expand Down Expand Up @@ -248,15 +249,15 @@ async def test_asimilarity_search(self, vs):
results = await vs.asimilarity_search("foo", k=1)
assert len(results) == 1
assert results == [Document(page_content="foo", id=ids[0])]
results = await vs.asimilarity_search("foo", k=1, filter="content = 'bar'")
results = await vs.asimilarity_search("foo", k=1, filter={"content": "bar"})
assert results == [Document(page_content="bar", id=ids[1])]

async def test_asimilarity_search_scann(self, vs_custom_scann_query_option):
results = await vs_custom_scann_query_option.asimilarity_search("foo", k=1)
assert len(results) == 1
assert results == [Document(page_content="foo", id=ids[0])]
results = await vs_custom_scann_query_option.asimilarity_search(
"foo", k=1, filter="mycontent = 'bar'"
"foo", k=1, filter={"mycontent": "bar"}
)
assert results == [Document(page_content="bar", id=ids[1])]

Expand Down Expand Up @@ -333,7 +334,7 @@ async def test_amax_marginal_relevance_search(self, vs):
results = await vs.amax_marginal_relevance_search("bar")
assert results[0] == Document(page_content="bar", id=ids[1])
results = await vs.amax_marginal_relevance_search(
"bar", filter="content = 'boo'"
"bar", filter={"content": "boo"}
)
assert results[0] == Document(page_content="boo", id=ids[3])

Expand All @@ -359,7 +360,7 @@ async def test_similarity_search(self, vs_custom):
assert len(results) == 1
assert results == [Document(page_content="foo", id=ids[0])]
results = await vs_custom.asimilarity_search(
"foo", k=1, filter="mycontent = 'bar'"
"foo", k=1, filter={"mycontent": "bar"}
)
assert results == [Document(page_content="bar", id=ids[1])]

Expand All @@ -386,7 +387,7 @@ async def test_max_marginal_relevance_search(self, vs_custom):
results = await vs_custom.amax_marginal_relevance_search("bar")
assert results[0] == Document(page_content="bar", id=ids[1])
results = await vs_custom.amax_marginal_relevance_search(
"bar", filter="mycontent = 'boo'"
"bar", filter={"mycontent": "boo"}
)
assert results[0] == Document(page_content="boo", id=ids[3])

Expand Down Expand Up @@ -419,11 +420,6 @@ async def test_aget_by_ids_custom_vs(self, vs_custom):

assert results[0] == Document(page_content="foo", id=ids[0])

def test_get_by_ids(self, vs):
test_ids = [ids[0]]
with pytest.raises(Exception, match=sync_method_exception_str):
vs.get_by_ids(ids=test_ids)

@pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES)
async def test_vectorstore_with_metadata_filters(
self,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_standard_test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column

DEFAULT_TABLE = "test_table_standard_test_suite" + str(uuid.uuid4())
DEFAULT_TABLE_SYNC = "test_table_sync_standard_test_suite" + str(uuid.uuid4())
DEFAULT_TABLE = "test_table" + str(uuid.uuid4())
DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4())


def get_env_var(key: str, desc: str) -> str:
Expand Down
22 changes: 9 additions & 13 deletions tests/test_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@

DEFAULT_TABLE = "test_table" + str(uuid.uuid4())
DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4())
CUSTOM_TABLE = "test-table-custom" + str(uuid.uuid4())
IMAGE_TABLE = "test_image_table" + str(uuid.uuid4())
IMAGE_TABLE_SYNC = "test_image_table_sync" + str(uuid.uuid4())
CUSTOM_TABLE = "custom" + str(uuid.uuid4())
IMAGE_TABLE = "image" + str(uuid.uuid4())
IMAGE_TABLE_SYNC = "image_sync" + str(uuid.uuid4())
VECTOR_SIZE = 768

embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)
Expand Down Expand Up @@ -387,20 +387,18 @@ async def test_aadd_images_store_uri_only(self, engine_sync, image_uris):
results = await afetch(engine_sync, f'SELECT * FROM "{table_name}"')
assert len(results) == len(image_uris)
for i, result_row in enumerate(results):
assert (
result_row[vs._AlloyDBVectorStore__vs.content_column] == image_uris[i]
)
assert result_row[vs._PGVectorStore__vs.content_column] == image_uris[i]
uri_embedding = embeddings_service.embed_query(image_uris[i])
image_embedding = image_embedding_service.embed_image([image_uris[i]])[0]
actual_embedding = json.loads(
result_row[vs._AlloyDBVectorStore__vs.embedding_column]
result_row[vs._PGVectorStore__vs.embedding_column]
)
assert actual_embedding != pytest.approx(uri_embedding)
assert actual_embedding == pytest.approx(image_embedding)
assert result_row["image_id"] == str(i)
assert result_row["source"] == "google.com"
assert (
result_row[vs._AlloyDBVectorStore__vs.metadata_json_column]["image_uri"]
result_row[vs._PGVectorStore__vs.metadata_json_column]["image_uri"]
== image_uris[i]
)
await aexecute(engine_sync, f'DROP TABLE IF EXISTS "{table_name}"')
Expand Down Expand Up @@ -475,20 +473,18 @@ async def test_add_images_store_uri_only(self, engine_sync, image_uris):
results = await afetch(engine_sync, (f'SELECT * FROM "{table_name}"'))
assert len(results) == len(image_uris)
for i, result_row in enumerate(results):
assert (
result_row[vs._AlloyDBVectorStore__vs.content_column] == image_uris[i]
)
assert result_row[vs._PGVectorStore__vs.content_column] == image_uris[i]
uri_embedding = embeddings_service.embed_query(image_uris[i])
image_embedding = image_embedding_service.embed_image([image_uris[i]])[0]
actual_embedding = json.loads(
result_row[vs._AlloyDBVectorStore__vs.embedding_column]
result_row[vs._PGVectorStore__vs.embedding_column]
)
assert actual_embedding != pytest.approx(uri_embedding)
assert actual_embedding == pytest.approx(image_embedding)
assert result_row["image_id"] == str(i)
assert result_row["source"] == "google.com"
assert (
result_row[vs._AlloyDBVectorStore__vs.metadata_json_column]["image_uri"]
result_row[vs._PGVectorStore__vs.metadata_json_column]["image_uri"]
== image_uris[i]
)
await vs.adelete(ids)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_vectorstore_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ async def test_asimilarity_search(self, vs):
results = await vs.asimilarity_search("foo", k=1)
assert len(results) == 1
assert results == [Document(page_content="foo", id=ids[0])]
results = await vs.asimilarity_search("foo", k=1, filter="content = 'bar'")
results = await vs.asimilarity_search("foo", k=1, filter={"content": "bar"})
assert results == [Document(page_content="bar", id=ids[1])]

async def test_asimilarity_search_score(self, vs):
Expand Down Expand Up @@ -242,7 +242,7 @@ async def test_amax_marginal_relevance_search(self, vs):
results = await vs.amax_marginal_relevance_search("bar")
assert results[0] == Document(page_content="bar", id=ids[1])
results = await vs.amax_marginal_relevance_search(
"bar", filter="content = 'boo'"
"bar", filter={"content": "boo"}
)
assert results[0] == Document(page_content="boo", id=ids[3])

Expand Down Expand Up @@ -342,8 +342,8 @@ def test_similarity_search(self, vs_custom):
results = vs_custom.similarity_search("foo", k=1)
assert len(results) == 1
assert results == [Document(page_content="foo", id=ids[0])]
results = vs_custom.similarity_search("foo", k=1, filter="mycontent = 'bar'")
assert results == [Document(page_content="bar", id=ids[1])]
results = vs_custom.similarity_search("foo", k=1, filter={"mycontent": "boo"})
assert results == [Document(page_content="boo", id=ids[3])]

def test_similarity_search_score(self, vs_custom):
results = vs_custom.similarity_search_with_score("foo")
Expand All @@ -364,7 +364,7 @@ def test_max_marginal_relevance_search(self, vs_custom):
results = vs_custom.max_marginal_relevance_search("bar")
assert results[0] == Document(page_content="bar", id=ids[1])
results = vs_custom.max_marginal_relevance_search(
"bar", filter="mycontent = 'boo'"
"bar", filter={"mycontent": "boo"}
)
assert results[0] == Document(page_content="boo", id=ids[3])

Expand Down
43 changes: 32 additions & 11 deletions tests/test_vectorstore_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@
ScaNNIndex,
)

DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_")
DEFAULT_TABLE_ASYNC = "test_table" + str(uuid.uuid4()).replace("-", "_")
DEFAULT_TABLE_OMNI = "test_table" + str(uuid.uuid4()).replace("-", "_")
CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_")
DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX
DEFAULT_INDEX_NAME_ASYNC = DEFAULT_TABLE_ASYNC + DEFAULT_INDEX_NAME_SUFFIX
DEFAULT_INDEX_NAME_OMNI = DEFAULT_TABLE_OMNI + DEFAULT_INDEX_NAME_SUFFIX
DEFAULT_TABLE_UUID = str(uuid.uuid4()).replace("-", "_")
DEFAULT_TABLE_ASYNC_UUID = str(uuid.uuid4()).replace("-", "_")
OMNI_UUID = str(uuid.uuid4()).replace("-", "_")
DEFAULT_TABLE = "test_table" + DEFAULT_TABLE_UUID
DEFAULT_TABLE_ASYNC = "test_table" + DEFAULT_TABLE_ASYNC_UUID
DEFAULT_TABLE_OMNI = "test_table" + OMNI_UUID
DEFAULT_INDEX_NAME = DEFAULT_INDEX_NAME_SUFFIX + DEFAULT_TABLE_UUID
DEFAULT_INDEX_NAME_ASYNC = DEFAULT_INDEX_NAME_SUFFIX + DEFAULT_TABLE_ASYNC_UUID
DEFAULT_INDEX_NAME_OMNI = DEFAULT_INDEX_NAME_SUFFIX + OMNI_UUID
VECTOR_SIZE = 768

embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)
Expand Down Expand Up @@ -159,6 +161,21 @@ async def test_is_valid_index(self, vs):
is_valid = vs.is_valid_index("invalid_index")
assert is_valid == False

async def test_aapply_vector_index_ivf(self, vs):
index = IVFIndex(
name=DEFAULT_INDEX_NAME, distance_strategy=DistanceStrategy.EUCLIDEAN
)
vs.apply_vector_index(index, concurrently=True)
assert vs.is_valid_index(DEFAULT_INDEX_NAME)
index = IVFIndex(
name="secondindex",
distance_strategy=DistanceStrategy.INNER_PRODUCT,
)
vs.apply_vector_index(index)
assert vs.is_valid_index("secondindex")
vs.drop_vector_index("secondindex")
vs.drop_vector_index(DEFAULT_INDEX_NAME)


@pytest.mark.asyncio(loop_scope="class")
class TestAsyncIndex:
Expand Down Expand Up @@ -284,7 +301,9 @@ async def test_is_valid_index(self, vs):
assert is_valid == False

async def test_aapply_vector_index_ivf(self, vs):
index = IVFIndex(distance_strategy=DistanceStrategy.EUCLIDEAN)
index = IVFIndex(
name=DEFAULT_INDEX_NAME_ASYNC, distance_strategy=DistanceStrategy.EUCLIDEAN
)
await vs.aapply_vector_index(index, concurrently=True)
assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC)
index = IVFIndex(
Expand All @@ -294,10 +313,12 @@ async def test_aapply_vector_index_ivf(self, vs):
await vs.aapply_vector_index(index)
assert await vs.ais_valid_index("secondindex")
await vs.adrop_vector_index("secondindex")
await vs.adrop_vector_index()
await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC)

async def test_aapply_alloydb_scann_index_ScaNN(self, omni_vs):
index = ScaNNIndex(distance_strategy=DistanceStrategy.EUCLIDEAN)
index = ScaNNIndex(
name=DEFAULT_INDEX_NAME_OMNI, distance_strategy=DistanceStrategy.EUCLIDEAN
)
await omni_vs.aset_maintenance_work_mem(index.num_leaves, VECTOR_SIZE)
await omni_vs.aapply_vector_index(index, concurrently=True)
assert await omni_vs.ais_valid_index(DEFAULT_INDEX_NAME_OMNI)
Expand All @@ -307,4 +328,4 @@ async def test_aapply_alloydb_scann_index_ScaNN(self, omni_vs):
await omni_vs.aapply_vector_index(index)
assert await omni_vs.ais_valid_index("secondindex")
await omni_vs.adrop_vector_index("secondindex")
await omni_vs.adrop_vector_index()
await omni_vs.adrop_vector_index(DEFAULT_INDEX_NAME_OMNI)
Loading