Skip to content

refactor: Refactor AlloyDBEngine to depend on PGEngine #434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions docs/vector_store.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,7 @@
"source": [
"from langchain_google_alloydb_pg import Column\n",
"\n",
"\n",
"# Set table name\n",
"TABLE_NAME = \"vectorstore_custom\"\n",
"# SCHEMA_NAME = \"my_schema\"\n",
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ dependencies = [
"numpy>=1.24.4, <3.0.0; python_version >= '3.11'",
"numpy>=1.24.4, <=2.2.6; python_version == '3.10'",
"numpy>=1.24.4, <=2.0.2; python_version <= '3.9'",
"pgvector>=0.2.5, <1.0.0",
"SQLAlchemy[asyncio]>=2.0.25, <3.0.0"
"pgvector>=0.2.5, <0.4.0",
"SQLAlchemy[asyncio]>=2.0.25, <3.0.0",
"langchain-postgres>=0.0.15",
]

classifiers = [
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ langchain-core==0.3.67
numpy==2.3.1; python_version >= "3.11"
numpy==2.2.6; python_version == "3.10"
numpy==2.0.2; python_version <= "3.9"
pgvector==0.4.1
SQLAlchemy[asyncio]==2.0.41
langgraph==0.6.0
langchain-postgres==0.0.15
2 changes: 1 addition & 1 deletion samples/langchain_quick_start.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@
},
"outputs": [],
"source": [
"from langchain_google_alloydb_pg import AlloyDBEngine, Column, AlloyDBLoader\n",
"from langchain_google_alloydb_pg import AlloyDBEngine, AlloyDBLoader, Column\n",
"\n",
"engine = AlloyDBEngine.from_instance(\n",
" project_id=project_id,\n",
Expand Down
4 changes: 3 additions & 1 deletion src/langchain_google_alloydb_pg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from langchain_postgres import Column

from .chat_message_history import AlloyDBChatMessageHistory
from .checkpoint import AlloyDBSaver
from .embeddings import AlloyDBEmbeddings
from .engine import AlloyDBEngine, Column
from .engine import AlloyDBEngine
from .loader import AlloyDBDocumentSaver, AlloyDBLoader
from .model_manager import AlloyDBModel, AlloyDBModelManager
from .vectorstore import AlloyDBVectorStore
Expand Down
6 changes: 3 additions & 3 deletions src/langchain_google_alloydb_pg/async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@
from sqlalchemy import RowMapping, text
from sqlalchemy.ext.asyncio import AsyncEngine

from .engine import AlloyDBEngine
from .indexes import (
from langchain_google_alloydb_pg.indexes import (
DEFAULT_DISTANCE_STRATEGY,
DEFAULT_INDEX_NAME_SUFFIX,
BaseIndex,
DistanceStrategy,
ExactNearestNeighbor,
QueryOptions,
ScaNNIndex,
)

from .engine import AlloyDBEngine

COMPARISONS_TO_NATIVE = {
"$eq": "=",
"$ne": "!=",
Expand Down
273 changes: 20 additions & 253 deletions src/langchain_google_alloydb_pg/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@

import asyncio
from concurrent.futures import Future
from dataclasses import dataclass
from threading import Thread
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Mapping,
Optional,
TypeVar,
Expand All @@ -35,10 +33,11 @@
IPTypes,
RefreshStrategy,
)
from sqlalchemy import MetaData, RowMapping, Table, text
from langchain_postgres import Column, PGEngine
from sqlalchemy import MetaData, Table, text
from sqlalchemy.engine import URL
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlalchemy.ext.asyncio import create_async_engine

from .version import __version__

Expand Down Expand Up @@ -90,60 +89,10 @@ async def _get_iam_principal_email(
return email.replace(".gserviceaccount.com", "")


@dataclass
class Column:
name: str
data_type: str
nullable: bool = True

def __post_init__(self) -> None:
"""Check if initialization parameters are valid.

Raises:
ValueError: If Column name is not string.
ValueError: If data_type is not type string.
"""

if not isinstance(self.name, str):
raise ValueError("Column name must be type string")
if not isinstance(self.data_type, str):
raise ValueError("Column data_type must be type string")


class AlloyDBEngine:
class AlloyDBEngine(PGEngine):
"""A class for managing connections to a AlloyDB database."""

_connector: Optional[AsyncConnector] = None
_default_loop: Optional[asyncio.AbstractEventLoop] = None
_default_thread: Optional[Thread] = None
__create_key = object()

def __init__(
self,
key: object,
pool: AsyncEngine,
loop: Optional[asyncio.AbstractEventLoop],
thread: Optional[Thread],
) -> None:
"""AlloyDBEngine constructor.

Args:
key (object): Prevent direct constructor usage.
engine (AsyncEngine): Async engine connection pool.
loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine.
thread (Optional[Thread]): Thread used to create the engine async.

Raises:
Exception: If the constructor is called directly by the user.
"""

if key != AlloyDBEngine.__create_key:
raise Exception(
"Only create class through 'create' or 'create_sync' methods!"
)
self._pool = pool
self._loop = loop
self._thread = thread

@classmethod
def __start_background_loop(
Expand Down Expand Up @@ -317,7 +266,7 @@ async def getconn() -> asyncpg.Connection:
async_creator=getconn,
**engine_args,
)
return cls(cls.__create_key, engine, loop, thread)
return cls(PGEngine._PGEngine__create_key, engine, loop, thread) # type: ignore

@classmethod
async def afrom_instance(
Expand Down Expand Up @@ -367,13 +316,21 @@ async def afrom_instance(
return await asyncio.wrap_future(future)

@classmethod
def from_engine(
cls: type[AlloyDBEngine],
engine: AsyncEngine,
loop: Optional[asyncio.AbstractEventLoop] = None,
def from_connection_string(
cls,
url: str | URL,
**kwargs: Any,
) -> AlloyDBEngine:
"""Create an AlloyDBEngine instance from an AsyncEngine."""
return cls(cls.__create_key, engine, loop, None)
"""Create an AlloyDBEngine instance from arguments
Args:
url (Optional[str]): the URL used to connect to a database. Use url or set other arguments.
Raises:
ValueError: If not all database url arguments are specified
Returns:
AlloyDBEngine
"""

return AlloyDBEngine.from_engine_args(url=url, **kwargs)

@classmethod
def from_engine_args(
Expand Down Expand Up @@ -408,197 +365,7 @@ def from_engine_args(
raise ValueError("Driver must be type 'postgresql+asyncpg'")

engine = create_async_engine(url, **kwargs)
return cls(cls.__create_key, engine, cls._default_loop, cls._default_thread)

async def _run_as_async(self, coro: Awaitable[T]) -> T:
"""Run an async coroutine asynchronously"""
# If a loop has not been provided, attempt to run in current thread
if not self._loop:
return await coro
# Otherwise, run in the background thread
return await asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(coro, self._loop)
)

def _run_as_sync(self, coro: Awaitable[T]) -> T:
"""Run an async coroutine synchronously"""
if not self._loop:
raise Exception(
"Engine was initialized without a background loop and cannot call sync methods."
)
return asyncio.run_coroutine_threadsafe(coro, self._loop).result()

async def close(self) -> None:
"""Dispose of connection pool"""
await self._run_as_async(self._pool.dispose())

async def _ainit_vectorstore_table(
self,
table_name: str,
vector_size: int,
schema_name: str = "public",
content_column: str = "content",
embedding_column: str = "embedding",
metadata_columns: list[Column] = [],
metadata_json_column: str = "langchain_metadata",
id_column: Union[str, Column] = "langchain_id",
overwrite_existing: bool = False,
store_metadata: bool = True,
) -> None:
"""
Create a table for saving of vectors to be used with AlloyDBVectorStore.

Args:
table_name (str): The Postgres database table name.
vector_size (int): Vector size for the embedding model to be used.
schema_name (str): The schema name.
Default: "public".
content_column (str): Name of the column to store document content.
Default: "page_content".
embedding_column (str) : Name of the column to store vector embeddings.
Default: "embedding".
metadata_columns (list[Column]): A list of Columns to create for custom
metadata. Default: []. Optional.
metadata_json_column (str): The column to store extra metadata in JSON format.
Default: "langchain_metadata". Optional.
id_column (Union[str, Column]) : Column to store ids.
Default: "langchain_id" column name with data type UUID. Optional.
overwrite_existing (bool): Whether to drop existing table. Default: False.
store_metadata (bool): Whether to store metadata in the table.
Default: True.

Raises:
:class:`DuplicateTableError <asyncpg.exceptions.DuplicateTableError>`: if table already exists.
:class:`UndefinedObjectError <asyncpg.exceptions.UndefinedObjectError>`: if the data type of the id column is not a postgreSQL data type.
"""
async with self._pool.connect() as conn:
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
await conn.commit()

if overwrite_existing:
async with self._pool.connect() as conn:
await conn.execute(
text(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"')
)
await conn.commit()

id_data_type = "UUID" if isinstance(id_column, str) else id_column.data_type
id_column_name = id_column if isinstance(id_column, str) else id_column.name

query = f"""CREATE TABLE "{schema_name}"."{table_name}"(
"{id_column_name}" {id_data_type} PRIMARY KEY,
"{content_column}" TEXT NOT NULL,
"{embedding_column}" vector({vector_size}) NOT NULL"""
for column in metadata_columns:
nullable = "NOT NULL" if not column.nullable else ""
query += f',\n"{column.name}" {column.data_type} {nullable}'
if store_metadata:
query += f""",\n"{metadata_json_column}" JSON"""
query += "\n);"

async with self._pool.connect() as conn:
await conn.execute(text(query))
await conn.commit()

async def ainit_vectorstore_table(
self,
table_name: str,
vector_size: int,
schema_name: str = "public",
content_column: str = "content",
embedding_column: str = "embedding",
metadata_columns: list[Column] = [],
metadata_json_column: str = "langchain_metadata",
id_column: Union[str, Column] = "langchain_id",
overwrite_existing: bool = False,
store_metadata: bool = True,
) -> None:
"""
Create a table for saving of vectors to be used with AlloyDBVectorStore.

Args:
table_name (str): The database table name.
vector_size (int): Vector size for the embedding model to be used.
schema_name (str): The schema name.
Default: "public".
content_column (str): Name of the column to store document content.
Default: "page_content".
embedding_column (str) : Name of the column to store vector embeddings.
Default: "embedding".
metadata_columns (list[Column]): A list of Columns to create for custom
metadata. Default: []. Optional.
metadata_json_column (str): The column to store extra metadata in JSON format.
Default: "langchain_metadata". Optional.
id_column (Union[str, Column]) : Column to store ids.
Default: "langchain_id" column name with data type UUID. Optional.
overwrite_existing (bool): Whether to drop existing table. Default: False.
store_metadata (bool): Whether to store metadata in the table.
Default: True.
"""
await self._run_as_async(
self._ainit_vectorstore_table(
table_name,
vector_size,
schema_name,
content_column,
embedding_column,
metadata_columns,
metadata_json_column,
id_column,
overwrite_existing,
store_metadata,
)
)

def init_vectorstore_table(
self,
table_name: str,
vector_size: int,
schema_name: str = "public",
content_column: str = "content",
embedding_column: str = "embedding",
metadata_columns: list[Column] = [],
metadata_json_column: str = "langchain_metadata",
id_column: Union[str, Column] = "langchain_id",
overwrite_existing: bool = False,
store_metadata: bool = True,
) -> None:
"""
Create a table for saving of vectors to be used with AlloyDBVectorStore.

Args:
table_name (str): The database table name.
vector_size (int): Vector size for the embedding model to be used.
schema_name (str): The schema name.
Default: "public".
content_column (str): Name of the column to store document content.
Default: "page_content".
embedding_column (str) : Name of the column to store vector embeddings.
Default: "embedding".
metadata_columns (list[Column]): A list of Columns to create for custom
metadata. Default: []. Optional.
metadata_json_column (str): The column to store extra metadata in JSON format.
Default: "langchain_metadata". Optional.
id_column (Union[str, Column]) : Column to store ids.
Default: "langchain_id" column name with data type UUID. Optional.
overwrite_existing (bool): Whether to drop existing table. Default: False.
store_metadata (bool): Whether to store metadata in the table.
Default: True.
"""
self._run_as_sync(
self._ainit_vectorstore_table(
table_name,
vector_size,
schema_name,
content_column,
embedding_column,
metadata_columns,
metadata_json_column,
id_column,
overwrite_existing,
store_metadata,
)
)
return cls(PGEngine._PGEngine__create_key, engine, cls._default_loop, cls._default_thread) # type: ignore

async def _ainit_chat_history_table(
self, table_name: str, schema_name: str = "public"
Expand Down
Loading