From cd80093ed88519f65bfc5cb4545cc23aa0c25efc Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Wed, 23 Jul 2025 08:28:39 +0000 Subject: [PATCH 01/12] chore: Refactor AlloyDBEngine to depend on PGEngine --- pyproject.toml | 3 +- requirements.txt | 1 + .../async_vectorstore.py | 11 +- src/langchain_google_alloydb_pg/engine.py | 273 ++---------------- src/langchain_google_alloydb_pg/indexes.py | 116 +------- .../vectorstore.py | 2 +- tests/test_engine.py | 22 ++ tests/test_indexes.py | 40 --- 8 files changed, 52 insertions(+), 416 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 21e9808c..7716147d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,8 @@ dependencies = [ "numpy>=1.24.4, <=2.2.6; python_version == '3.10'", "numpy>=1.24.4, <=2.0.2; python_version <= '3.9'", "pgvector>=0.2.5, <1.0.0", - "SQLAlchemy[asyncio]>=2.0.25, <3.0.0" + "SQLAlchemy[asyncio]>=2.0.25, <3.0.0", + "langchain-postgres>=0.0.15", ] classifiers = [ diff --git a/requirements.txt b/requirements.txt index 4fb24b63..0748c427 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ numpy==2.0.2; python_version <= "3.9" pgvector==0.4.1 SQLAlchemy[asyncio]==2.0.41 langgraph==0.5.0 +langchain-postgres==0.0.15 \ No newline at end of file diff --git a/src/langchain_google_alloydb_pg/async_vectorstore.py b/src/langchain_google_alloydb_pg/async_vectorstore.py index e44393df..ef25658a 100644 --- a/src/langchain_google_alloydb_pg/async_vectorstore.py +++ b/src/langchain_google_alloydb_pg/async_vectorstore.py @@ -28,19 +28,18 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore, utils -from sqlalchemy import RowMapping, text -from sqlalchemy.ext.asyncio import AsyncEngine - -from .engine import AlloyDBEngine -from .indexes import ( +from langchain_postgres.v2.indexes import ( DEFAULT_DISTANCE_STRATEGY, DEFAULT_INDEX_NAME_SUFFIX, BaseIndex, DistanceStrategy, ExactNearestNeighbor, QueryOptions, - ScaNNIndex, ) +from sqlalchemy import RowMapping, text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import AlloyDBEngine COMPARISONS_TO_NATIVE = { "$eq": "=", diff --git a/src/langchain_google_alloydb_pg/engine.py b/src/langchain_google_alloydb_pg/engine.py index 48799baf..840b2f3a 100644 --- a/src/langchain_google_alloydb_pg/engine.py +++ b/src/langchain_google_alloydb_pg/engine.py @@ -15,12 +15,10 @@ import asyncio from concurrent.futures import Future -from dataclasses import dataclass from threading import Thread from typing import ( TYPE_CHECKING, Any, - Awaitable, Mapping, Optional, TypeVar, @@ -35,10 +33,11 @@ IPTypes, RefreshStrategy, ) -from sqlalchemy import MetaData, RowMapping, Table, text +from langchain_postgres import Column, PGEngine +from sqlalchemy import MetaData, Table, text from sqlalchemy.engine import URL from sqlalchemy.exc import InvalidRequestError -from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlalchemy.ext.asyncio import create_async_engine from .version import __version__ @@ -90,60 +89,10 @@ async def _get_iam_principal_email( return email.replace(".gserviceaccount.com", "") -@dataclass -class Column: - name: str - data_type: str - nullable: bool = True - - def __post_init__(self) -> None: - """Check if initialization parameters are valid. - - Raises: - ValueError: If Column name is not string. - ValueError: If data_type is not type string. - """ - - if not isinstance(self.name, str): - raise ValueError("Column name must be type string") - if not isinstance(self.data_type, str): - raise ValueError("Column data_type must be type string") - - -class AlloyDBEngine: +class AlloyDBEngine(PGEngine): """A class for managing connections to a AlloyDB database.""" _connector: Optional[AsyncConnector] = None - _default_loop: Optional[asyncio.AbstractEventLoop] = None - _default_thread: Optional[Thread] = None - __create_key = object() - - def __init__( - self, - key: object, - pool: AsyncEngine, - loop: Optional[asyncio.AbstractEventLoop], - thread: Optional[Thread], - ) -> None: - """AlloyDBEngine constructor. - - Args: - key (object): Prevent direct constructor usage. - engine (AsyncEngine): Async engine connection pool. - loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine. - thread (Optional[Thread]): Thread used to create the engine async. - - Raises: - Exception: If the constructor is called directly by the user. - """ - - if key != AlloyDBEngine.__create_key: - raise Exception( - "Only create class through 'create' or 'create_sync' methods!" - ) - self._pool = pool - self._loop = loop - self._thread = thread @classmethod def __start_background_loop( @@ -317,7 +266,7 @@ async def getconn() -> asyncpg.Connection: async_creator=getconn, **engine_args, ) - return cls(cls.__create_key, engine, loop, thread) + return cls(PGEngine._PGEngine__create_key, engine, loop, thread) # type: ignore @classmethod async def afrom_instance( @@ -367,13 +316,21 @@ async def afrom_instance( return await asyncio.wrap_future(future) @classmethod - def from_engine( - cls: type[AlloyDBEngine], - engine: AsyncEngine, - loop: Optional[asyncio.AbstractEventLoop] = None, + def from_connection_string( + cls, + url: str | URL, + **kwargs: Any, ) -> AlloyDBEngine: - """Create an AlloyDBEngine instance from an AsyncEngine.""" - return cls(cls.__create_key, engine, loop, None) + """Create an AlloyDBEngine instance from arguments + Args: + url (Optional[str]): the URL used to connect to a database. Use url or set other arguments. + Raises: + ValueError: If not all database url arguments are specified + Returns: + AlloyDBEngine + """ + + return AlloyDBEngine.from_engine_args(url=url, **kwargs) @classmethod def from_engine_args( @@ -408,197 +365,7 @@ def from_engine_args( raise ValueError("Driver must be type 'postgresql+asyncpg'") engine = create_async_engine(url, **kwargs) - return cls(cls.__create_key, engine, cls._default_loop, cls._default_thread) - - async def _run_as_async(self, coro: Awaitable[T]) -> T: - """Run an async coroutine asynchronously""" - # If a loop has not been provided, attempt to run in current thread - if not self._loop: - return await coro - # Otherwise, run in the background thread - return await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(coro, self._loop) - ) - - def _run_as_sync(self, coro: Awaitable[T]) -> T: - """Run an async coroutine synchronously""" - if not self._loop: - raise Exception( - "Engine was initialized without a background loop and cannot call sync methods." - ) - return asyncio.run_coroutine_threadsafe(coro, self._loop).result() - - async def close(self) -> None: - """Dispose of connection pool""" - await self._run_as_async(self._pool.dispose()) - - async def _ainit_vectorstore_table( - self, - table_name: str, - vector_size: int, - schema_name: str = "public", - content_column: str = "content", - embedding_column: str = "embedding", - metadata_columns: list[Column] = [], - metadata_json_column: str = "langchain_metadata", - id_column: Union[str, Column] = "langchain_id", - overwrite_existing: bool = False, - store_metadata: bool = True, - ) -> None: - """ - Create a table for saving of vectors to be used with AlloyDBVectorStore. - - Args: - table_name (str): The Postgres database table name. - vector_size (int): Vector size for the embedding model to be used. - schema_name (str): The schema name. - Default: "public". - content_column (str): Name of the column to store document content. - Default: "page_content". - embedding_column (str) : Name of the column to store vector embeddings. - Default: "embedding". - metadata_columns (list[Column]): A list of Columns to create for custom - metadata. Default: []. Optional. - metadata_json_column (str): The column to store extra metadata in JSON format. - Default: "langchain_metadata". Optional. - id_column (Union[str, Column]) : Column to store ids. - Default: "langchain_id" column name with data type UUID. Optional. - overwrite_existing (bool): Whether to drop existing table. Default: False. - store_metadata (bool): Whether to store metadata in the table. - Default: True. - - Raises: - :class:`DuplicateTableError `: if table already exists. - :class:`UndefinedObjectError `: if the data type of the id column is not a postgreSQL data type. - """ - async with self._pool.connect() as conn: - await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) - await conn.commit() - - if overwrite_existing: - async with self._pool.connect() as conn: - await conn.execute( - text(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"') - ) - await conn.commit() - - id_data_type = "UUID" if isinstance(id_column, str) else id_column.data_type - id_column_name = id_column if isinstance(id_column, str) else id_column.name - - query = f"""CREATE TABLE "{schema_name}"."{table_name}"( - "{id_column_name}" {id_data_type} PRIMARY KEY, - "{content_column}" TEXT NOT NULL, - "{embedding_column}" vector({vector_size}) NOT NULL""" - for column in metadata_columns: - nullable = "NOT NULL" if not column.nullable else "" - query += f',\n"{column.name}" {column.data_type} {nullable}' - if store_metadata: - query += f""",\n"{metadata_json_column}" JSON""" - query += "\n);" - - async with self._pool.connect() as conn: - await conn.execute(text(query)) - await conn.commit() - - async def ainit_vectorstore_table( - self, - table_name: str, - vector_size: int, - schema_name: str = "public", - content_column: str = "content", - embedding_column: str = "embedding", - metadata_columns: list[Column] = [], - metadata_json_column: str = "langchain_metadata", - id_column: Union[str, Column] = "langchain_id", - overwrite_existing: bool = False, - store_metadata: bool = True, - ) -> None: - """ - Create a table for saving of vectors to be used with AlloyDBVectorStore. - - Args: - table_name (str): The database table name. - vector_size (int): Vector size for the embedding model to be used. - schema_name (str): The schema name. - Default: "public". - content_column (str): Name of the column to store document content. - Default: "page_content". - embedding_column (str) : Name of the column to store vector embeddings. - Default: "embedding". - metadata_columns (list[Column]): A list of Columns to create for custom - metadata. Default: []. Optional. - metadata_json_column (str): The column to store extra metadata in JSON format. - Default: "langchain_metadata". Optional. - id_column (Union[str, Column]) : Column to store ids. - Default: "langchain_id" column name with data type UUID. Optional. - overwrite_existing (bool): Whether to drop existing table. Default: False. - store_metadata (bool): Whether to store metadata in the table. - Default: True. - """ - await self._run_as_async( - self._ainit_vectorstore_table( - table_name, - vector_size, - schema_name, - content_column, - embedding_column, - metadata_columns, - metadata_json_column, - id_column, - overwrite_existing, - store_metadata, - ) - ) - - def init_vectorstore_table( - self, - table_name: str, - vector_size: int, - schema_name: str = "public", - content_column: str = "content", - embedding_column: str = "embedding", - metadata_columns: list[Column] = [], - metadata_json_column: str = "langchain_metadata", - id_column: Union[str, Column] = "langchain_id", - overwrite_existing: bool = False, - store_metadata: bool = True, - ) -> None: - """ - Create a table for saving of vectors to be used with AlloyDBVectorStore. - - Args: - table_name (str): The database table name. - vector_size (int): Vector size for the embedding model to be used. - schema_name (str): The schema name. - Default: "public". - content_column (str): Name of the column to store document content. - Default: "page_content". - embedding_column (str) : Name of the column to store vector embeddings. - Default: "embedding". - metadata_columns (list[Column]): A list of Columns to create for custom - metadata. Default: []. Optional. - metadata_json_column (str): The column to store extra metadata in JSON format. - Default: "langchain_metadata". Optional. - id_column (Union[str, Column]) : Column to store ids. - Default: "langchain_id" column name with data type UUID. Optional. - overwrite_existing (bool): Whether to drop existing table. Default: False. - store_metadata (bool): Whether to store metadata in the table. - Default: True. - """ - self._run_as_sync( - self._ainit_vectorstore_table( - table_name, - vector_size, - schema_name, - content_column, - embedding_column, - metadata_columns, - metadata_json_column, - id_column, - overwrite_existing, - store_metadata, - ) - ) + return cls(PGEngine._PGEngine__create_key, engine, cls._default_loop, cls._default_thread) # type: ignore async def _ainit_chat_history_table( self, table_name: str, schema_name: str = "public" diff --git a/src/langchain_google_alloydb_pg/indexes.py b/src/langchain_google_alloydb_pg/indexes.py index 051de869..add2c75e 100644 --- a/src/langchain_google_alloydb_pg/indexes.py +++ b/src/langchain_google_alloydb_pg/indexes.py @@ -12,124 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import enum import warnings -from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Optional - -@dataclass -class StrategyMixin: - operator: str - search_function: str - index_function: str - - -class DistanceStrategy(StrategyMixin, enum.Enum): - """Enumerator of the Distance strategies.""" - - EUCLIDEAN = "<->", "l2_distance", "vector_l2_ops" - COSINE_DISTANCE = "<=>", "cosine_distance", "vector_cosine_ops" - INNER_PRODUCT = "<#>", "inner_product", "vector_ip_ops" - - -DEFAULT_DISTANCE_STRATEGY: DistanceStrategy = DistanceStrategy.COSINE_DISTANCE -DEFAULT_INDEX_NAME_SUFFIX: str = "langchainvectorindex" - - -@dataclass -class BaseIndex(ABC): - name: Optional[str] = None - index_type: str = "base" - distance_strategy: DistanceStrategy = field( - default_factory=lambda: DistanceStrategy.COSINE_DISTANCE - ) - partial_indexes: Optional[list[str]] = None - extension_name: Optional[str] = None - - @abstractmethod - def index_options(self) -> str: - """Set index query options for vector store initialization.""" - raise NotImplementedError( - "index_options method must be implemented by subclass" - ) - - def get_index_function(self) -> str: - return self.distance_strategy.index_function - - -@dataclass -class ExactNearestNeighbor(BaseIndex): - index_type: str = "exactnearestneighbor" - - -@dataclass -class QueryOptions(ABC): - @abstractmethod - def to_parameter(self) -> list[str]: - """Convert index attributes to list of configurations.""" - raise NotImplementedError("to_parameter method must be implemented by subclass") - - @abstractmethod - def to_string(self) -> str: - """Convert index attributes to string.""" - raise NotImplementedError("to_string method must be implemented by subclass") - - -@dataclass -class HNSWIndex(BaseIndex): - index_type: str = "hnsw" - m: int = 16 - ef_construction: int = 64 - - def index_options(self) -> str: - """Set index query options for vector store initialization.""" - return f"(m = {self.m}, ef_construction = {self.ef_construction})" - - -@dataclass -class HNSWQueryOptions(QueryOptions): - ef_search: int = 40 - - def to_parameter(self) -> list[str]: - """Convert index attributes to list of configurations.""" - return [f"hnsw.ef_search = {self.ef_search}"] - - def to_string(self) -> str: - """Convert index attributes to string.""" - warnings.warn( - "to_string is deprecated, use to_parameter instead.", - DeprecationWarning, - ) - return f"hnsw.ef_search = {self.ef_search}" - - -@dataclass -class IVFFlatIndex(BaseIndex): - index_type: str = "ivfflat" - lists: int = 100 - - def index_options(self) -> str: - """Set index query options for vector store initialization.""" - return f"(lists = {self.lists})" - - -@dataclass -class IVFFlatQueryOptions(QueryOptions): - probes: int = 1 - - def to_parameter(self) -> list[str]: - """Convert index attributes to list of configurations.""" - return [f"ivfflat.probes = {self.probes}"] - - def to_string(self) -> str: - """Convert index attributes to string.""" - warnings.warn( - "to_string is deprecated, use to_parameter instead.", - DeprecationWarning, - ) - return f"ivfflat.probes = {self.probes}" +from langchain_postgres.v2.indexes import BaseIndex, DistanceStrategy, QueryOptions @dataclass diff --git a/src/langchain_google_alloydb_pg/vectorstore.py b/src/langchain_google_alloydb_pg/vectorstore.py index 4a083892..0666758d 100644 --- a/src/langchain_google_alloydb_pg/vectorstore.py +++ b/src/langchain_google_alloydb_pg/vectorstore.py @@ -23,7 +23,7 @@ from .async_vectorstore import AsyncAlloyDBVectorStore from .engine import AlloyDBEngine -from .indexes import ( +from langchain_postgres.v2.indexes import ( DEFAULT_DISTANCE_STRATEGY, BaseIndex, DistanceStrategy, diff --git a/tests/test_engine.py b/tests/test_engine.py index a7764694..3faa1afa 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -244,6 +244,28 @@ async def getconn() -> asyncpg.Connection: await aexecute(engine, "SELECT 1") await engine.close() + async def test_from_connection_string( + self, + db_name, + user, + password, + ): + port = "5432" + url = f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{db_name}" + engine = AlloyDBEngine.from_connection_string( + url, + echo=True, + poolclass=NullPool, + ) + await aexecute(engine, "SELECT 1") + await engine.close() + + engine = AlloyDBEngine.from_connection_string( + URL.create("postgresql+asyncpg", user, password, host, port, db_name) + ) + await aexecute(engine, "SELECT 1") + await engine.close() + async def test_from_engine_args_url( self, db_name, diff --git a/tests/test_indexes.py b/tests/test_indexes.py index a441058c..f3675eee 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -16,10 +16,6 @@ from langchain_google_alloydb_pg.indexes import ( DistanceStrategy, - HNSWIndex, - HNSWQueryOptions, - IVFFlatIndex, - IVFFlatQueryOptions, IVFIndex, IVFQueryOptions, ScaNNIndex, @@ -48,42 +44,6 @@ def test_distance_strategy(self): scann_index = ScaNNIndex(distance_strategy=DistanceStrategy.INNER_PRODUCT) assert scann_index.get_index_function() == "dot_prod" - def test_hnsw_index(self): - index = HNSWIndex(name="test_index", m=32, ef_construction=128) - assert index.index_type == "hnsw" - assert index.m == 32 - assert index.ef_construction == 128 - assert index.index_options() == "(m = 32, ef_construction = 128)" - - def test_hnsw_query_options(self): - options = HNSWQueryOptions(ef_search=80) - assert options.to_parameter() == ["hnsw.ef_search = 80"] - - with warnings.catch_warnings(record=True) as w: - options.to_string() - - assert len(w) == 1 - assert "to_string is deprecated, use to_parameter instead." in str( - w[-1].message - ) - - def test_ivfflat_index(self): - index = IVFFlatIndex(name="test_index", lists=200) - assert index.index_type == "ivfflat" - assert index.lists == 200 - assert index.index_options() == "(lists = 200)" - - def test_ivfflat_query_options(self): - options = IVFFlatQueryOptions(probes=2) - assert options.to_parameter() == ["ivfflat.probes = 2"] - - with warnings.catch_warnings(record=True) as w: - options.to_string() - assert len(w) == 1 - assert "to_string is deprecated, use to_parameter instead." in str( - w[-1].message - ) - def test_ivf_index(self): index = IVFIndex(name="test_index", lists=200) assert index.index_type == "ivf" From eab4316f9f0fe8d27ae8feeb014eebcc6e49ee1c Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Wed, 23 Jul 2025 08:34:32 +0000 Subject: [PATCH 02/12] fix tests --- src/langchain_google_alloydb_pg/vectorstore.py | 6 +++--- tests/test_async_vectorstore_index.py | 10 +++++----- tests/test_async_vectorstore_search.py | 6 ++++-- tests/test_vectorstore_index.py | 10 ++++++---- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/langchain_google_alloydb_pg/vectorstore.py b/src/langchain_google_alloydb_pg/vectorstore.py index 0666758d..c9d9a8c1 100644 --- a/src/langchain_google_alloydb_pg/vectorstore.py +++ b/src/langchain_google_alloydb_pg/vectorstore.py @@ -20,9 +20,6 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore - -from .async_vectorstore import AsyncAlloyDBVectorStore -from .engine import AlloyDBEngine from langchain_postgres.v2.indexes import ( DEFAULT_DISTANCE_STRATEGY, BaseIndex, @@ -30,6 +27,9 @@ QueryOptions, ) +from .async_vectorstore import AsyncAlloyDBVectorStore +from .engine import AlloyDBEngine + class AlloyDBVectorStore(VectorStore): """Google AlloyDB Vector Store class""" diff --git a/tests/test_async_vectorstore_index.py b/tests/test_async_vectorstore_index.py index 317f3559..4eab9e04 100644 --- a/tests/test_async_vectorstore_index.py +++ b/tests/test_async_vectorstore_index.py @@ -21,16 +21,16 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding -from sqlalchemy import text - -from langchain_google_alloydb_pg import AlloyDBEngine -from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore -from langchain_google_alloydb_pg.indexes import ( +from langchain_postgres.v2.indexes import ( DEFAULT_INDEX_NAME_SUFFIX, DistanceStrategy, HNSWIndex, IVFFlatIndex, ) +from sqlalchemy import text + +from langchain_google_alloydb_pg import AlloyDBEngine +from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX diff --git a/tests/test_async_vectorstore_search.py b/tests/test_async_vectorstore_search.py index 50946c79..d1e69597 100644 --- a/tests/test_async_vectorstore_search.py +++ b/tests/test_async_vectorstore_search.py @@ -19,6 +19,10 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from langchain_postgres.v2.indexes import ( + DistanceStrategy, + HNSWQueryOptions, +) from metadata_filtering_data import FILTERING_TEST_CASES, METADATAS from PIL import Image from sqlalchemy import text @@ -26,8 +30,6 @@ from langchain_google_alloydb_pg import AlloyDBEngine, Column from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore from langchain_google_alloydb_pg.indexes import ( - DistanceStrategy, - HNSWQueryOptions, ScaNNQueryOptions, ) diff --git a/tests/test_vectorstore_index.py b/tests/test_vectorstore_index.py index c63c464f..ce48a423 100644 --- a/tests/test_vectorstore_index.py +++ b/tests/test_vectorstore_index.py @@ -22,14 +22,16 @@ import sqlalchemy from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding -from sqlalchemy import text - -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore -from langchain_google_alloydb_pg.indexes import ( +from langchain_postgres.v2.indexes import ( DEFAULT_INDEX_NAME_SUFFIX, DistanceStrategy, HNSWIndex, IVFFlatIndex, +) +from sqlalchemy import text + +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore +from langchain_google_alloydb_pg.indexes import ( IVFIndex, ScaNNIndex, ) From 572b8e7556cc0a564c0ee4b7382033cbda8f922e Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Wed, 23 Jul 2025 08:36:51 +0000 Subject: [PATCH 03/12] dependency fux --- pyproject.toml | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7716147d..3654be57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "numpy>=1.24.4, <3.0.0; python_version >= '3.11'", "numpy>=1.24.4, <=2.2.6; python_version == '3.10'", "numpy>=1.24.4, <=2.0.2; python_version <= '3.9'", - "pgvector>=0.2.5, <1.0.0", + "pgvector>=0.2.5, <0.4.0", "SQLAlchemy[asyncio]>=2.0.25, <3.0.0", "langchain-postgres>=0.0.15", ] diff --git a/requirements.txt b/requirements.txt index 0748c427..36d75d52 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ langchain-core==0.3.67 numpy==2.3.1; python_version >= "3.11" numpy==2.2.6; python_version == "3.10" numpy==2.0.2; python_version <= "3.9" -pgvector==0.4.1 +pgvector==0.4.0 SQLAlchemy[asyncio]==2.0.41 langgraph==0.5.0 langchain-postgres==0.0.15 \ No newline at end of file From de1f104f265bf9598a02c5a3434616b9c0cb2ffe Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Wed, 23 Jul 2025 08:38:28 +0000 Subject: [PATCH 04/12] dependency fix --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 36d75d52..01b49ef1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,6 @@ langchain-core==0.3.67 numpy==2.3.1; python_version >= "3.11" numpy==2.2.6; python_version == "3.10" numpy==2.0.2; python_version <= "3.9" -pgvector==0.4.0 SQLAlchemy[asyncio]==2.0.41 langgraph==0.5.0 langchain-postgres==0.0.15 \ No newline at end of file From 27611b89f987eb37d540591ef405126b82a2879b Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Wed, 23 Jul 2025 08:48:28 +0000 Subject: [PATCH 05/12] change imports --- tests/test_async_loader.py | 3 ++- tests/test_async_vectorstore.py | 3 ++- tests/test_async_vectorstore_from_methods.py | 3 ++- tests/test_async_vectorstore_search.py | 3 ++- tests/test_engine.py | 3 ++- tests/test_loader.py | 2 +- tests/test_standard_test_suite.py | 3 ++- tests/test_vectorstore.py | 3 ++- tests/test_vectorstore_embeddings.py | 4 ++-- tests/test_vectorstore_from_methods.py | 3 ++- tests/test_vectorstore_search.py | 7 ++++--- tests/util_tests/test_pgvector_migrator.py | 3 ++- 12 files changed, 25 insertions(+), 15 deletions(-) diff --git a/tests/test_async_loader.py b/tests/test_async_loader.py index 2bb887e1..b14cb402 100644 --- a/tests/test_async_loader.py +++ b/tests/test_async_loader.py @@ -19,9 +19,10 @@ import pytest import pytest_asyncio from langchain_core.documents import Document +from langchain_postgres import Column from sqlalchemy import text -from langchain_google_alloydb_pg import AlloyDBEngine, Column +from langchain_google_alloydb_pg import AlloyDBEngine from langchain_google_alloydb_pg.async_loader import ( AsyncAlloyDBDocumentSaver, AsyncAlloyDBLoader, diff --git a/tests/test_async_vectorstore.py b/tests/test_async_vectorstore.py index 66858280..9d0d413b 100644 --- a/tests/test_async_vectorstore.py +++ b/tests/test_async_vectorstore.py @@ -21,11 +21,12 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from langchain_postgres import Column from PIL import Image from sqlalchemy import text from sqlalchemy.engine.row import RowMapping -from langchain_google_alloydb_pg import AlloyDBEngine, Column +from langchain_google_alloydb_pg import AlloyDBEngine from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) diff --git a/tests/test_async_vectorstore_from_methods.py b/tests/test_async_vectorstore_from_methods.py index 42f68d21..b55a48af 100644 --- a/tests/test_async_vectorstore_from_methods.py +++ b/tests/test_async_vectorstore_from_methods.py @@ -20,10 +20,11 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from langchain_postgres import Column from sqlalchemy import text from sqlalchemy.engine.row import RowMapping -from langchain_google_alloydb_pg import AlloyDBEngine, Column +from langchain_google_alloydb_pg import AlloyDBEngine from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") diff --git a/tests/test_async_vectorstore_search.py b/tests/test_async_vectorstore_search.py index d1e69597..b1daf968 100644 --- a/tests/test_async_vectorstore_search.py +++ b/tests/test_async_vectorstore_search.py @@ -19,6 +19,7 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from langchain_postgres import Column from langchain_postgres.v2.indexes import ( DistanceStrategy, HNSWQueryOptions, @@ -27,7 +28,7 @@ from PIL import Image from sqlalchemy import text -from langchain_google_alloydb_pg import AlloyDBEngine, Column +from langchain_google_alloydb_pg import AlloyDBEngine from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore from langchain_google_alloydb_pg.indexes import ( ScaNNQueryOptions, diff --git a/tests/test_engine.py b/tests/test_engine.py index 3faa1afa..1b949642 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -21,13 +21,14 @@ import pytest_asyncio from google.cloud.alloydb.connector import AsyncConnector, IPTypes from langchain_core.embeddings import DeterministicFakeEmbedding +from langchain_postgres import Column from sqlalchemy import VARCHAR, text from sqlalchemy.engine import URL from sqlalchemy.engine.row import RowMapping from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.pool import NullPool -from langchain_google_alloydb_pg import AlloyDBEngine, Column +from langchain_google_alloydb_pg import AlloyDBEngine DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") diff --git a/tests/test_loader.py b/tests/test_loader.py index 8c883195..a9dee1a4 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -19,13 +19,13 @@ import pytest import pytest_asyncio from langchain_core.documents import Document +from langchain_postgres import Column from sqlalchemy import text from langchain_google_alloydb_pg import ( AlloyDBDocumentSaver, AlloyDBEngine, AlloyDBLoader, - Column, ) project_id = os.environ["PROJECT_ID"] diff --git a/tests/test_standard_test_suite.py b/tests/test_standard_test_suite.py index 93942d66..c596741f 100644 --- a/tests/test_standard_test_suite.py +++ b/tests/test_standard_test_suite.py @@ -17,11 +17,12 @@ import pytest import pytest_asyncio +from langchain_postgres import Column from langchain_tests.integration_tests import VectorStoreIntegrationTests from langchain_tests.integration_tests.vectorstores import EMBEDDING_SIZE from sqlalchemy import text -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore DEFAULT_TABLE = "test_table_standard_test_suite" + str(uuid.uuid4()) DEFAULT_TABLE_SYNC = "test_table_sync_standard_test_suite" + str(uuid.uuid4()) diff --git a/tests/test_vectorstore.py b/tests/test_vectorstore.py index 43c873ed..787e9d37 100644 --- a/tests/test_vectorstore.py +++ b/tests/test_vectorstore.py @@ -24,12 +24,13 @@ from google.cloud.alloydb.connector import AsyncConnector, IPTypes from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from langchain_postgres import Column from PIL import Image from sqlalchemy import text from sqlalchemy.engine.row import RowMapping from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()) diff --git a/tests/test_vectorstore_embeddings.py b/tests/test_vectorstore_embeddings.py index d65e0c14..284ec905 100644 --- a/tests/test_vectorstore_embeddings.py +++ b/tests/test_vectorstore_embeddings.py @@ -18,6 +18,8 @@ import pytest import pytest_asyncio from langchain_core.documents import Document +from langchain_postgres import Column +from langchain_postgres.v2.indexes import DistanceStrategy, HNSWQueryOptions from sqlalchemy import text from langchain_google_alloydb_pg import ( @@ -25,9 +27,7 @@ AlloyDBEngine, AlloyDBModelManager, AlloyDBVectorStore, - Column, ) -from langchain_google_alloydb_pg.indexes import DistanceStrategy, HNSWQueryOptions DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE_SYNC = "test_table" + str(uuid.uuid4()).replace("-", "_") diff --git a/tests/test_vectorstore_from_methods.py b/tests/test_vectorstore_from_methods.py index 975620af..928cec66 100644 --- a/tests/test_vectorstore_from_methods.py +++ b/tests/test_vectorstore_from_methods.py @@ -20,11 +20,12 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from langchain_postgres import Column from sqlalchemy import VARCHAR, text from sqlalchemy.engine.row import RowMapping from sqlalchemy.ext.asyncio import create_async_engine -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()).replace("-", "_") diff --git a/tests/test_vectorstore_search.py b/tests/test_vectorstore_search.py index 721f1e95..ade340b1 100644 --- a/tests/test_vectorstore_search.py +++ b/tests/test_vectorstore_search.py @@ -19,12 +19,13 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from langchain_postgres import Column +from langchain_postgres.v2.indexes import DistanceStrategy, HNSWQueryOptions from metadata_filtering_data import FILTERING_TEST_CASES, METADATAS, NEGATIVE_TEST_CASES from PIL import Image -from sqlalchemy import RowMapping, Sequence, text +from sqlalchemy import text -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column -from langchain_google_alloydb_pg.indexes import DistanceStrategy, HNSWQueryOptions +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE_SYNC = "test_table" + str(uuid.uuid4()).replace("-", "_") diff --git a/tests/util_tests/test_pgvector_migrator.py b/tests/util_tests/test_pgvector_migrator.py index 2ba97e16..c37e1d6f 100644 --- a/tests/util_tests/test_pgvector_migrator.py +++ b/tests/util_tests/test_pgvector_migrator.py @@ -22,9 +22,10 @@ import pytest import pytest_asyncio from langchain_core.embeddings import FakeEmbeddings +from langchain_postgres import Column from sqlalchemy import RowMapping, text -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore from langchain_google_alloydb_pg.utils.pgvector_migrator import ( __concurrent_batch_insert, aextract_pgvector_collection, From ef0cce710bcd8ab0abf1ea27b8f5ab6ac4762fb6 Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Wed, 23 Jul 2025 09:26:38 +0000 Subject: [PATCH 06/12] change imports --- docs/document_loader.ipynb | 2 +- docs/vector_store.ipynb | 3 ++- samples/index_tuning_sample/create_vector_embeddings.py | 3 ++- samples/index_tuning_sample/index_search.py | 8 +++++--- samples/langchain_quick_start.ipynb | 6 ++++-- .../migrations/migrate_pinecone_vectorstore_to_alloydb.py | 2 +- 6 files changed, 15 insertions(+), 9 deletions(-) diff --git a/docs/document_loader.ipynb b/docs/document_loader.ipynb index bdd887e9..8b441bcc 100644 --- a/docs/document_loader.ipynb +++ b/docs/document_loader.ipynb @@ -415,7 +415,7 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain_google_alloydb_pg import Column\n", + "from langchain_postgres import Column\n", "\n", "await engine.ainit_document_table(\n", " table_name=TABLE_NAME,\n", diff --git a/docs/vector_store.ipynb b/docs/vector_store.ipynb index 8b4778c6..d65b1131 100644 --- a/docs/vector_store.ipynb +++ b/docs/vector_store.ipynb @@ -606,7 +606,8 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain_google_alloydb_pg import Column\n", + "from langchain_postgres import Column\n", + "\n", "\n", "# Set table name\n", "TABLE_NAME = \"vectorstore_custom\"\n", diff --git a/samples/index_tuning_sample/create_vector_embeddings.py b/samples/index_tuning_sample/create_vector_embeddings.py index a154d976..37c24cf5 100644 --- a/samples/index_tuning_sample/create_vector_embeddings.py +++ b/samples/index_tuning_sample/create_vector_embeddings.py @@ -19,8 +19,9 @@ import sqlalchemy from langchain_community.document_loaders import CSVLoader from langchain_google_vertexai import VertexAIEmbeddings +from langchain_postgres import Column -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore EMBEDDING_COUNT = 100000 VECTOR_SIZE = 768 diff --git a/samples/index_tuning_sample/index_search.py b/samples/index_tuning_sample/index_search.py index 0a0f107e..3dea0312 100644 --- a/samples/index_tuning_sample/index_search.py +++ b/samples/index_tuning_sample/index_search.py @@ -32,12 +32,14 @@ vector_table_name, ) from langchain_google_vertexai import VertexAIEmbeddings - -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore -from langchain_google_alloydb_pg.indexes import ( +from langchain_postgres.v2.indexes import ( HNSWIndex, HNSWQueryOptions, IVFFlatIndex, +) + +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore +from langchain_google_alloydb_pg.indexes import ( IVFIndex, ScaNNIndex, ) diff --git a/samples/langchain_quick_start.ipynb b/samples/langchain_quick_start.ipynb index ba0db1a8..5b1836a5 100644 --- a/samples/langchain_quick_start.ipynb +++ b/samples/langchain_quick_start.ipynb @@ -601,7 +601,8 @@ }, "outputs": [], "source": [ - "from langchain_google_alloydb_pg import AlloyDBEngine, Column, AlloyDBLoader\n", + "from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBLoader\n", + "from langchain_postgres.v2 import Column\n", "\n", "engine = AlloyDBEngine.from_instance(\n", " project_id=project_id,\n", @@ -708,7 +709,8 @@ }, "outputs": [], "source": [ - "from langchain_google_alloydb_pg import AlloyDBEngine, Column\n", + "from langchain_google_alloydb_pg import AlloyDBEngine\n", + "from langchain_postgres import Column\n", "\n", "sample_vector_table_name = \"movie_vector_table_samples\"\n", "\n", diff --git a/samples/migrations/migrate_pinecone_vectorstore_to_alloydb.py b/samples/migrations/migrate_pinecone_vectorstore_to_alloydb.py index d873ddf0..d32312cf 100644 --- a/samples/migrations/migrate_pinecone_vectorstore_to_alloydb.py +++ b/samples/migrations/migrate_pinecone_vectorstore_to_alloydb.py @@ -175,7 +175,7 @@ async def main( print("Langchain AlloyDB client initiated.") # [START pinecone_vectorstore_alloydb_migration_create_table] - from langchain_google_alloydb_pg import Column + from langchain_postgres import Column await alloydb_engine.ainit_vectorstore_table( table_name=alloydb_table, From 55b7af3143f25b0ff1ba8f451797a996611b8bef Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Wed, 23 Jul 2025 09:27:44 +0000 Subject: [PATCH 07/12] change imports --- samples/langchain_quick_start.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/samples/langchain_quick_start.ipynb b/samples/langchain_quick_start.ipynb index 5b1836a5..ff7aa094 100644 --- a/samples/langchain_quick_start.ipynb +++ b/samples/langchain_quick_start.ipynb @@ -602,7 +602,7 @@ "outputs": [], "source": [ "from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBLoader\n", - "from langchain_postgres.v2 import Column\n", + "from langchain_postgres import Column\n", "\n", "engine = AlloyDBEngine.from_instance(\n", " project_id=project_id,\n", From a54e62444e64cba89f9e03d712fa8b6c6506593b Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Mon, 4 Aug 2025 16:46:52 +0000 Subject: [PATCH 08/12] fix breaking imports --- docs/document_loader.ipynb | 2 +- docs/vector_store.ipynb | 2 +- .../create_vector_embeddings.py | 3 +-- samples/index_tuning_sample/index_search.py | 2 +- samples/langchain_quick_start.ipynb | 6 ++---- .../migrate_pinecone_vectorstore_to_alloydb.py | 2 +- src/langchain_google_alloydb_pg/__init__.py | 3 ++- .../async_vectorstore.py | 2 +- src/langchain_google_alloydb_pg/indexes.py | 14 +++++++++++++- src/langchain_google_alloydb_pg/vectorstore.py | 2 +- tests/test_async_loader.py | 3 +-- tests/test_async_vectorstore.py | 3 +-- tests/test_async_vectorstore_from_methods.py | 3 +-- tests/test_async_vectorstore_index.py | 2 +- tests/test_async_vectorstore_search.py | 5 ++--- tests/test_engine.py | 3 +-- tests/test_loader.py | 2 +- tests/test_standard_test_suite.py | 3 +-- tests/test_vectorstore.py | 3 +-- tests/test_vectorstore_embeddings.py | 4 ++-- tests/test_vectorstore_from_methods.py | 3 +-- tests/test_vectorstore_index.py | 2 +- tests/test_vectorstore_search.py | 5 ++--- tests/util_tests/test_pgvector_migrator.py | 3 +-- 24 files changed, 41 insertions(+), 41 deletions(-) diff --git a/docs/document_loader.ipynb b/docs/document_loader.ipynb index 8b441bcc..bdd887e9 100644 --- a/docs/document_loader.ipynb +++ b/docs/document_loader.ipynb @@ -415,7 +415,7 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain_postgres import Column\n", + "from langchain_google_alloydb_pg import Column\n", "\n", "await engine.ainit_document_table(\n", " table_name=TABLE_NAME,\n", diff --git a/docs/vector_store.ipynb b/docs/vector_store.ipynb index d65b1131..46fd80e9 100644 --- a/docs/vector_store.ipynb +++ b/docs/vector_store.ipynb @@ -606,7 +606,7 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain_postgres import Column\n", + "from langchain_google_alloydb_pg import Column\n", "\n", "\n", "# Set table name\n", diff --git a/samples/index_tuning_sample/create_vector_embeddings.py b/samples/index_tuning_sample/create_vector_embeddings.py index 37c24cf5..a154d976 100644 --- a/samples/index_tuning_sample/create_vector_embeddings.py +++ b/samples/index_tuning_sample/create_vector_embeddings.py @@ -19,9 +19,8 @@ import sqlalchemy from langchain_community.document_loaders import CSVLoader from langchain_google_vertexai import VertexAIEmbeddings -from langchain_postgres import Column -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column EMBEDDING_COUNT = 100000 VECTOR_SIZE = 768 diff --git a/samples/index_tuning_sample/index_search.py b/samples/index_tuning_sample/index_search.py index 3dea0312..6fc47e2d 100644 --- a/samples/index_tuning_sample/index_search.py +++ b/samples/index_tuning_sample/index_search.py @@ -32,7 +32,7 @@ vector_table_name, ) from langchain_google_vertexai import VertexAIEmbeddings -from langchain_postgres.v2.indexes import ( +from langchain_google_alloydb_pg.indexes import ( HNSWIndex, HNSWQueryOptions, IVFFlatIndex, diff --git a/samples/langchain_quick_start.ipynb b/samples/langchain_quick_start.ipynb index ff7aa094..b13a5d57 100644 --- a/samples/langchain_quick_start.ipynb +++ b/samples/langchain_quick_start.ipynb @@ -601,8 +601,7 @@ }, "outputs": [], "source": [ - "from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBLoader\n", - "from langchain_postgres import Column\n", + "from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBLoader, Column\n", "\n", "engine = AlloyDBEngine.from_instance(\n", " project_id=project_id,\n", @@ -709,8 +708,7 @@ }, "outputs": [], "source": [ - "from langchain_google_alloydb_pg import AlloyDBEngine\n", - "from langchain_postgres import Column\n", + "from langchain_google_alloydb_pg import AlloyDBEngine, Column\n", "\n", "sample_vector_table_name = \"movie_vector_table_samples\"\n", "\n", diff --git a/samples/migrations/migrate_pinecone_vectorstore_to_alloydb.py b/samples/migrations/migrate_pinecone_vectorstore_to_alloydb.py index d32312cf..d873ddf0 100644 --- a/samples/migrations/migrate_pinecone_vectorstore_to_alloydb.py +++ b/samples/migrations/migrate_pinecone_vectorstore_to_alloydb.py @@ -175,7 +175,7 @@ async def main( print("Langchain AlloyDB client initiated.") # [START pinecone_vectorstore_alloydb_migration_create_table] - from langchain_postgres import Column + from langchain_google_alloydb_pg import Column await alloydb_engine.ainit_vectorstore_table( table_name=alloydb_table, diff --git a/src/langchain_google_alloydb_pg/__init__.py b/src/langchain_google_alloydb_pg/__init__.py index 02eaa75e..96fc9705 100644 --- a/src/langchain_google_alloydb_pg/__init__.py +++ b/src/langchain_google_alloydb_pg/__init__.py @@ -15,11 +15,12 @@ from .chat_message_history import AlloyDBChatMessageHistory from .checkpoint import AlloyDBSaver from .embeddings import AlloyDBEmbeddings -from .engine import AlloyDBEngine, Column +from .engine import AlloyDBEngine from .loader import AlloyDBDocumentSaver, AlloyDBLoader from .model_manager import AlloyDBModel, AlloyDBModelManager from .vectorstore import AlloyDBVectorStore from .version import __version__ +from langchain_postgres import Column __all__ = [ "AlloyDBEngine", diff --git a/src/langchain_google_alloydb_pg/async_vectorstore.py b/src/langchain_google_alloydb_pg/async_vectorstore.py index ef25658a..e6c7f7fc 100644 --- a/src/langchain_google_alloydb_pg/async_vectorstore.py +++ b/src/langchain_google_alloydb_pg/async_vectorstore.py @@ -28,7 +28,7 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore, utils -from langchain_postgres.v2.indexes import ( +from langchain_google_alloydb_pg.indexes import ( DEFAULT_DISTANCE_STRATEGY, DEFAULT_INDEX_NAME_SUFFIX, BaseIndex, diff --git a/src/langchain_google_alloydb_pg/indexes.py b/src/langchain_google_alloydb_pg/indexes.py index add2c75e..7b1e5307 100644 --- a/src/langchain_google_alloydb_pg/indexes.py +++ b/src/langchain_google_alloydb_pg/indexes.py @@ -15,7 +15,19 @@ import warnings from dataclasses import dataclass, field -from langchain_postgres.v2.indexes import BaseIndex, DistanceStrategy, QueryOptions +from langchain_postgres.v2.indexes import ( + BaseIndex, + DistanceStrategy, + QueryOptions, + StrategyMixin, + DEFAULT_DISTANCE_STRATEGY, + DEFAULT_INDEX_NAME_SUFFIX, + ExactNearestNeighbor, + HNSWIndex, + HNSWQueryOptions, + IVFFlatIndex, + IVFFlatQueryOptions, +) @dataclass diff --git a/src/langchain_google_alloydb_pg/vectorstore.py b/src/langchain_google_alloydb_pg/vectorstore.py index c9d9a8c1..f7dc16b1 100644 --- a/src/langchain_google_alloydb_pg/vectorstore.py +++ b/src/langchain_google_alloydb_pg/vectorstore.py @@ -20,7 +20,7 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore -from langchain_postgres.v2.indexes import ( +from langchain_google_alloydb_pg.indexes import ( DEFAULT_DISTANCE_STRATEGY, BaseIndex, DistanceStrategy, diff --git a/tests/test_async_loader.py b/tests/test_async_loader.py index b14cb402..2bb887e1 100644 --- a/tests/test_async_loader.py +++ b/tests/test_async_loader.py @@ -19,10 +19,9 @@ import pytest import pytest_asyncio from langchain_core.documents import Document -from langchain_postgres import Column from sqlalchemy import text -from langchain_google_alloydb_pg import AlloyDBEngine +from langchain_google_alloydb_pg import AlloyDBEngine, Column from langchain_google_alloydb_pg.async_loader import ( AsyncAlloyDBDocumentSaver, AsyncAlloyDBLoader, diff --git a/tests/test_async_vectorstore.py b/tests/test_async_vectorstore.py index 9d0d413b..66858280 100644 --- a/tests/test_async_vectorstore.py +++ b/tests/test_async_vectorstore.py @@ -21,12 +21,11 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding -from langchain_postgres import Column from PIL import Image from sqlalchemy import text from sqlalchemy.engine.row import RowMapping -from langchain_google_alloydb_pg import AlloyDBEngine +from langchain_google_alloydb_pg import AlloyDBEngine, Column from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) diff --git a/tests/test_async_vectorstore_from_methods.py b/tests/test_async_vectorstore_from_methods.py index b55a48af..42f68d21 100644 --- a/tests/test_async_vectorstore_from_methods.py +++ b/tests/test_async_vectorstore_from_methods.py @@ -20,11 +20,10 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding -from langchain_postgres import Column from sqlalchemy import text from sqlalchemy.engine.row import RowMapping -from langchain_google_alloydb_pg import AlloyDBEngine +from langchain_google_alloydb_pg import AlloyDBEngine, Column from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") diff --git a/tests/test_async_vectorstore_index.py b/tests/test_async_vectorstore_index.py index 4eab9e04..92a151d4 100644 --- a/tests/test_async_vectorstore_index.py +++ b/tests/test_async_vectorstore_index.py @@ -21,7 +21,7 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding -from langchain_postgres.v2.indexes import ( +from langchain_google_alloydb_pg.indexes import ( DEFAULT_INDEX_NAME_SUFFIX, DistanceStrategy, HNSWIndex, diff --git a/tests/test_async_vectorstore_search.py b/tests/test_async_vectorstore_search.py index b1daf968..0bd411dc 100644 --- a/tests/test_async_vectorstore_search.py +++ b/tests/test_async_vectorstore_search.py @@ -19,8 +19,7 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding -from langchain_postgres import Column -from langchain_postgres.v2.indexes import ( +from langchain_google_alloydb_pg.indexes import ( DistanceStrategy, HNSWQueryOptions, ) @@ -28,7 +27,7 @@ from PIL import Image from sqlalchemy import text -from langchain_google_alloydb_pg import AlloyDBEngine +from langchain_google_alloydb_pg import AlloyDBEngine, Column from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore from langchain_google_alloydb_pg.indexes import ( ScaNNQueryOptions, diff --git a/tests/test_engine.py b/tests/test_engine.py index 1b949642..3faa1afa 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -21,14 +21,13 @@ import pytest_asyncio from google.cloud.alloydb.connector import AsyncConnector, IPTypes from langchain_core.embeddings import DeterministicFakeEmbedding -from langchain_postgres import Column from sqlalchemy import VARCHAR, text from sqlalchemy.engine import URL from sqlalchemy.engine.row import RowMapping from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.pool import NullPool -from langchain_google_alloydb_pg import AlloyDBEngine +from langchain_google_alloydb_pg import AlloyDBEngine, Column DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") diff --git a/tests/test_loader.py b/tests/test_loader.py index a9dee1a4..457dfc83 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -19,13 +19,13 @@ import pytest import pytest_asyncio from langchain_core.documents import Document -from langchain_postgres import Column from sqlalchemy import text from langchain_google_alloydb_pg import ( AlloyDBDocumentSaver, AlloyDBEngine, AlloyDBLoader, + Column ) project_id = os.environ["PROJECT_ID"] diff --git a/tests/test_standard_test_suite.py b/tests/test_standard_test_suite.py index c596741f..93942d66 100644 --- a/tests/test_standard_test_suite.py +++ b/tests/test_standard_test_suite.py @@ -17,12 +17,11 @@ import pytest import pytest_asyncio -from langchain_postgres import Column from langchain_tests.integration_tests import VectorStoreIntegrationTests from langchain_tests.integration_tests.vectorstores import EMBEDDING_SIZE from sqlalchemy import text -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column DEFAULT_TABLE = "test_table_standard_test_suite" + str(uuid.uuid4()) DEFAULT_TABLE_SYNC = "test_table_sync_standard_test_suite" + str(uuid.uuid4()) diff --git a/tests/test_vectorstore.py b/tests/test_vectorstore.py index 787e9d37..43c873ed 100644 --- a/tests/test_vectorstore.py +++ b/tests/test_vectorstore.py @@ -24,13 +24,12 @@ from google.cloud.alloydb.connector import AsyncConnector, IPTypes from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding -from langchain_postgres import Column from PIL import Image from sqlalchemy import text from sqlalchemy.engine.row import RowMapping from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()) diff --git a/tests/test_vectorstore_embeddings.py b/tests/test_vectorstore_embeddings.py index 284ec905..c92bb298 100644 --- a/tests/test_vectorstore_embeddings.py +++ b/tests/test_vectorstore_embeddings.py @@ -18,8 +18,7 @@ import pytest import pytest_asyncio from langchain_core.documents import Document -from langchain_postgres import Column -from langchain_postgres.v2.indexes import DistanceStrategy, HNSWQueryOptions +from langchain_google_alloydb_pg.indexes import DistanceStrategy, HNSWQueryOptions from sqlalchemy import text from langchain_google_alloydb_pg import ( @@ -27,6 +26,7 @@ AlloyDBEngine, AlloyDBModelManager, AlloyDBVectorStore, + Column ) DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") diff --git a/tests/test_vectorstore_from_methods.py b/tests/test_vectorstore_from_methods.py index 928cec66..975620af 100644 --- a/tests/test_vectorstore_from_methods.py +++ b/tests/test_vectorstore_from_methods.py @@ -20,12 +20,11 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding -from langchain_postgres import Column from sqlalchemy import VARCHAR, text from sqlalchemy.engine.row import RowMapping from sqlalchemy.ext.asyncio import create_async_engine -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()).replace("-", "_") diff --git a/tests/test_vectorstore_index.py b/tests/test_vectorstore_index.py index ce48a423..857afb60 100644 --- a/tests/test_vectorstore_index.py +++ b/tests/test_vectorstore_index.py @@ -22,7 +22,7 @@ import sqlalchemy from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding -from langchain_postgres.v2.indexes import ( +from langchain_google_alloydb_pg.indexes import ( DEFAULT_INDEX_NAME_SUFFIX, DistanceStrategy, HNSWIndex, diff --git a/tests/test_vectorstore_search.py b/tests/test_vectorstore_search.py index ade340b1..f83e5d53 100644 --- a/tests/test_vectorstore_search.py +++ b/tests/test_vectorstore_search.py @@ -19,13 +19,12 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding -from langchain_postgres import Column -from langchain_postgres.v2.indexes import DistanceStrategy, HNSWQueryOptions +from langchain_google_alloydb_pg.indexes import DistanceStrategy, HNSWQueryOptions from metadata_filtering_data import FILTERING_TEST_CASES, METADATAS, NEGATIVE_TEST_CASES from PIL import Image from sqlalchemy import text -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE_SYNC = "test_table" + str(uuid.uuid4()).replace("-", "_") diff --git a/tests/util_tests/test_pgvector_migrator.py b/tests/util_tests/test_pgvector_migrator.py index c37e1d6f..2ba97e16 100644 --- a/tests/util_tests/test_pgvector_migrator.py +++ b/tests/util_tests/test_pgvector_migrator.py @@ -22,10 +22,9 @@ import pytest import pytest_asyncio from langchain_core.embeddings import FakeEmbeddings -from langchain_postgres import Column from sqlalchemy import RowMapping, text -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column from langchain_google_alloydb_pg.utils.pgvector_migrator import ( __concurrent_batch_insert, aextract_pgvector_collection, From c7d016c162cfcac65cc4aef5eda5e646bf2753f4 Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Tue, 5 Aug 2025 06:17:49 +0000 Subject: [PATCH 09/12] linter fix and add vector index tests --- samples/index_tuning_sample/index_search.py | 6 +-- src/langchain_google_alloydb_pg/__init__.py | 3 +- .../async_vectorstore.py | 5 ++- src/langchain_google_alloydb_pg/indexes.py | 22 +++++----- .../vectorstore.py | 1 + tests/test_async_vectorstore_index.py | 8 ++-- tests/test_async_vectorstore_search.py | 6 +-- tests/test_indexes.py | 42 ++++++++++++++++++- tests/test_loader.py | 2 +- tests/test_vectorstore_embeddings.py | 4 +- tests/test_vectorstore_index.py | 8 ++-- tests/test_vectorstore_search.py | 2 +- 12 files changed, 73 insertions(+), 36 deletions(-) diff --git a/samples/index_tuning_sample/index_search.py b/samples/index_tuning_sample/index_search.py index 6fc47e2d..0a0f107e 100644 --- a/samples/index_tuning_sample/index_search.py +++ b/samples/index_tuning_sample/index_search.py @@ -32,14 +32,12 @@ vector_table_name, ) from langchain_google_vertexai import VertexAIEmbeddings + +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore from langchain_google_alloydb_pg.indexes import ( HNSWIndex, HNSWQueryOptions, IVFFlatIndex, -) - -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore -from langchain_google_alloydb_pg.indexes import ( IVFIndex, ScaNNIndex, ) diff --git a/src/langchain_google_alloydb_pg/__init__.py b/src/langchain_google_alloydb_pg/__init__.py index 96fc9705..20ceb71b 100644 --- a/src/langchain_google_alloydb_pg/__init__.py +++ b/src/langchain_google_alloydb_pg/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from langchain_postgres import Column + from .chat_message_history import AlloyDBChatMessageHistory from .checkpoint import AlloyDBSaver from .embeddings import AlloyDBEmbeddings @@ -20,7 +22,6 @@ from .model_manager import AlloyDBModel, AlloyDBModelManager from .vectorstore import AlloyDBVectorStore from .version import __version__ -from langchain_postgres import Column __all__ = [ "AlloyDBEngine", diff --git a/src/langchain_google_alloydb_pg/async_vectorstore.py b/src/langchain_google_alloydb_pg/async_vectorstore.py index e6c7f7fc..c425619f 100644 --- a/src/langchain_google_alloydb_pg/async_vectorstore.py +++ b/src/langchain_google_alloydb_pg/async_vectorstore.py @@ -28,6 +28,9 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore, utils +from sqlalchemy import RowMapping, text +from sqlalchemy.ext.asyncio import AsyncEngine + from langchain_google_alloydb_pg.indexes import ( DEFAULT_DISTANCE_STRATEGY, DEFAULT_INDEX_NAME_SUFFIX, @@ -36,8 +39,6 @@ ExactNearestNeighbor, QueryOptions, ) -from sqlalchemy import RowMapping, text -from sqlalchemy.ext.asyncio import AsyncEngine from .engine import AlloyDBEngine diff --git a/src/langchain_google_alloydb_pg/indexes.py b/src/langchain_google_alloydb_pg/indexes.py index 7b1e5307..48f5974f 100644 --- a/src/langchain_google_alloydb_pg/indexes.py +++ b/src/langchain_google_alloydb_pg/indexes.py @@ -16,17 +16,17 @@ from dataclasses import dataclass, field from langchain_postgres.v2.indexes import ( - BaseIndex, - DistanceStrategy, - QueryOptions, - StrategyMixin, - DEFAULT_DISTANCE_STRATEGY, - DEFAULT_INDEX_NAME_SUFFIX, - ExactNearestNeighbor, - HNSWIndex, - HNSWQueryOptions, - IVFFlatIndex, - IVFFlatQueryOptions, + DEFAULT_DISTANCE_STRATEGY, + DEFAULT_INDEX_NAME_SUFFIX, + BaseIndex, + DistanceStrategy, + ExactNearestNeighbor, + HNSWIndex, + HNSWQueryOptions, + IVFFlatIndex, + IVFFlatQueryOptions, + QueryOptions, + StrategyMixin, ) diff --git a/src/langchain_google_alloydb_pg/vectorstore.py b/src/langchain_google_alloydb_pg/vectorstore.py index f7dc16b1..39a25587 100644 --- a/src/langchain_google_alloydb_pg/vectorstore.py +++ b/src/langchain_google_alloydb_pg/vectorstore.py @@ -20,6 +20,7 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore + from langchain_google_alloydb_pg.indexes import ( DEFAULT_DISTANCE_STRATEGY, BaseIndex, diff --git a/tests/test_async_vectorstore_index.py b/tests/test_async_vectorstore_index.py index 92a151d4..317f3559 100644 --- a/tests/test_async_vectorstore_index.py +++ b/tests/test_async_vectorstore_index.py @@ -21,16 +21,16 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from sqlalchemy import text + +from langchain_google_alloydb_pg import AlloyDBEngine +from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore from langchain_google_alloydb_pg.indexes import ( DEFAULT_INDEX_NAME_SUFFIX, DistanceStrategy, HNSWIndex, IVFFlatIndex, ) -from sqlalchemy import text - -from langchain_google_alloydb_pg import AlloyDBEngine -from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX diff --git a/tests/test_async_vectorstore_search.py b/tests/test_async_vectorstore_search.py index 0bd411dc..50946c79 100644 --- a/tests/test_async_vectorstore_search.py +++ b/tests/test_async_vectorstore_search.py @@ -19,10 +19,6 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding -from langchain_google_alloydb_pg.indexes import ( - DistanceStrategy, - HNSWQueryOptions, -) from metadata_filtering_data import FILTERING_TEST_CASES, METADATAS from PIL import Image from sqlalchemy import text @@ -30,6 +26,8 @@ from langchain_google_alloydb_pg import AlloyDBEngine, Column from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore from langchain_google_alloydb_pg.indexes import ( + DistanceStrategy, + HNSWQueryOptions, ScaNNQueryOptions, ) diff --git a/tests/test_indexes.py b/tests/test_indexes.py index f3675eee..bc1d04de 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -14,8 +14,12 @@ import warnings -from langchain_google_alloydb_pg.indexes import ( +from langchain_google_alloydb_pg.indexes import ( # type: ignore DistanceStrategy, + HNSWIndex, + HNSWQueryOptions, + IVFFlatIndex, + IVFFlatQueryOptions, IVFIndex, IVFQueryOptions, ScaNNIndex, @@ -44,6 +48,42 @@ def test_distance_strategy(self): scann_index = ScaNNIndex(distance_strategy=DistanceStrategy.INNER_PRODUCT) assert scann_index.get_index_function() == "dot_prod" + def test_ivfflat_index(self): + index = IVFFlatIndex(name="test_index", lists=200) + assert index.index_type == "ivfflat" + assert index.lists == 200 + assert index.index_options() == "(lists = 200)" + + def test_ivfflat_query_options(self): + options = IVFFlatQueryOptions(probes=2) + assert options.to_parameter() == ["ivfflat.probes = 2"] + + with warnings.catch_warnings(record=True) as w: + options.to_string() + assert len(w) == 1 + assert "to_string is deprecated, use to_parameter instead." in str( + w[-1].message + ) + + def test_hnsw_index(self): + index = HNSWIndex(name="test_index", m=32, ef_construction=128) + assert index.index_type == "hnsw" + assert index.m == 32 + assert index.ef_construction == 128 + assert index.index_options() == "(m = 32, ef_construction = 128)" + + def test_hnsw_query_options(self): + options = HNSWQueryOptions(ef_search=80) + assert options.to_parameter() == ["hnsw.ef_search = 80"] + + with warnings.catch_warnings(record=True) as w: + options.to_string() + + assert len(w) == 1 + assert "to_string is deprecated, use to_parameter instead." in str( + w[-1].message + ) + def test_ivf_index(self): index = IVFIndex(name="test_index", lists=200) assert index.index_type == "ivf" diff --git a/tests/test_loader.py b/tests/test_loader.py index 457dfc83..8c883195 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -25,7 +25,7 @@ AlloyDBDocumentSaver, AlloyDBEngine, AlloyDBLoader, - Column + Column, ) project_id = os.environ["PROJECT_ID"] diff --git a/tests/test_vectorstore_embeddings.py b/tests/test_vectorstore_embeddings.py index c92bb298..d65e0c14 100644 --- a/tests/test_vectorstore_embeddings.py +++ b/tests/test_vectorstore_embeddings.py @@ -18,7 +18,6 @@ import pytest import pytest_asyncio from langchain_core.documents import Document -from langchain_google_alloydb_pg.indexes import DistanceStrategy, HNSWQueryOptions from sqlalchemy import text from langchain_google_alloydb_pg import ( @@ -26,8 +25,9 @@ AlloyDBEngine, AlloyDBModelManager, AlloyDBVectorStore, - Column + Column, ) +from langchain_google_alloydb_pg.indexes import DistanceStrategy, HNSWQueryOptions DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE_SYNC = "test_table" + str(uuid.uuid4()).replace("-", "_") diff --git a/tests/test_vectorstore_index.py b/tests/test_vectorstore_index.py index 857afb60..c63c464f 100644 --- a/tests/test_vectorstore_index.py +++ b/tests/test_vectorstore_index.py @@ -22,16 +22,14 @@ import sqlalchemy from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from sqlalchemy import text + +from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore from langchain_google_alloydb_pg.indexes import ( DEFAULT_INDEX_NAME_SUFFIX, DistanceStrategy, HNSWIndex, IVFFlatIndex, -) -from sqlalchemy import text - -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore -from langchain_google_alloydb_pg.indexes import ( IVFIndex, ScaNNIndex, ) diff --git a/tests/test_vectorstore_search.py b/tests/test_vectorstore_search.py index f83e5d53..875f5840 100644 --- a/tests/test_vectorstore_search.py +++ b/tests/test_vectorstore_search.py @@ -19,12 +19,12 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding -from langchain_google_alloydb_pg.indexes import DistanceStrategy, HNSWQueryOptions from metadata_filtering_data import FILTERING_TEST_CASES, METADATAS, NEGATIVE_TEST_CASES from PIL import Image from sqlalchemy import text from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column +from langchain_google_alloydb_pg.indexes import DistanceStrategy, HNSWQueryOptions DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE_SYNC = "test_table" + str(uuid.uuid4()).replace("-", "_") From 825fa1c753e73f09c9a2a02967ad43c32bba9626 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Tue, 12 Aug 2025 21:53:10 +0000 Subject: [PATCH 10/12] refactor!: Refactor AlloyDBVectorStore to depend on PGVectorstore (#435) * refactor!: Refactor AlloyDBVectorStore to depend on PGVectorstore * Linter fix * Fix tests * Fix tests * fix tests * linter fix * fix vectorstore * add all existing tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix dependecies * Add tests for hybrid search * linter fix * fix test * Add better documentation for breaking change * minor fix * update deps * Re-expose hybrid search config * add header * linter fix --- docs/vector_store.ipynb | 128 +- pyproject.toml | 3 - requirements.txt | 2 - .../async_vectorstore.py | 1226 +---------------- .../hybrid_search_config.py | 19 + .../vectorstore.py | 751 +--------- tests/test_async_vectorstore.py | 5 +- tests/test_async_vectorstore_index.py | 91 +- tests/test_async_vectorstore_search.py | 401 +++++- tests/test_engine.py | 55 + tests/test_standard_test_suite.py | 4 +- tests/test_vectorstore.py | 22 +- tests/test_vectorstore_embeddings.py | 10 +- tests/test_vectorstore_index.py | 39 +- tests/test_vectorstore_search.py | 86 +- 15 files changed, 837 insertions(+), 2005 deletions(-) create mode 100644 src/langchain_google_alloydb_pg/hybrid_search_config.py diff --git a/docs/vector_store.ipynb b/docs/vector_store.ipynb index 46fd80e9..5a35d820 100644 --- a/docs/vector_store.ipynb +++ b/docs/vector_store.ipynb @@ -650,8 +650,45 @@ "all_texts = [\"Apples and oranges\", \"Cars and airplanes\", \"Pineapple\", \"Train\", \"Banana\"]\n", "metadatas = [{\"len\": len(t)} for t in all_texts]\n", "ids = [str(uuid.uuid4()) for _ in all_texts]\n", - "await custom_store.aadd_texts(all_texts, metadatas=metadatas, ids=ids)\n", + "await custom_store.aadd_texts(all_texts, metadatas=metadatas, ids=ids)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### For v0.13.0+\n", "\n", + "**Important Update:** Support for string filters has been deprecated. Please use dictionaries to add filters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use filter on search\n", + "docs = await custom_store.asimilarity_search(query, filter={\"len\": {\"$gte\": 6}})\n", + "\n", + "print(docs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### For v0.12.0 and under\n", + "\n", + "You can make use of the string filters to filter on metadata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "# Use filter on search\n", "docs = await custom_store.asimilarity_search(query, filter=\"len >= 6\")\n", "\n", @@ -766,6 +803,37 @@ "Since price_usd is one of the metadata_columns, we can use price filter while searching" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### For v0.13.0+\n", + "\n", + "**Important Update:** Support for string filters has been deprecated. Please use dictionaries to add filters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import uuid\n", + "\n", + "docs = await custom_store.asimilarity_search(query, filter={\"price_usd\": {\"$gte\": 100}})\n", + "\n", + "print(docs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### For v0.12.0 and under\n", + "\n", + "You can make use of the string filters to filter on metadata" + ] + }, { "cell_type": "code", "execution_count": null, @@ -783,8 +851,62 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Search for documents with json filter\n", - "Since category is added in json metadata, we can use filter on JSON fields while searching\n" + "### Search for documents with json filter\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### For v0.13.0+\n", + "\n", + "**Important Update:** Support for string filters has been deprecated. To filter data on the JSON metadata, you must first create a new column for the specific key you wish to filter on. Use the following SQL command to set this up." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "sql" + } + }, + "outputs": [], + "source": [ + "ALTER TABLE vectorstore_table ADD COLUMN category VARCHAR;\n", + "UPDATE vectorstore_table\n", + "SET\n", + " category = metadata ->> 'category';" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that you've added the new column, you must update the Vectorstore instance to recognize it. After which the new column is available for filtering operations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import uuid\n", + "\n", + "docs = await custom_store.asimilarity_search(query, filter={\"category\": \"Electronics\"})\n", + "\n", + "print(docs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "#### For v0.12.0 and under\n", + "\n", + "Since category is added in json metadata, we can use filter on JSON fields using string filters while searching." ] }, { diff --git a/pyproject.toml b/pyproject.toml index a8af1aa1..75f5c03c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,12 +11,9 @@ authors = [ dependencies = [ "google-cloud-alloydb-connector[asyncpg]>=1.2.0, <2.0.0", "google-cloud-storage>=2.18.2, <4.0.0", - "langchain-core>=0.2.36, <1.0.0", "numpy>=1.24.4, <3.0.0; python_version >= '3.11'", "numpy>=1.24.4, <=2.2.6; python_version == '3.10'", "numpy>=1.24.4, <=2.0.2; python_version <= '3.9'", - "pgvector>=0.2.5, <0.4.0", - "SQLAlchemy[asyncio]>=2.0.25, <3.0.0", "langchain-postgres>=0.0.15", ] diff --git a/requirements.txt b/requirements.txt index fa5b9023..8afe1644 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,7 @@ google-cloud-alloydb-connector[asyncpg]==1.9.0 google-cloud-storage==3.1.1 -langchain-core==0.3.67 numpy==2.3.1; python_version >= "3.11" numpy==2.2.6; python_version == "3.10" numpy==2.0.2; python_version <= "3.9" -SQLAlchemy[asyncio]==2.0.41 langgraph==0.6.0 langchain-postgres==0.0.15 \ No newline at end of file diff --git a/src/langchain_google_alloydb_pg/async_vectorstore.py b/src/langchain_google_alloydb_pg/async_vectorstore.py index c425619f..7437f5a3 100644 --- a/src/langchain_google_alloydb_pg/async_vectorstore.py +++ b/src/langchain_google_alloydb_pg/async_vectorstore.py @@ -16,362 +16,22 @@ from __future__ import annotations import base64 -import copy -import json import re -import uuid -from typing import Any, Callable, Iterable, Optional, Sequence +from typing import Any, Optional import numpy as np import requests from google.cloud import storage # type: ignore from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from langchain_core.vectorstores import VectorStore, utils -from sqlalchemy import RowMapping, text -from sqlalchemy.ext.asyncio import AsyncEngine +from langchain_postgres.v2.async_vectorstore import AsyncPGVectorStore +from sqlalchemy import text -from langchain_google_alloydb_pg.indexes import ( - DEFAULT_DISTANCE_STRATEGY, - DEFAULT_INDEX_NAME_SUFFIX, - BaseIndex, - DistanceStrategy, - ExactNearestNeighbor, - QueryOptions, -) -from .engine import AlloyDBEngine - -COMPARISONS_TO_NATIVE = { - "$eq": "=", - "$ne": "!=", - "$lt": "<", - "$lte": "<=", - "$gt": ">", - "$gte": ">=", -} - -SPECIAL_CASED_OPERATORS = { - "$in", - "$nin", - "$between", - "$exists", -} - -TEXT_OPERATORS = { - "$like", - "$ilike", -} - -LOGICAL_OPERATORS = {"$and", "$or", "$not"} - -SUPPORTED_OPERATORS = ( - set(COMPARISONS_TO_NATIVE) - .union(TEXT_OPERATORS) - .union(LOGICAL_OPERATORS) - .union(SPECIAL_CASED_OPERATORS) -) - - -class AsyncAlloyDBVectorStore(VectorStore): +class AsyncAlloyDBVectorStore(AsyncPGVectorStore): """Google AlloyDB Vector Store class""" - __create_key = object() - - def __init__( - self, - key: object, - engine: AsyncEngine, - embedding_service: Embeddings, - table_name: str, - schema_name: str = "public", - content_column: str = "content", - embedding_column: str = "embedding", - metadata_columns: list[str] = [], - id_column: str = "langchain_id", - metadata_json_column: Optional[str] = "langchain_metadata", - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - index_query_options: Optional[QueryOptions] = None, - ): - """AsyncAlloyDBVectorStore constructor. - Args: - key (object): Prevent direct constructor usage. - engine (AlloyDBEngine): Connection pool engine for managing connections to AlloyDB database. - embedding_service (Embeddings): Text embedding model to use. - table_name (str): Name of the existing table or the table to be created. - schema_name (str, optional): Name of the database schema. Defaults to "public". - content_column (str): Column that represent a Document’s page_content. Defaults to "content". - embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". - metadata_columns (list[str]): Column(s) that represent a document's metadata. - id_column (str): Column that represents the Document's id. Defaults to "langchain_id". - metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata". - distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. - k (int): Number of Documents to return from search. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. - index_query_options (QueryOptions): Index query option. - - - Raises: - Exception: If called directly by user. - """ - if key != AsyncAlloyDBVectorStore.__create_key: - raise Exception( - "Only create class through 'create' or 'create_sync' methods!" - ) - - self.engine = engine - self.embedding_service = embedding_service - self.table_name = table_name - self.schema_name = schema_name - self.content_column = content_column - self.embedding_column = embedding_column - self.metadata_columns = metadata_columns - self.id_column = id_column - self.metadata_json_column = metadata_json_column - self.distance_strategy = distance_strategy - self.k = k - self.fetch_k = fetch_k - self.lambda_mult = lambda_mult - self.index_query_options = index_query_options - - @classmethod - async def create( - cls: type[AsyncAlloyDBVectorStore], - engine: AlloyDBEngine, - embedding_service: Embeddings, - table_name: str, - schema_name: str = "public", - content_column: str = "content", - embedding_column: str = "embedding", - metadata_columns: list[str] = [], - ignore_metadata_columns: Optional[list[str]] = None, - id_column: str = "langchain_id", - metadata_json_column: Optional[str] = "langchain_metadata", - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - index_query_options: Optional[QueryOptions] = None, - ) -> AsyncAlloyDBVectorStore: - """Create an AsyncAlloyDBVectorStore instance. - - Args: - engine (AlloyDBEngine): Connection pool engine for managing connections to AlloyDB database. - embedding_service (Embeddings): Text embedding model to use. - table_name (str): Name of an existing table. - schema_name (str, optional): Name of the database schema. Defaults to "public". - content_column (str): Column that represent a Document’s page_content. Defaults to "content". - embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". - metadata_columns (list[str]): Column(s) that represent a document's metadata. - ignore_metadata_columns (list[str]): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. - id_column (str): Column that represents the Document's id. Defaults to "langchain_id". - metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata". - distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. - k (int): Number of Documents to return from search. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. - index_query_options (QueryOptions): Index query option. - - Returns: - AsyncAlloyDBVectorStore - """ - if metadata_columns and ignore_metadata_columns: - raise ValueError( - "Can not use both metadata_columns and ignore_metadata_columns." - ) - # Get field type information - stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}' AND table_schema = '{schema_name}'" - async with engine._pool.connect() as conn: - result = await conn.execute(text(stmt)) - result_map = result.mappings() - results = result_map.fetchall() - columns = {} - for field in results: - columns[field["column_name"]] = field["data_type"] - - # Check columns - if id_column not in columns: - raise ValueError(f"Id column, {id_column}, does not exist.") - if content_column not in columns: - raise ValueError(f"Content column, {content_column}, does not exist.") - content_type = columns[content_column] - if content_type != "text" and "char" not in content_type: - raise ValueError( - f"Content column, {content_column}, is type, {content_type}. It must be a type of character string." - ) - if embedding_column not in columns: - raise ValueError(f"Embedding column, {embedding_column}, does not exist.") - if columns[embedding_column] != "USER-DEFINED": - raise ValueError( - f"Embedding column, {embedding_column}, is not type Vector." - ) - - metadata_json_column = ( - None if metadata_json_column not in columns else metadata_json_column - ) - - # If using metadata_columns check to make sure column exists - for column in metadata_columns: - if column not in columns: - raise ValueError(f"Metadata column, {column}, does not exist.") - - # If using ignore_metadata_columns, filter out known columns and set known metadata columns - all_columns = columns - if ignore_metadata_columns: - for column in ignore_metadata_columns: - del all_columns[column] - - del all_columns[id_column] - del all_columns[content_column] - del all_columns[embedding_column] - metadata_columns = [k for k in all_columns.keys()] - - return cls( - cls.__create_key, - engine._pool, - embedding_service, - table_name, - schema_name=schema_name, - content_column=content_column, - embedding_column=embedding_column, - metadata_columns=metadata_columns, - id_column=id_column, - metadata_json_column=metadata_json_column, - distance_strategy=distance_strategy, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - index_query_options=index_query_options, - ) - - @property - def embeddings(self) -> Embeddings: - return self.embedding_service - - async def aadd_embeddings( - self, - texts: Iterable[str], - embeddings: list[list[float]], - metadatas: Optional[list[dict]] = None, - ids: Optional[list] = None, - **kwargs: Any, - ) -> list[str]: - """Add data along with embeddings to the table. - - Raises: - :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. - """ - if not ids: - ids = [str(uuid.uuid4()) for _ in texts] - else: - # This is done to fill in any missing ids - ids = [id if id is not None else str(uuid.uuid4()) for id in ids] - if not metadatas: - metadatas = [{} for _ in texts] - # Insert embeddings - for id, content, embedding, metadata in zip(ids, texts, embeddings, metadatas): - metadata_col_names = ( - ", " + ", ".join(f'"{col}"' for col in self.metadata_columns) - if len(self.metadata_columns) > 0 - else "" - ) - insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{metadata_col_names}' - values = { - "langchain_id": id, - "content": content, - "embedding": str([float(dimension) for dimension in embedding]), - } - values_stmt = "VALUES (:langchain_id, :content, :embedding" - inline_embed_func = getattr( - self.embedding_service, "embed_query_inline", None - ) - if not embedding and callable(inline_embed_func): - values_stmt = f"VALUES (:langchain_id, :content, {self.embedding_service.embed_query_inline(content)}" # type: ignore - - # Add metadata - extra = copy.deepcopy(metadata) - for metadata_column in self.metadata_columns: - if metadata_column in metadata: - values_stmt += f", :{metadata_column}" - values[metadata_column] = metadata[metadata_column] - del extra[metadata_column] - else: - values_stmt += ",null" - - # Add JSON column and/or close statement - insert_stmt += ( - f""", "{self.metadata_json_column}")""" - if self.metadata_json_column - else ")" - ) - if self.metadata_json_column: - values_stmt += ", :extra)" - values["extra"] = json.dumps(extra) - else: - values_stmt += ")" - - upsert_stmt = f' ON CONFLICT ("{self.id_column}") DO UPDATE SET "{self.content_column}" = EXCLUDED."{self.content_column}", "{self.embedding_column}" = EXCLUDED."{self.embedding_column}"' - - if self.metadata_json_column: - upsert_stmt += f', "{self.metadata_json_column}" = EXCLUDED."{self.metadata_json_column}"' - - for column in self.metadata_columns: - upsert_stmt += f', "{column}" = EXCLUDED."{column}"' - - upsert_stmt += ";" - - query = insert_stmt + values_stmt + upsert_stmt - async with self.engine.connect() as conn: - await conn.execute(text(query), values) - await conn.commit() - - return ids - - async def aadd_texts( - self, - texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[list] = None, - **kwargs: Any, - ) -> list[str]: - """Embed texts and add to the table. - - Raises: - :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. - """ - # Check for inline embedding query - inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) - if callable(inline_embed_func): - embeddings: list[list[float]] = [[] for _ in list(texts)] - else: - embeddings = await self.embedding_service.aembed_documents(list(texts)) - - ids = await self.aadd_embeddings( - texts, embeddings, metadatas=metadatas, ids=ids, **kwargs - ) - return ids - - async def aadd_documents( - self, - documents: list[Document], - ids: Optional[list] = None, - **kwargs: Any, - ) -> list[str]: - """Embed documents and add to the table. - - Raises: - :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. - """ - texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] - if not ids: - ids = [doc.id for doc in documents] - ids = await self.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) - return ids + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def _encode_image(self, uri: str) -> str: """Get base64 string from a image URI.""" @@ -434,236 +94,9 @@ async def aadd_images( ids = await self.aadd_embeddings( texts_for_content_column, embeddings, metadatas=metadatas, ids=ids, **kwargs ) - return ids - - async def adelete( - self, - ids: Optional[list] = None, - **kwargs: Any, - ) -> Optional[bool]: - """Delete records from the table. - - Raises: - :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. - """ - if not ids: - return False - - id_list = ", ".join([f"'{id}'" for id in ids]) - query = f'DELETE FROM "{self.schema_name}"."{self.table_name}" WHERE {self.id_column} in ({id_list})' - async with self.engine.connect() as conn: - await conn.execute(text(query)) - await conn.commit() - return True - - @classmethod - async def afrom_texts( # type: ignore[override] - cls: type[AsyncAlloyDBVectorStore], - texts: list[str], - embedding: Embeddings, - engine: AlloyDBEngine, - table_name: str, - schema_name: str = "public", - metadatas: Optional[list[dict]] = None, - ids: Optional[list] = None, - content_column: str = "content", - embedding_column: str = "embedding", - metadata_columns: list[str] = [], - ignore_metadata_columns: Optional[list[str]] = None, - id_column: str = "langchain_id", - metadata_json_column: str = "langchain_metadata", - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - index_query_options: Optional[QueryOptions] = None, - **kwargs: Any, - ) -> AsyncAlloyDBVectorStore: - """Create an AsyncAlloyDBVectorStore instance from texts. - - Args: - texts (list[str]): Texts to add to the vector store. - embedding (Embeddings): Text embedding model to use. - engine (AlloyDBEngine): Connection pool engine for managing connections to AlloyDB database. - table_name (str): Name of an existing table. - metadatas (Optional[list[dict]]): List of metadatas to add to table records. - ids: (Optional[list[str]]): List of IDs to add to table records. - content_column (str): Column that represent a Document’s page_content. Defaults to "content". - embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". - metadata_columns (list[str]): Column(s) that represent a document's metadata. - ignore_metadata_columns (list[str]): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. - id_column (str): Column that represents the Document's id. Defaults to "langchain_id". - metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata". - distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. - k (int): Number of Documents to return from search. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. - index_query_options (QueryOptions): Index query option. - - Raises: - :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. - - Returns: - AsyncAlloyDBVectorStore - """ - vs = await cls.create( - engine, - embedding, - table_name, - schema_name=schema_name, - content_column=content_column, - embedding_column=embedding_column, - metadata_columns=metadata_columns, - ignore_metadata_columns=ignore_metadata_columns, - id_column=id_column, - metadata_json_column=metadata_json_column, - distance_strategy=distance_strategy, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - index_query_options=index_query_options, - ) - await vs.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) - return vs - - @classmethod - async def afrom_documents( # type: ignore[override] - cls: type[AsyncAlloyDBVectorStore], - documents: list[Document], - embedding: Embeddings, - engine: AlloyDBEngine, - table_name: str, - schema_name: str = "public", - ids: Optional[list] = None, - content_column: str = "content", - embedding_column: str = "embedding", - metadata_columns: list[str] = [], - ignore_metadata_columns: Optional[list[str]] = None, - id_column: str = "langchain_id", - metadata_json_column: str = "langchain_metadata", - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - index_query_options: Optional[QueryOptions] = None, - **kwargs: Any, - ) -> AsyncAlloyDBVectorStore: - """Create an AsyncAlloyDBVectorStore instance from documents. - - Args: - documents (list[Document]): Documents to add to the vector store. - embedding (Embeddings): Text embedding model to use. - engine (AlloyDBEngine): Connection pool engine for managing connections to AlloyDB database. - table_name (str): Name of an existing table. - metadatas (Optional[list[dict]]): List of metadatas to add to table records. - ids: (Optional[list[str]]): List of IDs to add to table records. - content_column (str): Column that represent a Document’s page_content. Defaults to "content". - embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". - metadata_columns (list[str]): Column(s) that represent a document's metadata. - ignore_metadata_columns (list[str]): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. - id_column (str): Column that represents the Document's id. Defaults to "langchain_id". - metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata". - distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. - k (int): Number of Documents to return from search. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. - index_query_options (QueryOptions): Index query option. - - Raises: - :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. - - Returns: - AsyncAlloyDBVectorStore - """ - - vs = await cls.create( - engine, - embedding, - table_name, - schema_name=schema_name, - content_column=content_column, - embedding_column=embedding_column, - metadata_columns=metadata_columns, - ignore_metadata_columns=ignore_metadata_columns, - id_column=id_column, - metadata_json_column=metadata_json_column, - distance_strategy=distance_strategy, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - index_query_options=index_query_options, - ) - texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] - await vs.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) - return vs - - async def __query_collection( - self, - embedding: list[float], - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> Sequence[RowMapping]: - """Perform similarity search query on database.""" - k = k if k else self.k - operator = self.distance_strategy.operator - search_function = self.distance_strategy.search_function - - columns = self.metadata_columns + [ - self.id_column, - self.content_column, - self.embedding_column, - ] - if self.metadata_json_column: - columns.append(self.metadata_json_column) - - column_names = ", ".join(f'"{col}"' for col in columns) - - if filter and isinstance(filter, dict): - filter = self._create_filter_clause(filter) - filter = f"WHERE {filter}" if filter else "" - inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) - if not embedding and callable(inline_embed_func) and "query" in kwargs: - query_embedding = self.embedding_service.embed_query_inline(kwargs["query"]) # type: ignore - else: - query_embedding = f"'{[float(dimension) for dimension in embedding]}'" - stmt = f'SELECT {column_names}, {search_function}({self.embedding_column}, {query_embedding}) as distance FROM "{self.schema_name}"."{self.table_name}" {filter} ORDER BY {self.embedding_column} {operator} {query_embedding} LIMIT {k};' - if self.index_query_options: - async with self.engine.connect() as conn: - # Set each query option individually - for query_option in self.index_query_options.to_parameter(): - query_options_stmt = f"SET LOCAL {query_option};" - await conn.execute(text(query_options_stmt)) - result = await conn.execute(text(stmt)) - result_map = result.mappings() - results = result_map.fetchall() - else: - async with self.engine.connect() as conn: - result = await conn.execute(text(stmt)) - result_map = result.mappings() - results = result_map.fetchall() - return results - - async def asimilarity_search( - self, - query: str, - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected by similarity search on query.""" - inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) - embedding = ( - [] - if callable(inline_embed_func) - else await self.embedding_service.aembed_query(text=query) - ) - kwargs["query"] = query - - return await self.asimilarity_search_by_vector( - embedding=embedding, k=k, filter=filter, **kwargs - ) + if ids: + return ids + return [] def _images_embedding_helper(self, image_uris: list[str]) -> list[list[float]]: # check if either `embed_images()` or `embed_image()` API is supported by the embedding service used @@ -691,7 +124,7 @@ async def asimilarity_search_image( self, image_uri: str, k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by similarity search on query.""" @@ -701,177 +134,6 @@ async def asimilarity_search_image( embedding=embedding, k=k, filter=filter, **kwargs ) - def _select_relevance_score_fn(self) -> Callable[[float], float]: - """Select a relevance function based on distance strategy.""" - # Calculate distance strategy provided in - # vectorstore constructor - if self.distance_strategy == DistanceStrategy.COSINE_DISTANCE: - return self._cosine_relevance_score_fn - if self.distance_strategy == DistanceStrategy.INNER_PRODUCT: - return self._max_inner_product_relevance_score_fn - elif self.distance_strategy == DistanceStrategy.EUCLIDEAN: - return self._euclidean_relevance_score_fn - - async def asimilarity_search_with_score( - self, - query: str, - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs and distance scores selected by similarity search on query.""" - inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) - embedding = ( - [] - if callable(inline_embed_func) - else await self.embedding_service.aembed_query(text=query) - ) - kwargs["query"] = query - - docs = await self.asimilarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter, **kwargs - ) - return docs - - async def asimilarity_search_by_vector( - self, - embedding: list[float], - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected by vector similarity search.""" - docs_and_scores = await self.asimilarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter, **kwargs - ) - - return [doc for doc, _ in docs_and_scores] - - async def asimilarity_search_with_score_by_vector( - self, - embedding: list[float], - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs and distance scores selected by vector similarity search.""" - results = await self.__query_collection( - embedding=embedding, k=k, filter=filter, **kwargs - ) - - documents_with_scores = [] - for row in results: - metadata = ( - row[self.metadata_json_column] - if self.metadata_json_column and row[self.metadata_json_column] - else {} - ) - for col in self.metadata_columns: - metadata[col] = row[col] - documents_with_scores.append( - ( - Document( - page_content=row[self.content_column], - metadata=metadata, - id=str(row[self.id_column]), - ), - row["distance"], - ) - ) - - return documents_with_scores - - async def amax_marginal_relevance_search( - self, - query: str, - k: Optional[int] = None, - fetch_k: Optional[int] = None, - lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance.""" - embedding = await self.embedding_service.aembed_query(text=query) - - return await self.amax_marginal_relevance_search_by_vector( - embedding=embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - **kwargs, - ) - - async def amax_marginal_relevance_search_by_vector( - self, - embedding: list[float], - k: Optional[int] = None, - fetch_k: Optional[int] = None, - lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance.""" - docs_and_scores = ( - await self.amax_marginal_relevance_search_with_score_by_vector( - embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - **kwargs, - ) - ) - - return [result[0] for result in docs_and_scores] - - async def amax_marginal_relevance_search_with_score_by_vector( - self, - embedding: list[float], - k: Optional[int] = None, - fetch_k: Optional[int] = None, - lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs and distance scores selected using the maximal marginal relevance.""" - results = await self.__query_collection( - embedding=embedding, k=fetch_k, filter=filter, **kwargs - ) - - k = k if k else self.k - fetch_k = fetch_k if fetch_k else self.fetch_k - lambda_mult = lambda_mult if lambda_mult else self.lambda_mult - embedding_list = [json.loads(row[self.embedding_column]) for row in results] - mmr_selected = utils.maximal_marginal_relevance( - np.array(embedding, dtype=np.float32), - embedding_list, - k=k, - lambda_mult=lambda_mult, - ) - - documents_with_scores = [] - for row in results: - metadata = ( - row[self.metadata_json_column] - if self.metadata_json_column and row[self.metadata_json_column] - else {} - ) - for col in self.metadata_columns: - metadata[col] = row[col] - documents_with_scores.append( - ( - Document( - page_content=row[self.content_column], - metadata=metadata, - id=str(row[self.id_column]), - ), - row["distance"], - ) - ) - - return [r for i, r in enumerate(documents_with_scores) if i in mmr_selected] - async def set_maintenance_work_mem(self, num_leaves: int, vector_size: int) -> None: """Set database maintenance work memory (for ScaNN index creation).""" # Required index memory in MB @@ -884,339 +146,6 @@ async def set_maintenance_work_mem(self, num_leaves: int, vector_size: int) -> N await conn.execute(text(query)) await conn.commit() - async def aapply_vector_index( - self, - index: BaseIndex, - name: Optional[str] = None, - concurrently: bool = False, - ) -> None: - """Create index in the vector store table.""" - if isinstance(index, ExactNearestNeighbor): - await self.adrop_vector_index() - return - - # if extension name is mentioned, create the extension - if index.extension_name: - async with self.engine.connect() as conn: - await conn.execute( - text(f"CREATE EXTENSION IF NOT EXISTS {index.extension_name}") - ) - await conn.commit() - function = index.get_index_function() - - filter = f"WHERE ({index.partial_indexes})" if index.partial_indexes else "" - params = "WITH " + index.index_options() - if name is None: - if index.name == None: - index.name = self.table_name + DEFAULT_INDEX_NAME_SUFFIX - name = index.name - stmt = f"CREATE INDEX {'CONCURRENTLY' if concurrently else ''} {name} ON \"{self.schema_name}\".\"{self.table_name}\" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};" - if concurrently: - async with self.engine.connect() as conn: - await conn.execute(text("COMMIT")) - await conn.execute(text(stmt)) - else: - async with self.engine.connect() as conn: - await conn.execute(text(stmt)) - await conn.commit() - - async def areindex(self, index_name: Optional[str] = None) -> None: - """Re-index the vector store table.""" - index_name = index_name or self.table_name + DEFAULT_INDEX_NAME_SUFFIX - query = f"REINDEX INDEX {index_name};" - async with self.engine.connect() as conn: - await conn.execute(text(query)) - await conn.commit() - - async def adrop_vector_index( - self, - index_name: Optional[str] = None, - ) -> None: - """Drop the vector index.""" - index_name = index_name or self.table_name + DEFAULT_INDEX_NAME_SUFFIX - query = f"DROP INDEX IF EXISTS {index_name};" - async with self.engine.connect() as conn: - await conn.execute(text(query)) - await conn.commit() - - async def is_valid_index( - self, - index_name: Optional[str] = None, - ) -> bool: - """Check if index exists in the table.""" - index_name = index_name or self.table_name + DEFAULT_INDEX_NAME_SUFFIX - query = f""" - SELECT tablename, indexname - FROM pg_indexes - WHERE tablename = '{self.table_name}' AND schemaname = '{self.schema_name}' AND indexname = '{index_name}'; - """ - async with self.engine.connect() as conn: - result = await conn.execute(text(query)) - result_map = result.mappings() - results = result_map.fetchall() - return bool(len(results) == 1) - - async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]: - """Get documents by ids.""" - - quoted_ids = [f"'{id_val}'" for id_val in ids] - id_list_str = ", ".join(quoted_ids) - - columns = self.metadata_columns + [ - self.id_column, - self.content_column, - ] - if self.metadata_json_column: - columns.append(self.metadata_json_column) - - column_names = ", ".join(f'"{col}"' for col in columns) - - query = f'SELECT {column_names} FROM "{self.schema_name}"."{self.table_name}" WHERE "{self.id_column}" IN ({id_list_str});' - - async with self.engine.connect() as conn: - result = await conn.execute(text(query)) - result_map = result.mappings() - results = result_map.fetchall() - - documents = [] - for row in results: - metadata = ( - row[self.metadata_json_column] - if self.metadata_json_column and row[self.metadata_json_column] - else {} - ) - for col in self.metadata_columns: - metadata[col] = row[col] - documents.append( - ( - Document( - page_content=row[self.content_column], - metadata=metadata, - id=str(row[self.id_column]), - ) - ) - ) - - return documents - - def _handle_field_filter( - self, - field: str, - value: Any, - ) -> str: - """Create a filter for a specific field. - - Args: - field: name of field - value: value to filter - If provided as is then this will be an equality filter - If provided as a dictionary then this will be a filter, the key - will be the operator and the value will be the value to filter by - - Returns: - sql where query as a string - """ - if not isinstance(field, str): - raise ValueError( - f"field should be a string but got: {type(field)} with value: {field}" - ) - - if field.startswith("$"): - raise ValueError( - f"Invalid filter condition. Expected a field but got an operator: " - f"{field}" - ) - - # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters - if not field.isidentifier(): - raise ValueError( - f"Invalid field name: {field}. Expected a valid identifier." - ) - - if isinstance(value, dict): - # This is a filter specification - if len(value) != 1: - raise ValueError( - "Invalid filter condition. Expected a value which " - "is a dictionary with a single key that corresponds to an operator " - f"but got a dictionary with {len(value)} keys. The first few " - f"keys are: {list(value.keys())[:3]}" - ) - operator, filter_value = list(value.items())[0] - # Verify that that operator is an operator - if operator not in SUPPORTED_OPERATORS: - raise ValueError( - f"Invalid operator: {operator}. " - f"Expected one of {SUPPORTED_OPERATORS}" - ) - else: # Then we assume an equality operator - operator = "$eq" - filter_value = value - - if operator in COMPARISONS_TO_NATIVE: - # Then we implement an equality filter - # native is trusted input - if isinstance(filter_value, str): - filter_value = f"'{filter_value}'" - native = COMPARISONS_TO_NATIVE[operator] - return f"({field} {native} {filter_value})" - elif operator == "$between": - # Use AND with two comparisons - low, high = filter_value - - return f"({field} BETWEEN {low} AND {high})" - elif operator in {"$in", "$nin", "$like", "$ilike"}: - # We'll do force coercion to text - if operator in {"$in", "$nin"}: - for val in filter_value: - if not isinstance(val, (str, int, float)): - raise NotImplementedError( - f"Unsupported type: {type(val)} for value: {val}" - ) - - if isinstance(val, bool): # b/c bool is an instance of int - raise NotImplementedError( - f"Unsupported type: {type(val)} for value: {val}" - ) - - if operator in {"$in"}: - values = str(tuple(val for val in filter_value)) - return f"({field} IN {values})" - elif operator in {"$nin"}: - values = str(tuple(val for val in filter_value)) - return f"({field} NOT IN {values})" - elif operator in {"$like"}: - return f"({field} LIKE '{filter_value}')" - elif operator in {"$ilike"}: - return f"({field} ILIKE '{filter_value}')" - else: - raise NotImplementedError() - elif operator == "$exists": - if not isinstance(filter_value, bool): - raise ValueError( - "Expected a boolean value for $exists " - f"operator, but got: {filter_value}" - ) - else: - if filter_value: - return f"({field} IS NOT NULL)" - else: - return f"({field} IS NULL)" - else: - raise NotImplementedError() - - def _create_filter_clause(self, filters: Any) -> str: - """Create LangChain filter representation to matching SQL where clauses - - Args: - filters: Dictionary of filters to apply to the query. - - Returns: - String containing the sql where query. - """ - - if not isinstance(filters, dict): - raise ValueError( - f"Invalid type: Expected a dictionary but got type: {type(filters)}" - ) - if len(filters) == 1: - # The only operators allowed at the top level are $AND, $OR, and $NOT - # First check if an operator or a field - key, value = list(filters.items())[0] - if key.startswith("$"): - # Then it's an operator - if key.lower() not in ["$and", "$or", "$not"]: - raise ValueError( - f"Invalid filter condition. Expected $and, $or or $not " - f"but got: {key}" - ) - else: - # Then it's a field - return self._handle_field_filter(key, filters[key]) - - if key.lower() == "$and" or key.lower() == "$or": - if not isinstance(value, list): - raise ValueError( - f"Expected a list, but got {type(value)} for value: {value}" - ) - op = key[1:].upper() # Extract the operator - filter_clause = [self._create_filter_clause(el) for el in value] - if len(filter_clause) > 1: - return f"({f' {op} '.join(filter_clause)})" - elif len(filter_clause) == 1: - return filter_clause[0] - else: - raise ValueError( - "Invalid filter condition. Expected a dictionary " - "but got an empty dictionary" - ) - elif key.lower() == "$not": - if isinstance(value, list): - not_conditions = [ - self._create_filter_clause(item) for item in value - ] - not_stmts = [f"NOT {condition}" for condition in not_conditions] - return f"({' AND '.join(not_stmts)})" - elif isinstance(value, dict): - not_ = self._create_filter_clause(value) - return f"(NOT {not_})" - else: - raise ValueError( - f"Invalid filter condition. Expected a dictionary " - f"or a list but got: {type(value)}" - ) - else: - raise ValueError( - f"Invalid filter condition. Expected $and, $or or $not " - f"but got: {key}" - ) - elif len(filters) > 1: - # Then all keys have to be fields (they cannot be operators) - for key in filters.keys(): - if key.startswith("$"): - raise ValueError( - f"Invalid filter condition. Expected a field but got: {key}" - ) - # These should all be fields and combined using an $and operator - and_ = [self._handle_field_filter(k, v) for k, v in filters.items()] - if len(and_) > 1: - return f"({' AND '.join(and_)})" - elif len(and_) == 1: - return and_[0] - else: - raise ValueError( - "Invalid filter condition. Expected a dictionary " - "but got an empty dictionary" - ) - else: - return "" - - def get_by_ids(self, ids: Sequence[str]) -> list[Document]: - raise NotImplementedError( - "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." - ) - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[list] = None, - **kwargs: Any, - ) -> list[str]: - raise NotImplementedError( - "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." - ) - - def add_documents( - self, - documents: list[Document], - ids: Optional[list] = None, - **kwargs: Any, - ) -> list[str]: - raise NotImplementedError( - "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." - ) - def add_images( self, uris: list[str], @@ -1228,146 +157,13 @@ def add_images( "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." ) - def delete( - self, - ids: Optional[list] = None, - **kwargs: Any, - ) -> Optional[bool]: - raise NotImplementedError( - "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." - ) - - @classmethod - def from_texts( # type: ignore[override] - cls: type[AsyncAlloyDBVectorStore], - texts: list[str], - embedding: Embeddings, - engine: AlloyDBEngine, - table_name: str, - metadatas: Optional[list[dict]] = None, - ids: Optional[list] = None, - content_column: str = "content", - embedding_column: str = "embedding", - metadata_columns: list[str] = [], - ignore_metadata_columns: Optional[list[str]] = None, - id_column: str = "langchain_id", - metadata_json_column: str = "langchain_metadata", - **kwargs: Any, - ) -> AsyncAlloyDBVectorStore: - raise NotImplementedError( - "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." - ) - - @classmethod - def from_documents( # type: ignore[override] - cls: type[AsyncAlloyDBVectorStore], - documents: list[Document], - embedding: Embeddings, - engine: AlloyDBEngine, - table_name: str, - ids: Optional[list] = None, - content_column: str = "content", - embedding_column: str = "embedding", - metadata_columns: list[str] = [], - ignore_metadata_columns: Optional[list[str]] = None, - id_column: str = "langchain_id", - metadata_json_column: str = "langchain_metadata", - **kwargs: Any, - ) -> AsyncAlloyDBVectorStore: - raise NotImplementedError( - "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." - ) - - def similarity_search( - self, - query: str, - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[Document]: - raise NotImplementedError( - "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." - ) - def similarity_search_image( self, image_uri: str, k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[Document]: - raise NotImplementedError( - "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." - ) - - def similarity_search_with_score( - self, - query: str, - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - raise NotImplementedError( - "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." - ) - - def similarity_search_by_vector( - self, - embedding: list[float], - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[Document]: - raise NotImplementedError( - "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." - ) - - def similarity_search_with_score_by_vector( - self, - embedding: list[float], - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - raise NotImplementedError( - "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." - ) - - def max_marginal_relevance_search( - self, - query: str, - k: Optional[int] = None, - fetch_k: Optional[int] = None, - lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." ) - - def max_marginal_relevance_search_by_vector( - self, - embedding: list[float], - k: Optional[int] = None, - fetch_k: Optional[int] = None, - lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[Document]: - raise NotImplementedError( - "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." - ) - - def max_marginal_relevance_search_with_score_by_vector( - self, - embedding: list[float], - k: Optional[int] = None, - fetch_k: Optional[int] = None, - lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - raise NotImplementedError( - "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." - ) diff --git a/src/langchain_google_alloydb_pg/hybrid_search_config.py b/src/langchain_google_alloydb_pg/hybrid_search_config.py new file mode 100644 index 00000000..9e024cc1 --- /dev/null +++ b/src/langchain_google_alloydb_pg/hybrid_search_config.py @@ -0,0 +1,19 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from langchain_postgres.v2.hybrid_search_config import ( + HybridSearchConfig, + reciprocal_rank_fusion, + weighted_sum_ranking, +) diff --git a/src/langchain_google_alloydb_pg/vectorstore.py b/src/langchain_google_alloydb_pg/vectorstore.py index 39a25587..f018aaa9 100644 --- a/src/langchain_google_alloydb_pg/vectorstore.py +++ b/src/langchain_google_alloydb_pg/vectorstore.py @@ -15,15 +15,15 @@ # TODO: Remove below import when minimum supported Python version is 3.10 from __future__ import annotations -from typing import Any, Callable, Iterable, Optional, Sequence +from typing import Any, Optional from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.vectorstores import VectorStore +from langchain_postgres import PGVectorStore +from langchain_google_alloydb_pg.hybrid_search_config import HybridSearchConfig from langchain_google_alloydb_pg.indexes import ( DEFAULT_DISTANCE_STRATEGY, - BaseIndex, DistanceStrategy, QueryOptions, ) @@ -32,40 +32,22 @@ from .engine import AlloyDBEngine -class AlloyDBVectorStore(VectorStore): +class AlloyDBVectorStore(PGVectorStore): """Google AlloyDB Vector Store class""" - __create_key = object() - - def __init__(self, key: object, engine: AlloyDBEngine, vs: AsyncAlloyDBVectorStore): - """AlloyDBVectorStore constructor. - Args: - key (object): Prevent direct constructor usage. - engine (AlloyDBEngine): Connection pool engine for managing connections to Postgres database. - vs (AsyncAlloyDBVectorstore): The async only VectorStore implementation - - - Raises: - Exception: If called directly by user. - """ - if key != AlloyDBVectorStore.__create_key: - raise Exception( - "Only create class through 'create' or 'create_sync' methods!" - ) - - self._engine = engine - self.__vs = vs + _engine: AlloyDBEngine + __vs: AsyncAlloyDBVectorStore @classmethod async def create( cls: type[AlloyDBVectorStore], - engine: AlloyDBEngine, + engine: AlloyDBEngine, # type: ignore embedding_service: Embeddings, table_name: str, schema_name: str = "public", content_column: str = "content", embedding_column: str = "embedding", - metadata_columns: list[str] = [], + metadata_columns: Optional[list[str]] = None, ignore_metadata_columns: Optional[list[str]] = None, id_column: str = "langchain_id", metadata_json_column: Optional[str] = "langchain_metadata", @@ -74,15 +56,16 @@ async def create( fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> AlloyDBVectorStore: - """Create an AlloyDBVectorStore instance. + """Create an PGVectorStore instance. Args: - engine (AlloyDBEngine): Connection pool engine for managing connections to AlloyDB database. + engine (AlloyDBEngine): Connection pool engine for managing connections to postgres database. embedding_service (Embeddings): Text embedding model to use. table_name (str): Name of an existing table. schema_name (str, optional): Name of the database schema. Defaults to "public". - content_column (str): Column that represent a Document’s page_content. Defaults to "content". + content_column (str): Column that represent a Document's page_content. Defaults to "content". embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". metadata_columns (list[str]): Column(s) that represent a document's metadata. ignore_metadata_columns (list[str]): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. @@ -93,9 +76,10 @@ async def create( fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Returns: - AlloyDBVectorStore + PGVectorStore """ coro = AsyncAlloyDBVectorStore.create( engine, @@ -113,20 +97,21 @@ async def create( fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) vs = await engine._run_as_async(coro) - return cls(cls.__create_key, engine, vs) + return cls(cls._PGVectorStore__create_key, engine, vs) # type: ignore @classmethod def create_sync( - cls, - engine: AlloyDBEngine, + cls: type[AlloyDBVectorStore], + engine: AlloyDBEngine, # type: ignore embedding_service: Embeddings, table_name: str, schema_name: str = "public", content_column: str = "content", embedding_column: str = "embedding", - metadata_columns: list[str] = [], + metadata_columns: Optional[list[str]] = None, ignore_metadata_columns: Optional[list[str]] = None, id_column: str = "langchain_id", metadata_json_column: str = "langchain_metadata", @@ -135,6 +120,7 @@ def create_sync( fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> AlloyDBVectorStore: """Create an AlloyDBVectorStore instance. @@ -155,6 +141,7 @@ def create_sync( fetch_k (int, optional): Number of Documents to fetch to pass to MMR algorithm. Defaults to 20. lambda_mult (float, optional): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (Optional[QueryOptions], optional): Index query option. Defaults to None. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Returns: AlloyDBVectorStore @@ -175,57 +162,10 @@ def create_sync( fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) vs = engine._run_as_sync(coro) - return cls(cls.__create_key, engine, vs) - - @property - def embeddings(self) -> Embeddings: - return self.__vs.embedding_service - - async def aadd_embeddings( - self, - texts: Iterable[str], - embeddings: list[list[float]], - metadatas: Optional[list[dict]] = None, - ids: Optional[list[str]] = None, - **kwargs: Any, - ) -> list[str]: - """Add data along with embeddings to the table.""" - return await self._engine._run_as_async( - self.__vs.aadd_embeddings(texts, embeddings, metadatas, ids, **kwargs) - ) - - async def aadd_texts( - self, - texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[list] = None, - **kwargs: Any, - ) -> list[str]: - """Embed texts and add to the table. - - Raises: - :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. - """ - return await self._engine._run_as_async( - self.__vs.aadd_texts(texts, metadatas, ids, **kwargs) - ) - - async def aadd_documents( - self, - documents: list[Document], - ids: Optional[list] = None, - **kwargs: Any, - ) -> list[str]: - """Embed documents and add to the table. - - Raises: - :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. - """ - return await self._engine._run_as_async( - self.__vs.aadd_documents(documents, ids, **kwargs) - ) + return cls(cls._PGVectorStore__create_key, engine, vs) # type: ignore async def aadd_images( self, @@ -237,55 +177,11 @@ async def aadd_images( ) -> list[str]: """Embed images and add to the table.""" return await self._engine._run_as_async( - self.__vs.aadd_images( + self._PGVectorStore__vs.aadd_images( # type: ignore uris, metadatas, ids, store_uri_only=store_uri_only, **kwargs ) ) - def add_embeddings( - self, - texts: Iterable[str], - embeddings: list[list[float]], - metadatas: Optional[list[dict]] = None, - ids: Optional[list[str]] = None, - **kwargs: Any, - ) -> list[str]: - """Add data along with embeddings to the table.""" - return self._engine._run_as_sync( - self.__vs.aadd_embeddings(texts, embeddings, metadatas, ids, **kwargs) - ) - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[list[dict]] = None, - ids: Optional[list] = None, - **kwargs: Any, - ) -> list[str]: - """Embed texts and add to the table. - - Raises: - :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. - """ - return self._engine._run_as_sync( - self.__vs.aadd_texts(texts, metadatas, ids, **kwargs) - ) - - def add_documents( - self, - documents: list[Document], - ids: Optional[list] = None, - **kwargs: Any, - ) -> list[str]: - """Embed documents and add to the table. - - Raises: - :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. - """ - return self._engine._run_as_sync( - self.__vs.aadd_documents(documents, ids, **kwargs) - ) - def add_images( self, uris: list[str], @@ -296,547 +192,33 @@ def add_images( ) -> list[str]: """Embed images and add to the table.""" return self._engine._run_as_sync( - self.__vs.aadd_images( + self._PGVectorStore__vs.aadd_images( # type: ignore uris, metadatas, ids, store_uri_only=store_uri_only, **kwargs ) ) - async def adelete( - self, - ids: Optional[list] = None, - **kwargs: Any, - ) -> Optional[bool]: - """Delete records from the table. - - Raises: - :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. - """ - return await self._engine._run_as_async(self.__vs.adelete(ids, **kwargs)) - - def delete( - self, - ids: Optional[list] = None, - **kwargs: Any, - ) -> Optional[bool]: - """Delete records from the table. - - Raises: - :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. - """ - return self._engine._run_as_sync(self.__vs.adelete(ids, **kwargs)) - - @classmethod - async def afrom_texts( # type: ignore[override] - cls: type[AlloyDBVectorStore], - texts: list[str], - embedding: Embeddings, - engine: AlloyDBEngine, - table_name: str, - schema_name: str = "public", - metadatas: Optional[list[dict]] = None, - ids: Optional[list] = None, - content_column: str = "content", - embedding_column: str = "embedding", - metadata_columns: list[str] = [], - ignore_metadata_columns: Optional[list[str]] = None, - id_column: str = "langchain_id", - metadata_json_column: str = "langchain_metadata", - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - index_query_options: Optional[QueryOptions] = None, - **kwargs: Any, - ) -> AlloyDBVectorStore: - """Create an AlloyDBVectorStore instance from texts. - - Args: - texts (list[str]): Texts to add to the vector store. - embedding (Embeddings): Text embedding model to use. - engine (AlloyDBEngine): Connection pool engine for managing connections to AlloyDB database. - table_name (str): Name of an existing table. - schema_name (str, optional): Name of the database schema. Defaults to "public". - metadatas (Optional[list[dict]], optional): List of metadatas to add to table records. Defaults to None. - ids: (Optional[list]): List of IDs to add to table records. Defaults to None. - content_column (str, optional): Column that represent a Document’s page_content. Defaults to "content". - embedding_column (str, optional): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". - metadata_columns (list[str], optional): Column(s) that represent a document's metadata. Defaults to an empty list. - ignore_metadata_columns (Optional[list[str]], optional): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. - id_column (str, optional): Column that represents the Document's id. Defaults to "langchain_id". - metadata_json_column (str, optional): Column to store metadata as JSON. Defaults to "langchain_metadata". - distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. - k (int): Number of Documents to return from search. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. - index_query_options (QueryOptions): Index query option. - - Raises: - :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. - - Returns: - AlloyDBVectorStore - """ - vs = await cls.create( - engine, - embedding, - table_name, - schema_name=schema_name, - content_column=content_column, - embedding_column=embedding_column, - metadata_columns=metadata_columns, - ignore_metadata_columns=ignore_metadata_columns, - metadata_json_column=metadata_json_column, - id_column=id_column, - distance_strategy=distance_strategy, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - index_query_options=index_query_options, - ) - await vs.aadd_texts(texts, metadatas=metadatas, ids=ids) - return vs - - @classmethod - async def afrom_documents( # type: ignore[override] - cls: type[AlloyDBVectorStore], - documents: list[Document], - embedding: Embeddings, - engine: AlloyDBEngine, - table_name: str, - schema_name: str = "public", - ids: Optional[list] = None, - content_column: str = "content", - embedding_column: str = "embedding", - metadata_columns: list[str] = [], - ignore_metadata_columns: Optional[list[str]] = None, - id_column: str = "langchain_id", - metadata_json_column: str = "langchain_metadata", - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - index_query_options: Optional[QueryOptions] = None, - **kwargs: Any, - ) -> AlloyDBVectorStore: - """Create an AlloyDBVectorStore instance from documents. - - Args: - documents (list[Document]): Documents to add to the vector store. - embedding (Embeddings): Text embedding model to use. - engine (AlloyDBEngine): Connection pool engine for managing connections to AlloyDB database. - table_name (str): Name of an existing table. - schema_name (str, optional): Name of the database schema. Defaults to "public". - ids: (Optional[list]): List of IDs to add to table records. Defaults to None. - content_column (str, optional): Column that represent a Document’s page_content. Defaults to "content". - embedding_column (str, optional): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". - metadata_columns (list[str], optional): Column(s) that represent a document's metadata. Defaults to an empty list. - ignore_metadata_columns (Optional[list[str]], optional): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. - id_column (str, optional): Column that represents the Document's id. Defaults to "langchain_id". - metadata_json_column (str, optional): Column to store metadata as JSON. Defaults to "langchain_metadata". - distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. - k (int): Number of Documents to return from search. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. - index_query_options (QueryOptions): Index query option. - - Raises: - :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. - - Returns: - AlloyDBVectorStore - """ - - vs = await cls.create( - engine, - embedding, - table_name, - schema_name=schema_name, - content_column=content_column, - embedding_column=embedding_column, - metadata_columns=metadata_columns, - ignore_metadata_columns=ignore_metadata_columns, - metadata_json_column=metadata_json_column, - id_column=id_column, - distance_strategy=distance_strategy, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - index_query_options=index_query_options, - ) - await vs.aadd_documents(documents, ids=ids) - return vs - - @classmethod - def from_texts( # type: ignore[override] - cls: type[AlloyDBVectorStore], - texts: list[str], - embedding: Embeddings, - engine: AlloyDBEngine, - table_name: str, - schema_name: str = "public", - metadatas: Optional[list[dict]] = None, - ids: Optional[list] = None, - content_column: str = "content", - embedding_column: str = "embedding", - metadata_columns: list[str] = [], - ignore_metadata_columns: Optional[list[str]] = None, - id_column: str = "langchain_id", - metadata_json_column: str = "langchain_metadata", - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - index_query_options: Optional[QueryOptions] = None, - **kwargs: Any, - ) -> AlloyDBVectorStore: - """Create an AlloyDBVectorStore instance from texts. - - Args: - texts (list[str]): Texts to add to the vector store. - embedding (Embeddings): Text embedding model to use. - engine (AlloyDBEngine): Connection pool engine for managing connections to AlloyDB database. - table_name (str): Name of an existing table. - schema_name (str, optional): Name of the database schema. Defaults to "public". - metadatas (Optional[list[dict]], optional): List of metadatas to add to table records. Defaults to None. - ids: (Optional[list]): List of IDs to add to table records. Defaults to None. - content_column (str, optional): Column that represent a Document’s page_content. Defaults to "content". - embedding_column (str, optional): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". - metadata_columns (list[str], optional): Column(s) that represent a document's metadata. Defaults to empty list. - ignore_metadata_columns (Optional[list[str]], optional): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. - id_column (str, optional): Column that represents the Document's id. Defaults to "langchain_id". - metadata_json_column (str, optional): Column to store metadata as JSON. Defaults to "langchain_metadata". - distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. - k (int): Number of Documents to return from search. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. - index_query_options (QueryOptions): Index query option. - - Raises: - :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. - - Returns: - AlloyDBVectorStore - """ - vs = cls.create_sync( - engine, - embedding, - table_name, - schema_name=schema_name, - content_column=content_column, - embedding_column=embedding_column, - metadata_columns=metadata_columns, - ignore_metadata_columns=ignore_metadata_columns, - metadata_json_column=metadata_json_column, - id_column=id_column, - distance_strategy=distance_strategy, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - index_query_options=index_query_options, - **kwargs, - ) - vs.add_texts(texts, metadatas=metadatas, ids=ids) - return vs - - @classmethod - def from_documents( # type: ignore[override] - cls: type[AlloyDBVectorStore], - documents: list[Document], - embedding: Embeddings, - engine: AlloyDBEngine, - table_name: str, - schema_name: str = "public", - ids: Optional[list] = None, - content_column: str = "content", - embedding_column: str = "embedding", - metadata_columns: list[str] = [], - ignore_metadata_columns: Optional[list[str]] = None, - id_column: str = "langchain_id", - metadata_json_column: str = "langchain_metadata", - distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, - k: int = 4, - fetch_k: int = 20, - lambda_mult: float = 0.5, - index_query_options: Optional[QueryOptions] = None, - **kwargs: Any, - ) -> AlloyDBVectorStore: - """Create an AlloyDBVectorStore instance from documents. - - Args: - documents (list[Document]): Documents to add to the vector store. - embedding (Embeddings): Text embedding model to use. - engine (AlloyDBEngine): Connection pool engine for managing connections to AlloyDB database. - table_name (str): Name of an existing table. - schema_name (str, optional): Name of the database schema. Defaults to "public". - ids: (Optional[list]): List of IDs to add to table records. Defaults to None. - content_column (str, optional): Column that represent a Document’s page_content. Defaults to "content". - embedding_column (str, optional): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". - metadata_columns (list[str], optional): Column(s) that represent a document's metadata. Defaults to an empty list. - ignore_metadata_columns (Optional[list[str]], optional): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. - id_column (str, optional): Column that represents the Document's id. Defaults to "langchain_id". - metadata_json_column (str, optional): Column to store metadata as JSON. Defaults to "langchain_metadata". - distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. - k (int): Number of Documents to return from search. Defaults to 4. - fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. - lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. - index_query_options (QueryOptions): Index query option. - - Raises: - :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. - - Returns: - AlloyDBVectorStore - """ - vs = cls.create_sync( - engine, - embedding, - table_name, - schema_name=schema_name, - content_column=content_column, - embedding_column=embedding_column, - metadata_columns=metadata_columns, - ignore_metadata_columns=ignore_metadata_columns, - metadata_json_column=metadata_json_column, - id_column=id_column, - distance_strategy=distance_strategy, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - index_query_options=index_query_options, - **kwargs, - ) - vs.add_documents(documents, ids=ids) - return vs - - def similarity_search( - self, - query: str, - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected by similarity search on query.""" - return self._engine._run_as_sync( - self.__vs.asimilarity_search(query, k, filter, **kwargs) - ) - def similarity_search_image( self, image_uri: str, k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by similarity search on image.""" return self._engine._run_as_sync( - self.__vs.asimilarity_search_image(image_uri, k, filter, **kwargs) - ) - - async def asimilarity_search( - self, - query: str, - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected by similarity search on query.""" - return await self._engine._run_as_async( - self.__vs.asimilarity_search(query, k, filter, **kwargs) + self._PGVectorStore__vs.asimilarity_search_image(image_uri, k, filter, **kwargs) # type: ignore ) async def asimilarity_search_image( self, image_uri: str, k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, + filter: Optional[dict] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by similarity search on image_uri.""" return await self._engine._run_as_async( - self.__vs.asimilarity_search_image(image_uri, k, filter, **kwargs) - ) - - # Required for (a)similarity_search_with_relevance_scores - def _select_relevance_score_fn(self) -> Callable[[float], float]: - """Select a relevance function based on distance strategy.""" - # Calculate distance strategy provided in vectorstore constructor - if self.__vs.distance_strategy == DistanceStrategy.COSINE_DISTANCE: - return self._cosine_relevance_score_fn - if self.__vs.distance_strategy == DistanceStrategy.INNER_PRODUCT: - return self._max_inner_product_relevance_score_fn - elif self.__vs.distance_strategy == DistanceStrategy.EUCLIDEAN: - return self._euclidean_relevance_score_fn - - async def asimilarity_search_with_score( - self, - query: str, - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs and distance scores selected by similarity search on query.""" - return await self._engine._run_as_async( - self.__vs.asimilarity_search_with_score(query, k, filter, **kwargs) - ) - - async def asimilarity_search_by_vector( - self, - embedding: list[float], - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected by vector similarity search.""" - return await self._engine._run_as_async( - self.__vs.asimilarity_search_by_vector(embedding, k, filter, **kwargs) - ) - - async def asimilarity_search_with_score_by_vector( - self, - embedding: list[float], - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs and distance scores selected by vector similarity search.""" - return await self._engine._run_as_async( - self.__vs.asimilarity_search_with_score_by_vector( - embedding, k, filter, **kwargs - ) - ) - - async def amax_marginal_relevance_search( - self, - query: str, - k: Optional[int] = None, - fetch_k: Optional[int] = None, - lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance.""" - return await self._engine._run_as_async( - self.__vs.amax_marginal_relevance_search( - query, k, fetch_k, lambda_mult, filter, **kwargs - ) - ) - - async def amax_marginal_relevance_search_by_vector( - self, - embedding: list[float], - k: Optional[int] = None, - fetch_k: Optional[int] = None, - lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance.""" - return await self._engine._run_as_async( - self.__vs.amax_marginal_relevance_search_by_vector( - embedding, k, fetch_k, lambda_mult, filter, **kwargs - ) - ) - - async def amax_marginal_relevance_search_with_score_by_vector( - self, - embedding: list[float], - k: Optional[int] = None, - fetch_k: Optional[int] = None, - lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs and distance scores selected using the maximal marginal relevance.""" - return await self._engine._run_as_async( - self.__vs.amax_marginal_relevance_search_with_score_by_vector( - embedding, k, fetch_k, lambda_mult, filter, **kwargs - ) - ) - - def similarity_search_with_score( - self, - query: str, - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs and distance scores selected by similarity search on query.""" - return self._engine._run_as_sync( - self.__vs.asimilarity_search_with_score(query, k, filter, **kwargs) - ) - - def similarity_search_by_vector( - self, - embedding: list[float], - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected by vector similarity search.""" - return self._engine._run_as_sync( - self.__vs.asimilarity_search_by_vector(embedding, k, filter, **kwargs) - ) - - def similarity_search_with_score_by_vector( - self, - embedding: list[float], - k: Optional[int] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs and distance scores selected by similarity search on vector.""" - return self._engine._run_as_sync( - self.__vs.asimilarity_search_with_score_by_vector( - embedding, k, filter, **kwargs - ) - ) - - def max_marginal_relevance_search( - self, - query: str, - k: Optional[int] = None, - fetch_k: Optional[int] = None, - lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance.""" - return self._engine._run_as_sync( - self.__vs.amax_marginal_relevance_search( - query, k, fetch_k, lambda_mult, filter, **kwargs - ) - ) - - def max_marginal_relevance_search_by_vector( - self, - embedding: list[float], - k: Optional[int] = None, - fetch_k: Optional[int] = None, - lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[Document]: - """Return docs selected using the maximal marginal relevance.""" - return self._engine._run_as_sync( - self.__vs.amax_marginal_relevance_search_by_vector( - embedding, k, fetch_k, lambda_mult, filter, **kwargs - ) - ) - - def max_marginal_relevance_search_with_score_by_vector( - self, - embedding: list[float], - k: Optional[int] = None, - fetch_k: Optional[int] = None, - lambda_mult: Optional[float] = None, - filter: Optional[dict] | Optional[str] = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: - """Return docs and distance scores selected using the maximal marginal relevance.""" - return self._engine._run_as_sync( - self.__vs.amax_marginal_relevance_search_with_score_by_vector( - embedding, k, fetch_k, lambda_mult, filter, **kwargs - ) + self._PGVectorStore__vs.asimilarity_search_image(image_uri, k, filter, **kwargs) # type: ignore ) async def aset_maintenance_work_mem( @@ -844,82 +226,11 @@ async def aset_maintenance_work_mem( ) -> None: """Set database maintenance work memory (for ScaNN index creation).""" await self._engine._run_as_async( - self.__vs.set_maintenance_work_mem(num_leaves, vector_size) + self._PGVectorStore__vs.set_maintenance_work_mem(num_leaves, vector_size) # type: ignore ) def set_maintenance_work_mem(self, num_leaves: int, vector_size: int) -> None: """Set database maintenance work memory (for ScaNN index creation).""" self._engine._run_as_sync( - self.__vs.set_maintenance_work_mem(num_leaves, vector_size) + self._PGVectorStore__vs.set_maintenance_work_mem(num_leaves, vector_size) # type: ignore ) - - async def aapply_vector_index( - self, - index: BaseIndex, - name: Optional[str] = None, - concurrently: bool = False, - ) -> None: - """Create an index on the vector store table.""" - return await self._engine._run_as_async( - self.__vs.aapply_vector_index(index, name, concurrently) - ) - - def apply_vector_index( - self, - index: BaseIndex, - name: Optional[str] = None, - concurrently: bool = False, - ) -> None: - """Create an index on the vector store table.""" - return self._engine._run_as_sync( - self.__vs.aapply_vector_index(index, name, concurrently) - ) - - async def areindex(self, index_name: Optional[str] = None) -> None: - """Re-index the vector store table.""" - return await self._engine._run_as_async(self.__vs.areindex(index_name)) - - def reindex(self, index_name: Optional[str] = None) -> None: - """Re-index the vector store table.""" - return self._engine._run_as_sync(self.__vs.areindex(index_name)) - - async def adrop_vector_index( - self, - index_name: Optional[str] = None, - ) -> None: - """Drop the vector index.""" - return await self._engine._run_as_async( - self.__vs.adrop_vector_index(index_name) - ) - - def drop_vector_index( - self, - index_name: Optional[str] = None, - ) -> None: - """Drop the vector index.""" - return self._engine._run_as_sync(self.__vs.adrop_vector_index(index_name)) - - async def ais_valid_index( - self, - index_name: Optional[str] = None, - ) -> bool: - """Check if index exists in the table.""" - return await self._engine._run_as_async(self.__vs.is_valid_index(index_name)) - - def is_valid_index( - self, - index_name: Optional[str] = None, - ) -> bool: - """Check if index exists in the table.""" - return self._engine._run_as_sync(self.__vs.is_valid_index(index_name)) - - async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]: - """Get documents by ids.""" - return await self._engine._run_as_async(self.__vs.aget_by_ids(ids=ids)) - - def get_by_ids(self, ids: Sequence[str]) -> list[Document]: - """Get documents by ids.""" - return self._engine._run_as_sync(self.__vs.aget_by_ids(ids=ids)) - - def get_table_name(self) -> str: - return self.__vs.table_name diff --git a/tests/test_async_vectorstore.py b/tests/test_async_vectorstore.py index 66858280..8dc93762 100644 --- a/tests/test_async_vectorstore.py +++ b/tests/test_async_vectorstore.py @@ -30,8 +30,8 @@ DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()) -CUSTOM_TABLE = "test-table-custom" + str(uuid.uuid4()) -IMAGE_TABLE = "test_image_table" + str(uuid.uuid4()) +CUSTOM_TABLE = "custom" + str(uuid.uuid4()) +IMAGE_TABLE = "image" + str(uuid.uuid4()) VECTOR_SIZE = 768 embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) @@ -111,6 +111,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): yield engine await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') await aexecute(engine, f'DROP TABLE IF EXISTS "{CUSTOM_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{IMAGE_TABLE}"') await engine.close() @pytest_asyncio.fixture(scope="class") diff --git a/tests/test_async_vectorstore_index.py b/tests/test_async_vectorstore_index.py index 317f3559..c5953335 100644 --- a/tests/test_async_vectorstore_index.py +++ b/tests/test_async_vectorstore_index.py @@ -25,15 +25,19 @@ from langchain_google_alloydb_pg import AlloyDBEngine from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore +from langchain_google_alloydb_pg.hybrid_search_config import HybridSearchConfig from langchain_google_alloydb_pg.indexes import ( DEFAULT_INDEX_NAME_SUFFIX, DistanceStrategy, HNSWIndex, IVFFlatIndex, + IVFIndex, ) -DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") -DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX +UUID_STR = str(uuid.uuid4()).replace("-", "_") +DEFAULT_TABLE = "table" + UUID_STR +DEFAULT_HYBRID_TABLE = "hybrid" + UUID_STR +DEFAULT_INDEX_NAME = DEFAULT_INDEX_NAME_SUFFIX + UUID_STR VECTOR_SIZE = 768 embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) @@ -94,6 +98,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): ) yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") + await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_HYBRID_TABLE}") await engine.close() @pytest_asyncio.fixture(scope="class") @@ -109,6 +114,22 @@ async def vs(self, engine): await vs.adrop_vector_index() yield vs + async def test_aapply_vector_index_ivf(self, vs): + index = IVFIndex( + name=DEFAULT_INDEX_NAME, + distance_strategy=DistanceStrategy.EUCLIDEAN, + ) + await vs.aapply_vector_index(index, concurrently=True) + assert await vs.is_valid_index(DEFAULT_INDEX_NAME) + index = IVFIndex( + name="secondindex", + distance_strategy=DistanceStrategy.INNER_PRODUCT, + ) + await vs.aapply_vector_index(index) + assert await vs.is_valid_index("secondindex") + await vs.adrop_vector_index("secondindex") + await vs.adrop_vector_index() + async def test_aapply_vector_index(self, vs): index = HNSWIndex() await vs.aapply_vector_index(index) @@ -119,18 +140,20 @@ async def test_areindex(self, vs): if not await vs.is_valid_index(DEFAULT_INDEX_NAME): index = HNSWIndex() await vs.aapply_vector_index(index) - await vs.areindex() + await vs.areindex(DEFAULT_INDEX_NAME) await vs.areindex(DEFAULT_INDEX_NAME) assert await vs.is_valid_index(DEFAULT_INDEX_NAME) await vs.adrop_vector_index(DEFAULT_INDEX_NAME) async def test_dropindex(self, vs): - await vs.adrop_vector_index() + await vs.adrop_vector_index(DEFAULT_INDEX_NAME) result = await vs.is_valid_index(DEFAULT_INDEX_NAME) assert not result async def test_aapply_vector_index_ivfflat(self, vs): - index = IVFFlatIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + index = IVFFlatIndex( + name=DEFAULT_INDEX_NAME, distance_strategy=DistanceStrategy.EUCLIDEAN + ) await vs.aapply_vector_index(index, concurrently=True) assert await vs.is_valid_index(DEFAULT_INDEX_NAME) index = IVFFlatIndex( @@ -145,3 +168,61 @@ async def test_aapply_vector_index_ivfflat(self, vs): async def test_is_valid_index(self, vs): is_valid = await vs.is_valid_index("invalid_index") assert is_valid == False + + async def test_apply_default_name_vector_index(self, vs): + await vs.adrop_vector_index(DEFAULT_INDEX_NAME) + index = HNSWIndex() + await vs.aapply_vector_index(index) + assert await vs.is_valid_index() + await vs.adrop_vector_index() + + async def test_aapply_vector_index_non_hybrid_search_vs(self, vs): + with pytest.raises(ValueError): + await vs.aapply_hybrid_search_index() + + async def test_aapply_hybrid_search_index_table_without_tsv_column( + self, engine, vs + ): + # overwriting vs to get a hybrid vs + tsv_index_name = "index_without_tsv_column_" + UUID_STR + vs = await AsyncAlloyDBVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + hybrid_search_config=HybridSearchConfig(index_name=tsv_index_name), + ) + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + await vs.aapply_hybrid_search_index() + assert await vs.is_valid_index(tsv_index_name) + await vs.adrop_vector_index(tsv_index_name) + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + + async def test_aapply_hybrid_search_index_table_with_tsv_column(self, engine): + tsv_index_name = "index_without_tsv_column_" + UUID_STR + config = HybridSearchConfig( + tsv_column="tsv_column", + tsv_lang="pg_catalog.english", + index_name=tsv_index_name, + ) + await engine._ainit_vectorstore_table( + DEFAULT_HYBRID_TABLE, + VECTOR_SIZE, + hybrid_search_config=config, + ) + vs = await AsyncAlloyDBVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_HYBRID_TABLE, + hybrid_search_config=config, + ) + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + await vs.aapply_hybrid_search_index() + assert await vs.is_valid_index(tsv_index_name) + await vs.areindex(tsv_index_name) + assert await vs.is_valid_index(tsv_index_name) + await vs.adrop_vector_index(tsv_index_name) + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False diff --git a/tests/test_async_vectorstore_search.py b/tests/test_async_vectorstore_search.py index 50946c79..74606abd 100644 --- a/tests/test_async_vectorstore_search.py +++ b/tests/test_async_vectorstore_search.py @@ -25,6 +25,11 @@ from langchain_google_alloydb_pg import AlloyDBEngine, Column from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore +from langchain_google_alloydb_pg.hybrid_search_config import ( + HybridSearchConfig, + reciprocal_rank_fusion, + weighted_sum_ranking, +) from langchain_google_alloydb_pg.indexes import ( DistanceStrategy, HNSWQueryOptions, @@ -33,8 +38,10 @@ DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") -IMAGE_TABLE = "test_image_table" + str(uuid.uuid4()).replace("-", "_") -CUSTOM_FILTER_TABLE = "test_table_custom_filter" + str(uuid.uuid4()).replace("-", "_") +IMAGE_TABLE = "image" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_FILTER_TABLE = "custom_filter" + str(uuid.uuid4()).replace("-", "_") +HYBRID_SEARCH_TABLE1 = "hybrid1" + str(uuid.uuid4()).replace("-", "_") +HYBRID_SEARCH_TABLE2 = "hybrid2" + str(uuid.uuid4()).replace("-", "_") VECTOR_SIZE = 768 sync_method_exception_str = "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." @@ -57,6 +64,19 @@ Document(page_content=texts[i], metadata=METADATAS[i]) for i in range(len(texts)) ] +# Documents designed for hybrid search testing +hybrid_docs_content = { + "hs_doc_apple_fruit": "An apple is a sweet and edible fruit produced by an apple tree. Apples are very common.", + "hs_doc_apple_tech": "Apple Inc. is a multinational technology company. Their latest tech is amazing.", + "hs_doc_orange_fruit": "The orange is the fruit of various citrus species. Oranges are tasty.", + "hs_doc_generic_tech": "Technology drives innovation in the modern world. Tech is evolving.", + "hs_doc_unrelated_cat": "A fluffy cat sat on a mat quietly observing a mouse.", +} +hybrid_docs = [ + Document(page_content=content, metadata={"doc_id_key": key}) + for key, content in hybrid_docs_content.items() +] + class FakeImageEmbedding(DeterministicFakeEmbedding): @@ -118,6 +138,9 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE}") await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE}") + await aexecute(engine, f"DROP TABLE IF EXISTS {IMAGE_TABLE}") + await aexecute(engine, f"DROP TABLE IF EXISTS {HYBRID_SEARCH_TABLE1}") + await aexecute(engine, f"DROP TABLE IF EXISTS {HYBRID_SEARCH_TABLE2}") await engine.close() @pytest_asyncio.fixture(scope="class") @@ -244,11 +267,54 @@ async def vs_custom_filter(self, engine): await vs_custom_filter.aadd_documents(filter_docs, ids=ids) yield vs_custom_filter + @pytest_asyncio.fixture(scope="class") + async def vs_hybrid_search_with_tsv_column(self, engine): + hybrid_search_config = HybridSearchConfig( + tsv_column="mycontent_tsv", + tsv_lang="pg_catalog.english", + fts_query="my_fts_query", + fusion_function=reciprocal_rank_fusion, + fusion_function_parameters={ + "rrf_k": 60, + "fetch_top_k": 10, + }, + ) + await engine._ainit_vectorstore_table( + HYBRID_SEARCH_TABLE1, + VECTOR_SIZE, + id_column=Column("myid", "TEXT"), + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[ + Column("page", "TEXT"), + Column("source", "TEXT"), + Column("doc_id_key", "TEXT"), + ], + metadata_json_column="mymetadata", # ignored + store_metadata=False, + hybrid_search_config=hybrid_search_config, + ) + + vs_custom = await AsyncAlloyDBVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=HYBRID_SEARCH_TABLE1, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_json_column="mymetadata", + metadata_columns=["doc_id_key"], + index_query_options=HNSWQueryOptions(ef_search=1), + hybrid_search_config=hybrid_search_config, + ) + await vs_custom.aadd_documents(hybrid_docs) + yield vs_custom + async def test_asimilarity_search(self, vs): results = await vs.asimilarity_search("foo", k=1) assert len(results) == 1 assert results == [Document(page_content="foo", id=ids[0])] - results = await vs.asimilarity_search("foo", k=1, filter="content = 'bar'") + results = await vs.asimilarity_search("foo", k=1, filter={"content": "bar"}) assert results == [Document(page_content="bar", id=ids[1])] async def test_asimilarity_search_scann(self, vs_custom_scann_query_option): @@ -256,7 +322,7 @@ async def test_asimilarity_search_scann(self, vs_custom_scann_query_option): assert len(results) == 1 assert results == [Document(page_content="foo", id=ids[0])] results = await vs_custom_scann_query_option.asimilarity_search( - "foo", k=1, filter="mycontent = 'bar'" + "foo", k=1, filter={"mycontent": "bar"} ) assert results == [Document(page_content="bar", id=ids[1])] @@ -333,7 +399,7 @@ async def test_amax_marginal_relevance_search(self, vs): results = await vs.amax_marginal_relevance_search("bar") assert results[0] == Document(page_content="bar", id=ids[1]) results = await vs.amax_marginal_relevance_search( - "bar", filter="content = 'boo'" + "bar", filter={"content": "boo"} ) assert results[0] == Document(page_content="boo", id=ids[3]) @@ -359,7 +425,7 @@ async def test_similarity_search(self, vs_custom): assert len(results) == 1 assert results == [Document(page_content="foo", id=ids[0])] results = await vs_custom.asimilarity_search( - "foo", k=1, filter="mycontent = 'bar'" + "foo", k=1, filter={"mycontent": "bar"} ) assert results == [Document(page_content="bar", id=ids[1])] @@ -386,7 +452,7 @@ async def test_max_marginal_relevance_search(self, vs_custom): results = await vs_custom.amax_marginal_relevance_search("bar") assert results[0] == Document(page_content="bar", id=ids[1]) results = await vs_custom.amax_marginal_relevance_search( - "bar", filter="mycontent = 'boo'" + "bar", filter={"mycontent": "boo"} ) assert results[0] == Document(page_content="boo", id=ids[3]) @@ -419,11 +485,6 @@ async def test_aget_by_ids_custom_vs(self, vs_custom): assert results[0] == Document(page_content="foo", id=ids[0]) - def test_get_by_ids(self, vs): - test_ids = [ids[0]] - with pytest.raises(Exception, match=sync_method_exception_str): - vs.get_by_ids(ids=test_ids) - @pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES) async def test_vectorstore_with_metadata_filters( self, @@ -436,3 +497,319 @@ async def test_vectorstore_with_metadata_filters( "meow", k=5, filter=test_filter ) assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + + async def test_asimilarity_hybrid_search_rrk(self, vs): + results = await vs.asimilarity_search( + "foo", + k=1, + hybrid_search_config=HybridSearchConfig( + fusion_function=reciprocal_rank_fusion + ), + ) + assert len(results) == 1 + assert results == [Document(page_content="foo", id=ids[0])] + + results = await vs.asimilarity_search( + "bar", + k=1, + filter={"content": {"$ne": "baz"}}, + hybrid_search_config=HybridSearchConfig( + fusion_function=reciprocal_rank_fusion, + fusion_function_parameters={ + "rrf_k": 100, + "fetch_top_k": 10, + }, + primary_top_k=1, + secondary_top_k=1, + ), + ) + assert results == [Document(page_content="bar", id=ids[1])] + + async def test_hybrid_search_weighted_sum_default( + self, vs_hybrid_search_with_tsv_column + ): + """Test hybrid search with default weighted sum (0.5 vector, 0.5 FTS).""" + query = "apple" # Should match "apple" in FTS and vector + + # The vs_hybrid_search_with_tsv_column instance is already configured for hybrid search. + # Default fusion is weighted_sum_ranking with 0.5/0.5 weights. + # fts_query will default to the main query. + results_with_scores = ( + await vs_hybrid_search_with_tsv_column.asimilarity_search_with_score( + query, k=3 + ) + ) + + assert len(results_with_scores) > 1 + result_ids = [doc.metadata["doc_id_key"] for doc, score in results_with_scores] + + # Expect "hs_doc_apple_fruit" and "hs_doc_apple_tech" to be highly ranked. + assert "hs_doc_apple_fruit" in result_ids + + # Scores should be floats (fused scores) + for doc, score in results_with_scores: + assert isinstance(score, float) + + # Check if sorted by score (descending for weighted_sum_ranking with positive scores) + assert results_with_scores[0][1] >= results_with_scores[1][1] + + async def test_hybrid_search_weighted_sum_vector_bias( + self, vs_hybrid_search_with_tsv_column + ): + """Test weighted sum with higher weight for vector results.""" + query = "Apple Inc technology" # More specific for vector similarity + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", # Must match table setup + fusion_function_parameters={ + "primary_results_weight": 0.8, # Vector bias + "secondary_results_weight": 0.2, + }, + # fts_query will default to main query + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + assert len(result_ids) > 0 + assert result_ids[0] == "hs_doc_orange_fruit" + + async def test_hybrid_search_weighted_sum_fts_bias( + self, vs_hybrid_search_with_tsv_column + ): + """Test weighted sum with higher weight for FTS results.""" + query = "fruit common tasty" # Strong FTS signal for fruit docs + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fusion_function=weighted_sum_ranking, + fusion_function_parameters={ + "primary_results_weight": 0.01, + "secondary_results_weight": 0.99, # FTS bias + }, + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + assert len(result_ids) == 2 + assert "hs_doc_apple_fruit" in result_ids + + async def test_hybrid_search_reciprocal_rank_fusion( + self, vs_hybrid_search_with_tsv_column + ): + """Test hybrid search with Reciprocal Rank Fusion.""" + query = "technology company" + + # Configure RRF. primary_top_k and secondary_top_k control inputs to fusion. + # fusion_function_parameters.fetch_top_k controls output count from RRF. + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fusion_function=reciprocal_rank_fusion, + primary_top_k=3, # How many dense results to consider + secondary_top_k=3, # How many sparse results to consider + fusion_function_parameters={ + "rrf_k": 60, + "fetch_top_k": 2, + }, # RRF specific params + ) + # The `k` in asimilarity_search here is the final desired number of results, + # which should align with fusion_function_parameters.fetch_top_k for RRF. + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + assert len(result_ids) == 2 + # "hs_doc_apple_tech" (FTS: technology, company; Vector: Apple Inc technology) + # "hs_doc_generic_tech" (FTS: technology; Vector: Technology drives innovation) + # RRF should combine these ranks. "hs_doc_apple_tech" is likely higher. + assert "hs_doc_apple_tech" in result_ids + assert result_ids[0] == "hs_doc_apple_tech" # Stronger combined signal + + async def test_hybrid_search_explicit_fts_query( + self, vs_hybrid_search_with_tsv_column + ): + """Test hybrid search when fts_query in HybridSearchConfig is different from main query.""" + main_vector_query = "Apple Inc." # For vector search + fts_specific_query = "fruit" # For FTS + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fts_query=fts_specific_query, # Override FTS query + fusion_function_parameters={ # Using default weighted_sum_ranking + "primary_results_weight": 0.5, + "secondary_results_weight": 0.5, + }, + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + main_vector_query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + # Vector search for "Apple Inc.": hs_doc_apple_tech + # FTS search for "fruit": hs_doc_apple_fruit, hs_doc_orange_fruit + # Combined: hs_doc_apple_fruit (strong FTS) and hs_doc_apple_tech (strong vector) are candidates. + # "hs_doc_apple_fruit" might get a boost if "Apple Inc." vector has some similarity to "apple fruit" doc. + assert len(result_ids) > 0 + assert ( + "hs_doc_apple_fruit" in result_ids + or "hs_doc_apple_tech" in result_ids + or "hs_doc_orange_fruit" in result_ids + ) + + async def test_hybrid_search_with_filter(self, vs_hybrid_search_with_tsv_column): + """Test hybrid search with a metadata filter applied.""" + query = "apple" + # Filter to only include "tech" related apple docs using metadata + # Assuming metadata_columns=["doc_id_key"] was set up for vs_hybrid_search_with_tsv_column + doc_filter = {"doc_id_key": {"$eq": "hs_doc_apple_tech"}} + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, filter=doc_filter, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + assert len(results) == 1 + assert result_ids[0] == "hs_doc_apple_tech" + + async def test_hybrid_search_fts_empty_results( + self, vs_hybrid_search_with_tsv_column + ): + """Test when FTS query yields no results, should fall back to vector search.""" + vector_query = "apple" + no_match_fts_query = "zzyyxx_gibberish_term_for_fts_nomatch" + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fts_query=no_match_fts_query, + fusion_function_parameters={ + "primary_results_weight": 0.6, + "secondary_results_weight": 0.4, + }, + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + vector_query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + # Expect results based purely on vector search for "apple" + assert len(result_ids) > 0 + assert "hs_doc_apple_fruit" in result_ids or "hs_doc_apple_tech" in result_ids + # The top result should be one of the apple documents based on vector search + assert results[0].metadata["doc_id_key"].startswith("hs_doc_unrelated_cat") + + async def test_hybrid_search_vector_empty_results_effectively( + self, vs_hybrid_search_with_tsv_column + ): + """Test when vector query is very dissimilar to docs, should rely on FTS.""" + # This is hard to guarantee with fake embeddings, but we try. + # A better way might be to use a filter that excludes all docs for the vector part, + # but filters are applied to both. + vector_query_far_off = "supercalifragilisticexpialidocious_vector_nomatch" + fts_query_match = "orange fruit" # Should match hs_doc_orange_fruit + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fts_query=fts_query_match, + fusion_function_parameters={ + "primary_results_weight": 0.4, + "secondary_results_weight": 0.6, + }, + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + vector_query_far_off, k=1, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + # Expect results based purely on FTS search for "orange fruit" + assert len(result_ids) == 1 + assert result_ids[0] == "hs_doc_generic_tech" + + async def test_hybrid_search_without_tsv_column(self, engine): + """Test hybrid search without a TSV column.""" + # This is hard to guarantee with fake embeddings, but we try. + # A better way might be to use a filter that excludes all docs for the vector part, + # but filters are applied to both. + vector_query_far_off = "apple iphone tech is better designed than macs" + fts_query_match = "apple fruit" + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fts_query=fts_query_match, + fusion_function_parameters={ + "primary_results_weight": 0.1, + "secondary_results_weight": 0.9, + }, + ) + await engine._ainit_vectorstore_table( + HYBRID_SEARCH_TABLE2, + VECTOR_SIZE, + id_column=Column("myid", "TEXT"), + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[ + Column("page", "TEXT"), + Column("source", "TEXT"), + Column("doc_id_key", "TEXT"), + ], + store_metadata=False, + hybrid_search_config=config, + ) + + vs_with_tsv_column = await AsyncAlloyDBVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=HYBRID_SEARCH_TABLE2, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["doc_id_key"], + index_query_options=HNSWQueryOptions(ef_search=1), + hybrid_search_config=config, + ) + await vs_with_tsv_column.aadd_documents(hybrid_docs) + + config = HybridSearchConfig( + tsv_column="", # no TSV column + fts_query=fts_query_match, + fusion_function_parameters={ + "primary_results_weight": 0.9, + "secondary_results_weight": 0.1, + }, + ) + vs_without_tsv_column = await AsyncAlloyDBVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=HYBRID_SEARCH_TABLE2, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["doc_id_key"], + index_query_options=HNSWQueryOptions(ef_search=1), + hybrid_search_config=config, + ) + + results_with_tsv_column = await vs_with_tsv_column.asimilarity_search( + vector_query_far_off, k=1, hybrid_search_config=config + ) + results_without_tsv_column = await vs_without_tsv_column.asimilarity_search( + vector_query_far_off, k=1, hybrid_search_config=config + ) + result_ids_with_tsv_column = [ + doc.metadata["doc_id_key"] for doc in results_with_tsv_column + ] + result_ids_without_tsv_column = [ + doc.metadata["doc_id_key"] for doc in results_without_tsv_column + ] + + # Expect results based purely on FTS search for "orange fruit" + assert len(result_ids_with_tsv_column) == 1 + assert len(result_ids_without_tsv_column) == 1 + assert result_ids_with_tsv_column[0] == "hs_doc_apple_tech" + assert result_ids_without_tsv_column[0] == "hs_doc_apple_tech" diff --git a/tests/test_engine.py b/tests/test_engine.py index 3faa1afa..7f9ec509 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -28,15 +28,18 @@ from sqlalchemy.pool import NullPool from langchain_google_alloydb_pg import AlloyDBEngine, Column +from langchain_google_alloydb_pg.hybrid_search_config import HybridSearchConfig DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") INT_ID_CUSTOM_TABLE = "test_table_custom_int_id" + str(uuid.uuid4()).replace("-", "_") +HYBRID_SEARCH_TABLE = "hybrid" + str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE_SYNC = "test_table" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE_SYNC = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") INT_ID_CUSTOM_TABLE_SYNC = "test_table_custom_int_id" + str(uuid.uuid4()).replace( "-", "_" ) +HYBRID_SEARCH_TABLE_SYNC = "hybrid_sync" + str(uuid.uuid4()).replace("-", "_") VECTOR_SIZE = 768 embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) @@ -125,6 +128,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{CUSTOM_TABLE}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE}"') await aexecute(engine, f'DROP TABLE "{INT_ID_CUSTOM_TABLE}"') + await aexecute(engine, f'DROP TABLE "{HYBRID_SEARCH_TABLE}"') await engine.close() async def test_init_table(self, engine): @@ -369,6 +373,31 @@ async def test_ainit_checkpoint_writes_table(self, engine): await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name_writes}"') + async def test_init_table_hybrid_search(self, engine): + await engine.ainit_vectorstore_table( + HYBRID_SEARCH_TABLE, + VECTOR_SIZE, + id_column="uuid", + content_column="my-content", + embedding_column="my_embedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=True, + hybrid_search_config=HybridSearchConfig(), + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{HYBRID_SEARCH_TABLE}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "uuid", "data_type": "uuid"}, + {"column_name": "my_embedding", "data_type": "USER-DEFINED"}, + {"column_name": "langchain_metadata", "data_type": "json"}, + {"column_name": "my-content", "data_type": "text"}, + {"column_name": "my-content_tsv", "data_type": "tsvector"}, + {"column_name": "page", "data_type": "text"}, + {"column_name": "source", "data_type": "text"}, + ] + for row in results: + assert row in expected + @pytest.mark.asyncio class TestEngineSync: @@ -417,6 +446,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{CUSTOM_TABLE_SYNC}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE_SYNC}"') await aexecute(engine, f'DROP TABLE "{INT_ID_CUSTOM_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE "{HYBRID_SEARCH_TABLE_SYNC}"') await engine.close() async def test_init_table(self, engine): @@ -563,3 +593,28 @@ async def test_init_checkpoints_table(self, engine): assert row in expected await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name_writes}"') + + async def test_init_table_hybrid_search(self, engine): + engine.init_vectorstore_table( + HYBRID_SEARCH_TABLE_SYNC, + VECTOR_SIZE, + id_column="uuid", + content_column="my-content", + embedding_column="my_embedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=True, + hybrid_search_config=HybridSearchConfig(), + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{HYBRID_SEARCH_TABLE_SYNC}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "uuid", "data_type": "uuid"}, + {"column_name": "my_embedding", "data_type": "USER-DEFINED"}, + {"column_name": "langchain_metadata", "data_type": "json"}, + {"column_name": "my-content", "data_type": "text"}, + {"column_name": "my-content_tsv", "data_type": "tsvector"}, + {"column_name": "page", "data_type": "text"}, + {"column_name": "source", "data_type": "text"}, + ] + for row in results: + assert row in expected diff --git a/tests/test_standard_test_suite.py b/tests/test_standard_test_suite.py index 93942d66..42497ae9 100644 --- a/tests/test_standard_test_suite.py +++ b/tests/test_standard_test_suite.py @@ -23,8 +23,8 @@ from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column -DEFAULT_TABLE = "test_table_standard_test_suite" + str(uuid.uuid4()) -DEFAULT_TABLE_SYNC = "test_table_sync_standard_test_suite" + str(uuid.uuid4()) +DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) +DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()) def get_env_var(key: str, desc: str) -> str: diff --git a/tests/test_vectorstore.py b/tests/test_vectorstore.py index 43c873ed..0ec411ff 100644 --- a/tests/test_vectorstore.py +++ b/tests/test_vectorstore.py @@ -33,9 +33,9 @@ DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()) -CUSTOM_TABLE = "test-table-custom" + str(uuid.uuid4()) -IMAGE_TABLE = "test_image_table" + str(uuid.uuid4()) -IMAGE_TABLE_SYNC = "test_image_table_sync" + str(uuid.uuid4()) +CUSTOM_TABLE = "custom" + str(uuid.uuid4()) +IMAGE_TABLE = "image" + str(uuid.uuid4()) +IMAGE_TABLE_SYNC = "image_sync" + str(uuid.uuid4()) VECTOR_SIZE = 768 embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) @@ -387,20 +387,18 @@ async def test_aadd_images_store_uri_only(self, engine_sync, image_uris): results = await afetch(engine_sync, f'SELECT * FROM "{table_name}"') assert len(results) == len(image_uris) for i, result_row in enumerate(results): - assert ( - result_row[vs._AlloyDBVectorStore__vs.content_column] == image_uris[i] - ) + assert result_row[vs._PGVectorStore__vs.content_column] == image_uris[i] uri_embedding = embeddings_service.embed_query(image_uris[i]) image_embedding = image_embedding_service.embed_image([image_uris[i]])[0] actual_embedding = json.loads( - result_row[vs._AlloyDBVectorStore__vs.embedding_column] + result_row[vs._PGVectorStore__vs.embedding_column] ) assert actual_embedding != pytest.approx(uri_embedding) assert actual_embedding == pytest.approx(image_embedding) assert result_row["image_id"] == str(i) assert result_row["source"] == "google.com" assert ( - result_row[vs._AlloyDBVectorStore__vs.metadata_json_column]["image_uri"] + result_row[vs._PGVectorStore__vs.metadata_json_column]["image_uri"] == image_uris[i] ) await aexecute(engine_sync, f'DROP TABLE IF EXISTS "{table_name}"') @@ -475,20 +473,18 @@ async def test_add_images_store_uri_only(self, engine_sync, image_uris): results = await afetch(engine_sync, (f'SELECT * FROM "{table_name}"')) assert len(results) == len(image_uris) for i, result_row in enumerate(results): - assert ( - result_row[vs._AlloyDBVectorStore__vs.content_column] == image_uris[i] - ) + assert result_row[vs._PGVectorStore__vs.content_column] == image_uris[i] uri_embedding = embeddings_service.embed_query(image_uris[i]) image_embedding = image_embedding_service.embed_image([image_uris[i]])[0] actual_embedding = json.loads( - result_row[vs._AlloyDBVectorStore__vs.embedding_column] + result_row[vs._PGVectorStore__vs.embedding_column] ) assert actual_embedding != pytest.approx(uri_embedding) assert actual_embedding == pytest.approx(image_embedding) assert result_row["image_id"] == str(i) assert result_row["source"] == "google.com" assert ( - result_row[vs._AlloyDBVectorStore__vs.metadata_json_column]["image_uri"] + result_row[vs._PGVectorStore__vs.metadata_json_column]["image_uri"] == image_uris[i] ) await vs.adelete(ids) diff --git a/tests/test_vectorstore_embeddings.py b/tests/test_vectorstore_embeddings.py index d65e0c14..4db39633 100644 --- a/tests/test_vectorstore_embeddings.py +++ b/tests/test_vectorstore_embeddings.py @@ -183,7 +183,7 @@ async def test_asimilarity_search(self, vs): results = await vs.asimilarity_search("foo", k=1) assert len(results) == 1 assert results == [Document(page_content="foo", id=ids[0])] - results = await vs.asimilarity_search("foo", k=1, filter="content = 'bar'") + results = await vs.asimilarity_search("foo", k=1, filter={"content": "bar"}) assert results == [Document(page_content="bar", id=ids[1])] async def test_asimilarity_search_score(self, vs): @@ -242,7 +242,7 @@ async def test_amax_marginal_relevance_search(self, vs): results = await vs.amax_marginal_relevance_search("bar") assert results[0] == Document(page_content="bar", id=ids[1]) results = await vs.amax_marginal_relevance_search( - "bar", filter="content = 'boo'" + "bar", filter={"content": "boo"} ) assert results[0] == Document(page_content="boo", id=ids[3]) @@ -342,8 +342,8 @@ def test_similarity_search(self, vs_custom): results = vs_custom.similarity_search("foo", k=1) assert len(results) == 1 assert results == [Document(page_content="foo", id=ids[0])] - results = vs_custom.similarity_search("foo", k=1, filter="mycontent = 'bar'") - assert results == [Document(page_content="bar", id=ids[1])] + results = vs_custom.similarity_search("foo", k=1, filter={"mycontent": "boo"}) + assert results == [Document(page_content="boo", id=ids[3])] def test_similarity_search_score(self, vs_custom): results = vs_custom.similarity_search_with_score("foo") @@ -364,7 +364,7 @@ def test_max_marginal_relevance_search(self, vs_custom): results = vs_custom.max_marginal_relevance_search("bar") assert results[0] == Document(page_content="bar", id=ids[1]) results = vs_custom.max_marginal_relevance_search( - "bar", filter="mycontent = 'boo'" + "bar", filter={"mycontent": "boo"} ) assert results[0] == Document(page_content="boo", id=ids[3]) diff --git a/tests/test_vectorstore_index.py b/tests/test_vectorstore_index.py index c63c464f..310d3d21 100644 --- a/tests/test_vectorstore_index.py +++ b/tests/test_vectorstore_index.py @@ -34,10 +34,10 @@ ScaNNIndex, ) -DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") -DEFAULT_TABLE_ASYNC = "test_table" + str(uuid.uuid4()).replace("-", "_") -DEFAULT_TABLE_OMNI = "test_table" + str(uuid.uuid4()).replace("-", "_") -CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") +DEFAULT_TABLE = "table" + str(uuid.uuid4()).replace("-", "_") +DEFAULT_TABLE_ASYNC = "table" + str(uuid.uuid4()).replace("-", "_") +DEFAULT_TABLE_OMNI = "table" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_") DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX DEFAULT_INDEX_NAME_ASYNC = DEFAULT_TABLE_ASYNC + DEFAULT_INDEX_NAME_SUFFIX DEFAULT_INDEX_NAME_OMNI = DEFAULT_TABLE_OMNI + DEFAULT_INDEX_NAME_SUFFIX @@ -132,7 +132,7 @@ async def test_areindex(self, vs): if not vs.is_valid_index(DEFAULT_INDEX_NAME): index = HNSWIndex() vs.apply_vector_index(index) - vs.reindex() + vs.reindex(DEFAULT_INDEX_NAME) vs.reindex(DEFAULT_INDEX_NAME) assert vs.is_valid_index(DEFAULT_INDEX_NAME) vs.drop_vector_index(DEFAULT_INDEX_NAME) @@ -159,6 +159,21 @@ async def test_is_valid_index(self, vs): is_valid = vs.is_valid_index("invalid_index") assert is_valid == False + async def test_aapply_vector_index_ivf(self, vs): + index = IVFIndex( + name=DEFAULT_INDEX_NAME, distance_strategy=DistanceStrategy.EUCLIDEAN + ) + vs.apply_vector_index(index, concurrently=True) + assert vs.is_valid_index(DEFAULT_INDEX_NAME) + index = IVFIndex( + name="secondindex", + distance_strategy=DistanceStrategy.INNER_PRODUCT, + ) + vs.apply_vector_index(index) + assert vs.is_valid_index("secondindex") + vs.drop_vector_index("secondindex") + vs.drop_vector_index(DEFAULT_INDEX_NAME) + @pytest.mark.asyncio(loop_scope="class") class TestAsyncIndex: @@ -256,7 +271,7 @@ async def test_areindex(self, vs): if not await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC): index = HNSWIndex() await vs.aapply_vector_index(index) - await vs.areindex() + await vs.areindex(DEFAULT_INDEX_NAME_ASYNC) await vs.areindex(DEFAULT_INDEX_NAME_ASYNC) assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC) @@ -284,7 +299,9 @@ async def test_is_valid_index(self, vs): assert is_valid == False async def test_aapply_vector_index_ivf(self, vs): - index = IVFIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + index = IVFIndex( + name=DEFAULT_INDEX_NAME_ASYNC, distance_strategy=DistanceStrategy.EUCLIDEAN + ) await vs.aapply_vector_index(index, concurrently=True) assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) index = IVFIndex( @@ -294,10 +311,12 @@ async def test_aapply_vector_index_ivf(self, vs): await vs.aapply_vector_index(index) assert await vs.ais_valid_index("secondindex") await vs.adrop_vector_index("secondindex") - await vs.adrop_vector_index() + await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC) async def test_aapply_alloydb_scann_index_ScaNN(self, omni_vs): - index = ScaNNIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + index = ScaNNIndex( + name=DEFAULT_INDEX_NAME_OMNI, distance_strategy=DistanceStrategy.EUCLIDEAN + ) await omni_vs.aset_maintenance_work_mem(index.num_leaves, VECTOR_SIZE) await omni_vs.aapply_vector_index(index, concurrently=True) assert await omni_vs.ais_valid_index(DEFAULT_INDEX_NAME_OMNI) @@ -307,4 +326,4 @@ async def test_aapply_alloydb_scann_index_ScaNN(self, omni_vs): await omni_vs.aapply_vector_index(index) assert await omni_vs.ais_valid_index("secondindex") await omni_vs.adrop_vector_index("secondindex") - await omni_vs.adrop_vector_index() + await omni_vs.adrop_vector_index(DEFAULT_INDEX_NAME_OMNI) diff --git a/tests/test_vectorstore_search.py b/tests/test_vectorstore_search.py index 875f5840..53b45b3b 100644 --- a/tests/test_vectorstore_search.py +++ b/tests/test_vectorstore_search.py @@ -24,17 +24,20 @@ from sqlalchemy import text from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column +from langchain_google_alloydb_pg.hybrid_search_config import ( + HybridSearchConfig, + reciprocal_rank_fusion, + weighted_sum_ranking, +) from langchain_google_alloydb_pg.indexes import DistanceStrategy, HNSWQueryOptions -DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") -DEFAULT_TABLE_SYNC = "test_table" + str(uuid.uuid4()).replace("-", "_") -CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") -IMAGE_TABLE = "test_image_table" + str(uuid.uuid4()).replace("-", "_") -IMAGE_TABLE_SYNC = "test_image_table_sync" + str(uuid.uuid4()).replace("-", "_") -CUSTOM_FILTER_TABLE = "test_table_custom_filter" + str(uuid.uuid4()).replace("-", "_") -CUSTOM_FILTER_TABLE_SYNC = "test_table_custom_filter_sync" + str(uuid.uuid4()).replace( - "-", "_" -) +DEFAULT_TABLE = "default" + str(uuid.uuid4()).replace("-", "_") +DEFAULT_TABLE_SYNC = "default_sync" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_") +IMAGE_TABLE = "image" + str(uuid.uuid4()).replace("-", "_") +IMAGE_TABLE_SYNC = "image_sync" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_FILTER_TABLE = "custom_filter" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_FILTER_TABLE_SYNC = "custom_filter_sync" + str(uuid.uuid4()).replace("-", "_") VECTOR_SIZE = 768 embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) @@ -118,6 +121,7 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE}") + await aexecute(engine, f"DROP TABLE IF EXISTS {IMAGE_TABLE}") await engine.close() @pytest_asyncio.fixture(scope="class") @@ -250,7 +254,7 @@ async def test_asimilarity_search(self, vs): results = await vs.asimilarity_search("foo", k=1) assert len(results) == 1 assert results == [Document(page_content="foo", id=ids[0])] - results = await vs.asimilarity_search("foo", k=1, filter="content = 'bar'") + results = await vs.asimilarity_search("foo", k=1, filter={"content": "bar"}) assert results == [Document(page_content="bar", id=ids[1])] async def test_asimilarity_search_image(self, image_vs, image_uris): @@ -319,7 +323,7 @@ async def test_amax_marginal_relevance_search(self, vs): results = await vs.amax_marginal_relevance_search("bar") assert results[0] == Document(page_content="bar", id=ids[1]) results = await vs.amax_marginal_relevance_search( - "bar", filter="content = 'boo'" + "bar", filter={"content": "boo"} ) assert results[0] == Document(page_content="boo", id=ids[3]) @@ -365,6 +369,37 @@ async def test_vectorstore_with_metadata_filters( ) assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + async def test_asimilarity_hybrid_search(self, vs): + results = await vs.asimilarity_search( + "foo", k=1, hybrid_search_config=HybridSearchConfig() + ) + assert len(results) == 1 + assert results == [Document(page_content="foo", id=ids[0])] + + results = await vs.asimilarity_search( + "bar", + k=1, + hybrid_search_config=HybridSearchConfig(), + ) + assert results[0] == Document(page_content="bar", id=ids[1]) + + results = await vs.asimilarity_search( + "foo", + k=1, + filter={"content": {"$ne": "baz"}}, + hybrid_search_config=HybridSearchConfig( + fusion_function=weighted_sum_ranking, + fusion_function_parameters={ + "primary_results_weight": 0.1, + "secondary_results_weight": 0.9, + "fetch_top_k": 10, + }, + primary_top_k=1, + secondary_top_k=1, + ), + ) + assert results == [Document(page_content="foo", id=ids[0])] + class TestVectorStoreSearchSync: @pytest.fixture(scope="module") @@ -401,6 +436,7 @@ async def engine_sync( yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_SYNC}") await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE_SYNC}") + await aexecute(engine, f"DROP TABLE IF EXISTS {IMAGE_TABLE_SYNC}") await engine.close() @pytest_asyncio.fixture(scope="class") @@ -501,7 +537,7 @@ def test_similarity_search(self, vs_custom): results = vs_custom.similarity_search("foo", k=1) assert len(results) == 1 assert results == [Document(page_content="foo", id=ids[0])] - results = vs_custom.similarity_search("foo", k=1, filter="mycontent = 'bar'") + results = vs_custom.similarity_search("foo", k=1, filter={"mycontent": "bar"}) assert results == [Document(page_content="bar", id=ids[1])] def test_similarity_search_image(self, image_vs, image_uris): @@ -528,7 +564,7 @@ def test_max_marginal_relevance_search(self, vs_custom): results = vs_custom.max_marginal_relevance_search("bar") assert results[0] == Document(page_content="bar", id=ids[1]) results = vs_custom.max_marginal_relevance_search( - "bar", filter="mycontent = 'boo'" + "bar", filter={"mycontent": "boo"} ) assert results[0] == Document(page_content="boo", id=ids[3]) @@ -573,3 +609,27 @@ def test_metadata_filter_negative_tests(self, vs_custom_filter_sync, test_filter docs = vs_custom_filter_sync.similarity_search( "meow", k=5, filter=test_filter ) + + def test_similarity_hybrid_search(self, vs_custom): + results = vs_custom.similarity_search( + "foo", k=1, hybrid_search_config=HybridSearchConfig() + ) + assert len(results) == 1 + assert results == [Document(page_content="foo", id=ids[0])] + + results = vs_custom.similarity_search( + "bar", + k=1, + hybrid_search_config=HybridSearchConfig(), + ) + assert results == [Document(page_content="bar", id=ids[1])] + + results = vs_custom.similarity_search( + "foo", + k=1, + filter={"mycontent": {"$ne": "baz"}}, + hybrid_search_config=HybridSearchConfig( + fusion_function=reciprocal_rank_fusion + ), + ) + assert results == [Document(page_content="foo", id=ids[0])] From 449d7a1c70f247184da083fa3330946a91320d41 Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Wed, 13 Aug 2025 15:26:53 +0000 Subject: [PATCH 11/12] re-expose hybrid search configs from init --- src/langchain_google_alloydb_pg/__init__.py | 8 ++++++++ .../hybrid_search_config.py | 19 ------------------- .../vectorstore.py | 2 +- tests/test_async_vectorstore_index.py | 3 +-- tests/test_async_vectorstore_search.py | 7 ++++--- 5 files changed, 14 insertions(+), 25 deletions(-) delete mode 100644 src/langchain_google_alloydb_pg/hybrid_search_config.py diff --git a/src/langchain_google_alloydb_pg/__init__.py b/src/langchain_google_alloydb_pg/__init__.py index 20ceb71b..1d0ddf44 100644 --- a/src/langchain_google_alloydb_pg/__init__.py +++ b/src/langchain_google_alloydb_pg/__init__.py @@ -13,6 +13,11 @@ # limitations under the License. from langchain_postgres import Column +from langchain_postgres.v2.hybrid_search_config import ( + HybridSearchConfig, + reciprocal_rank_fusion, + weighted_sum_ranking, +) from .chat_message_history import AlloyDBChatMessageHistory from .checkpoint import AlloyDBSaver @@ -34,5 +39,8 @@ "AlloyDBModelManager", "AlloyDBModel", "AlloyDBSaver", + "HybridSearchConfig", + "reciprocal_rank_fusion", + "weighted_sum_ranking", "__version__", ] diff --git a/src/langchain_google_alloydb_pg/hybrid_search_config.py b/src/langchain_google_alloydb_pg/hybrid_search_config.py deleted file mode 100644 index 9e024cc1..00000000 --- a/src/langchain_google_alloydb_pg/hybrid_search_config.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from langchain_postgres.v2.hybrid_search_config import ( - HybridSearchConfig, - reciprocal_rank_fusion, - weighted_sum_ranking, -) diff --git a/src/langchain_google_alloydb_pg/vectorstore.py b/src/langchain_google_alloydb_pg/vectorstore.py index f018aaa9..09fba583 100644 --- a/src/langchain_google_alloydb_pg/vectorstore.py +++ b/src/langchain_google_alloydb_pg/vectorstore.py @@ -21,7 +21,7 @@ from langchain_core.embeddings import Embeddings from langchain_postgres import PGVectorStore -from langchain_google_alloydb_pg.hybrid_search_config import HybridSearchConfig +from langchain_google_alloydb_pg import HybridSearchConfig from langchain_google_alloydb_pg.indexes import ( DEFAULT_DISTANCE_STRATEGY, DistanceStrategy, diff --git a/tests/test_async_vectorstore_index.py b/tests/test_async_vectorstore_index.py index c5953335..d1befd05 100644 --- a/tests/test_async_vectorstore_index.py +++ b/tests/test_async_vectorstore_index.py @@ -23,9 +23,8 @@ from langchain_core.embeddings import DeterministicFakeEmbedding from sqlalchemy import text -from langchain_google_alloydb_pg import AlloyDBEngine +from langchain_google_alloydb_pg import AlloyDBEngine, HybridSearchConfig from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore -from langchain_google_alloydb_pg.hybrid_search_config import HybridSearchConfig from langchain_google_alloydb_pg.indexes import ( DEFAULT_INDEX_NAME_SUFFIX, DistanceStrategy, diff --git a/tests/test_async_vectorstore_search.py b/tests/test_async_vectorstore_search.py index 74606abd..aa90a4b2 100644 --- a/tests/test_async_vectorstore_search.py +++ b/tests/test_async_vectorstore_search.py @@ -23,13 +23,14 @@ from PIL import Image from sqlalchemy import text -from langchain_google_alloydb_pg import AlloyDBEngine, Column -from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore -from langchain_google_alloydb_pg.hybrid_search_config import ( +from langchain_google_alloydb_pg import ( + AlloyDBEngine, + Column, HybridSearchConfig, reciprocal_rank_fusion, weighted_sum_ranking, ) +from langchain_google_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore from langchain_google_alloydb_pg.indexes import ( DistanceStrategy, HNSWQueryOptions, From 23887e640d1370eaf78e72110b0bc3a1dc903416 Mon Sep 17 00:00:00 2001 From: Disha Prakash Date: Wed, 13 Aug 2025 15:32:02 +0000 Subject: [PATCH 12/12] re-expose hybrid search configs from init --- tests/test_engine.py | 3 +-- tests/test_vectorstore_search.py | 6 ++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_engine.py b/tests/test_engine.py index 7f9ec509..8d22a7ef 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -27,8 +27,7 @@ from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.pool import NullPool -from langchain_google_alloydb_pg import AlloyDBEngine, Column -from langchain_google_alloydb_pg.hybrid_search_config import HybridSearchConfig +from langchain_google_alloydb_pg import AlloyDBEngine, Column, HybridSearchConfig DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") diff --git a/tests/test_vectorstore_search.py b/tests/test_vectorstore_search.py index 53b45b3b..ac7a2867 100644 --- a/tests/test_vectorstore_search.py +++ b/tests/test_vectorstore_search.py @@ -23,8 +23,10 @@ from PIL import Image from sqlalchemy import text -from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBVectorStore, Column -from langchain_google_alloydb_pg.hybrid_search_config import ( +from langchain_google_alloydb_pg import ( + AlloyDBEngine, + AlloyDBVectorStore, + Column, HybridSearchConfig, reciprocal_rank_fusion, weighted_sum_ranking,