Skip to content

Commit cd80093

Browse files
committed
chore: Refactor AlloyDBEngine to depend on PGEngine
1 parent 9f1715c commit cd80093

File tree

8 files changed

+52
-416
lines changed

8 files changed

+52
-416
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ dependencies = [
1616
"numpy>=1.24.4, <=2.2.6; python_version == '3.10'",
1717
"numpy>=1.24.4, <=2.0.2; python_version <= '3.9'",
1818
"pgvector>=0.2.5, <1.0.0",
19-
"SQLAlchemy[asyncio]>=2.0.25, <3.0.0"
19+
"SQLAlchemy[asyncio]>=2.0.25, <3.0.0",
20+
"langchain-postgres>=0.0.15",
2021
]
2122

2223
classifiers = [

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ numpy==2.0.2; python_version <= "3.9"
77
pgvector==0.4.1
88
SQLAlchemy[asyncio]==2.0.41
99
langgraph==0.5.0
10+
langchain-postgres==0.0.15

src/langchain_google_alloydb_pg/async_vectorstore.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,18 @@
2828
from langchain_core.documents import Document
2929
from langchain_core.embeddings import Embeddings
3030
from langchain_core.vectorstores import VectorStore, utils
31-
from sqlalchemy import RowMapping, text
32-
from sqlalchemy.ext.asyncio import AsyncEngine
33-
34-
from .engine import AlloyDBEngine
35-
from .indexes import (
31+
from langchain_postgres.v2.indexes import (
3632
DEFAULT_DISTANCE_STRATEGY,
3733
DEFAULT_INDEX_NAME_SUFFIX,
3834
BaseIndex,
3935
DistanceStrategy,
4036
ExactNearestNeighbor,
4137
QueryOptions,
42-
ScaNNIndex,
4338
)
39+
from sqlalchemy import RowMapping, text
40+
from sqlalchemy.ext.asyncio import AsyncEngine
41+
42+
from .engine import AlloyDBEngine
4443

4544
COMPARISONS_TO_NATIVE = {
4645
"$eq": "=",

src/langchain_google_alloydb_pg/engine.py

Lines changed: 20 additions & 253 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515

1616
import asyncio
1717
from concurrent.futures import Future
18-
from dataclasses import dataclass
1918
from threading import Thread
2019
from typing import (
2120
TYPE_CHECKING,
2221
Any,
23-
Awaitable,
2422
Mapping,
2523
Optional,
2624
TypeVar,
@@ -35,10 +33,11 @@
3533
IPTypes,
3634
RefreshStrategy,
3735
)
38-
from sqlalchemy import MetaData, RowMapping, Table, text
36+
from langchain_postgres import Column, PGEngine
37+
from sqlalchemy import MetaData, Table, text
3938
from sqlalchemy.engine import URL
4039
from sqlalchemy.exc import InvalidRequestError
41-
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
40+
from sqlalchemy.ext.asyncio import create_async_engine
4241

4342
from .version import __version__
4443

@@ -90,60 +89,10 @@ async def _get_iam_principal_email(
9089
return email.replace(".gserviceaccount.com", "")
9190

9291

93-
@dataclass
94-
class Column:
95-
name: str
96-
data_type: str
97-
nullable: bool = True
98-
99-
def __post_init__(self) -> None:
100-
"""Check if initialization parameters are valid.
101-
102-
Raises:
103-
ValueError: If Column name is not string.
104-
ValueError: If data_type is not type string.
105-
"""
106-
107-
if not isinstance(self.name, str):
108-
raise ValueError("Column name must be type string")
109-
if not isinstance(self.data_type, str):
110-
raise ValueError("Column data_type must be type string")
111-
112-
113-
class AlloyDBEngine:
92+
class AlloyDBEngine(PGEngine):
11493
"""A class for managing connections to a AlloyDB database."""
11594

11695
_connector: Optional[AsyncConnector] = None
117-
_default_loop: Optional[asyncio.AbstractEventLoop] = None
118-
_default_thread: Optional[Thread] = None
119-
__create_key = object()
120-
121-
def __init__(
122-
self,
123-
key: object,
124-
pool: AsyncEngine,
125-
loop: Optional[asyncio.AbstractEventLoop],
126-
thread: Optional[Thread],
127-
) -> None:
128-
"""AlloyDBEngine constructor.
129-
130-
Args:
131-
key (object): Prevent direct constructor usage.
132-
engine (AsyncEngine): Async engine connection pool.
133-
loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine.
134-
thread (Optional[Thread]): Thread used to create the engine async.
135-
136-
Raises:
137-
Exception: If the constructor is called directly by the user.
138-
"""
139-
140-
if key != AlloyDBEngine.__create_key:
141-
raise Exception(
142-
"Only create class through 'create' or 'create_sync' methods!"
143-
)
144-
self._pool = pool
145-
self._loop = loop
146-
self._thread = thread
14796

14897
@classmethod
14998
def __start_background_loop(
@@ -317,7 +266,7 @@ async def getconn() -> asyncpg.Connection:
317266
async_creator=getconn,
318267
**engine_args,
319268
)
320-
return cls(cls.__create_key, engine, loop, thread)
269+
return cls(PGEngine._PGEngine__create_key, engine, loop, thread) # type: ignore
321270

322271
@classmethod
323272
async def afrom_instance(
@@ -367,13 +316,21 @@ async def afrom_instance(
367316
return await asyncio.wrap_future(future)
368317

369318
@classmethod
370-
def from_engine(
371-
cls: type[AlloyDBEngine],
372-
engine: AsyncEngine,
373-
loop: Optional[asyncio.AbstractEventLoop] = None,
319+
def from_connection_string(
320+
cls,
321+
url: str | URL,
322+
**kwargs: Any,
374323
) -> AlloyDBEngine:
375-
"""Create an AlloyDBEngine instance from an AsyncEngine."""
376-
return cls(cls.__create_key, engine, loop, None)
324+
"""Create an AlloyDBEngine instance from arguments
325+
Args:
326+
url (Optional[str]): the URL used to connect to a database. Use url or set other arguments.
327+
Raises:
328+
ValueError: If not all database url arguments are specified
329+
Returns:
330+
AlloyDBEngine
331+
"""
332+
333+
return AlloyDBEngine.from_engine_args(url=url, **kwargs)
377334

378335
@classmethod
379336
def from_engine_args(
@@ -408,197 +365,7 @@ def from_engine_args(
408365
raise ValueError("Driver must be type 'postgresql+asyncpg'")
409366

410367
engine = create_async_engine(url, **kwargs)
411-
return cls(cls.__create_key, engine, cls._default_loop, cls._default_thread)
412-
413-
async def _run_as_async(self, coro: Awaitable[T]) -> T:
414-
"""Run an async coroutine asynchronously"""
415-
# If a loop has not been provided, attempt to run in current thread
416-
if not self._loop:
417-
return await coro
418-
# Otherwise, run in the background thread
419-
return await asyncio.wrap_future(
420-
asyncio.run_coroutine_threadsafe(coro, self._loop)
421-
)
422-
423-
def _run_as_sync(self, coro: Awaitable[T]) -> T:
424-
"""Run an async coroutine synchronously"""
425-
if not self._loop:
426-
raise Exception(
427-
"Engine was initialized without a background loop and cannot call sync methods."
428-
)
429-
return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
430-
431-
async def close(self) -> None:
432-
"""Dispose of connection pool"""
433-
await self._run_as_async(self._pool.dispose())
434-
435-
async def _ainit_vectorstore_table(
436-
self,
437-
table_name: str,
438-
vector_size: int,
439-
schema_name: str = "public",
440-
content_column: str = "content",
441-
embedding_column: str = "embedding",
442-
metadata_columns: list[Column] = [],
443-
metadata_json_column: str = "langchain_metadata",
444-
id_column: Union[str, Column] = "langchain_id",
445-
overwrite_existing: bool = False,
446-
store_metadata: bool = True,
447-
) -> None:
448-
"""
449-
Create a table for saving of vectors to be used with AlloyDBVectorStore.
450-
451-
Args:
452-
table_name (str): The Postgres database table name.
453-
vector_size (int): Vector size for the embedding model to be used.
454-
schema_name (str): The schema name.
455-
Default: "public".
456-
content_column (str): Name of the column to store document content.
457-
Default: "page_content".
458-
embedding_column (str) : Name of the column to store vector embeddings.
459-
Default: "embedding".
460-
metadata_columns (list[Column]): A list of Columns to create for custom
461-
metadata. Default: []. Optional.
462-
metadata_json_column (str): The column to store extra metadata in JSON format.
463-
Default: "langchain_metadata". Optional.
464-
id_column (Union[str, Column]) : Column to store ids.
465-
Default: "langchain_id" column name with data type UUID. Optional.
466-
overwrite_existing (bool): Whether to drop existing table. Default: False.
467-
store_metadata (bool): Whether to store metadata in the table.
468-
Default: True.
469-
470-
Raises:
471-
:class:`DuplicateTableError <asyncpg.exceptions.DuplicateTableError>`: if table already exists.
472-
:class:`UndefinedObjectError <asyncpg.exceptions.UndefinedObjectError>`: if the data type of the id column is not a postgreSQL data type.
473-
"""
474-
async with self._pool.connect() as conn:
475-
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
476-
await conn.commit()
477-
478-
if overwrite_existing:
479-
async with self._pool.connect() as conn:
480-
await conn.execute(
481-
text(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"')
482-
)
483-
await conn.commit()
484-
485-
id_data_type = "UUID" if isinstance(id_column, str) else id_column.data_type
486-
id_column_name = id_column if isinstance(id_column, str) else id_column.name
487-
488-
query = f"""CREATE TABLE "{schema_name}"."{table_name}"(
489-
"{id_column_name}" {id_data_type} PRIMARY KEY,
490-
"{content_column}" TEXT NOT NULL,
491-
"{embedding_column}" vector({vector_size}) NOT NULL"""
492-
for column in metadata_columns:
493-
nullable = "NOT NULL" if not column.nullable else ""
494-
query += f',\n"{column.name}" {column.data_type} {nullable}'
495-
if store_metadata:
496-
query += f""",\n"{metadata_json_column}" JSON"""
497-
query += "\n);"
498-
499-
async with self._pool.connect() as conn:
500-
await conn.execute(text(query))
501-
await conn.commit()
502-
503-
async def ainit_vectorstore_table(
504-
self,
505-
table_name: str,
506-
vector_size: int,
507-
schema_name: str = "public",
508-
content_column: str = "content",
509-
embedding_column: str = "embedding",
510-
metadata_columns: list[Column] = [],
511-
metadata_json_column: str = "langchain_metadata",
512-
id_column: Union[str, Column] = "langchain_id",
513-
overwrite_existing: bool = False,
514-
store_metadata: bool = True,
515-
) -> None:
516-
"""
517-
Create a table for saving of vectors to be used with AlloyDBVectorStore.
518-
519-
Args:
520-
table_name (str): The database table name.
521-
vector_size (int): Vector size for the embedding model to be used.
522-
schema_name (str): The schema name.
523-
Default: "public".
524-
content_column (str): Name of the column to store document content.
525-
Default: "page_content".
526-
embedding_column (str) : Name of the column to store vector embeddings.
527-
Default: "embedding".
528-
metadata_columns (list[Column]): A list of Columns to create for custom
529-
metadata. Default: []. Optional.
530-
metadata_json_column (str): The column to store extra metadata in JSON format.
531-
Default: "langchain_metadata". Optional.
532-
id_column (Union[str, Column]) : Column to store ids.
533-
Default: "langchain_id" column name with data type UUID. Optional.
534-
overwrite_existing (bool): Whether to drop existing table. Default: False.
535-
store_metadata (bool): Whether to store metadata in the table.
536-
Default: True.
537-
"""
538-
await self._run_as_async(
539-
self._ainit_vectorstore_table(
540-
table_name,
541-
vector_size,
542-
schema_name,
543-
content_column,
544-
embedding_column,
545-
metadata_columns,
546-
metadata_json_column,
547-
id_column,
548-
overwrite_existing,
549-
store_metadata,
550-
)
551-
)
552-
553-
def init_vectorstore_table(
554-
self,
555-
table_name: str,
556-
vector_size: int,
557-
schema_name: str = "public",
558-
content_column: str = "content",
559-
embedding_column: str = "embedding",
560-
metadata_columns: list[Column] = [],
561-
metadata_json_column: str = "langchain_metadata",
562-
id_column: Union[str, Column] = "langchain_id",
563-
overwrite_existing: bool = False,
564-
store_metadata: bool = True,
565-
) -> None:
566-
"""
567-
Create a table for saving of vectors to be used with AlloyDBVectorStore.
568-
569-
Args:
570-
table_name (str): The database table name.
571-
vector_size (int): Vector size for the embedding model to be used.
572-
schema_name (str): The schema name.
573-
Default: "public".
574-
content_column (str): Name of the column to store document content.
575-
Default: "page_content".
576-
embedding_column (str) : Name of the column to store vector embeddings.
577-
Default: "embedding".
578-
metadata_columns (list[Column]): A list of Columns to create for custom
579-
metadata. Default: []. Optional.
580-
metadata_json_column (str): The column to store extra metadata in JSON format.
581-
Default: "langchain_metadata". Optional.
582-
id_column (Union[str, Column]) : Column to store ids.
583-
Default: "langchain_id" column name with data type UUID. Optional.
584-
overwrite_existing (bool): Whether to drop existing table. Default: False.
585-
store_metadata (bool): Whether to store metadata in the table.
586-
Default: True.
587-
"""
588-
self._run_as_sync(
589-
self._ainit_vectorstore_table(
590-
table_name,
591-
vector_size,
592-
schema_name,
593-
content_column,
594-
embedding_column,
595-
metadata_columns,
596-
metadata_json_column,
597-
id_column,
598-
overwrite_existing,
599-
store_metadata,
600-
)
601-
)
368+
return cls(PGEngine._PGEngine__create_key, engine, cls._default_loop, cls._default_thread) # type: ignore
602369

603370
async def _ainit_chat_history_table(
604371
self, table_name: str, schema_name: str = "public"

0 commit comments

Comments
 (0)