|
15 | 15 |
|
16 | 16 | import asyncio
|
17 | 17 | from concurrent.futures import Future
|
18 |
| -from dataclasses import dataclass |
19 | 18 | from threading import Thread
|
20 | 19 | from typing import (
|
21 | 20 | TYPE_CHECKING,
|
22 | 21 | Any,
|
23 |
| - Awaitable, |
24 | 22 | Mapping,
|
25 | 23 | Optional,
|
26 | 24 | TypeVar,
|
|
35 | 33 | IPTypes,
|
36 | 34 | RefreshStrategy,
|
37 | 35 | )
|
38 |
| -from sqlalchemy import MetaData, RowMapping, Table, text |
| 36 | +from langchain_postgres import Column, PGEngine |
| 37 | +from sqlalchemy import MetaData, Table, text |
39 | 38 | from sqlalchemy.engine import URL
|
40 | 39 | from sqlalchemy.exc import InvalidRequestError
|
41 |
| -from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine |
| 40 | +from sqlalchemy.ext.asyncio import create_async_engine |
42 | 41 |
|
43 | 42 | from .version import __version__
|
44 | 43 |
|
@@ -90,60 +89,10 @@ async def _get_iam_principal_email(
|
90 | 89 | return email.replace(".gserviceaccount.com", "")
|
91 | 90 |
|
92 | 91 |
|
93 |
| -@dataclass |
94 |
| -class Column: |
95 |
| - name: str |
96 |
| - data_type: str |
97 |
| - nullable: bool = True |
98 |
| - |
99 |
| - def __post_init__(self) -> None: |
100 |
| - """Check if initialization parameters are valid. |
101 |
| -
|
102 |
| - Raises: |
103 |
| - ValueError: If Column name is not string. |
104 |
| - ValueError: If data_type is not type string. |
105 |
| - """ |
106 |
| - |
107 |
| - if not isinstance(self.name, str): |
108 |
| - raise ValueError("Column name must be type string") |
109 |
| - if not isinstance(self.data_type, str): |
110 |
| - raise ValueError("Column data_type must be type string") |
111 |
| - |
112 |
| - |
113 |
| -class AlloyDBEngine: |
| 92 | +class AlloyDBEngine(PGEngine): |
114 | 93 | """A class for managing connections to a AlloyDB database."""
|
115 | 94 |
|
116 | 95 | _connector: Optional[AsyncConnector] = None
|
117 |
| - _default_loop: Optional[asyncio.AbstractEventLoop] = None |
118 |
| - _default_thread: Optional[Thread] = None |
119 |
| - __create_key = object() |
120 |
| - |
121 |
| - def __init__( |
122 |
| - self, |
123 |
| - key: object, |
124 |
| - pool: AsyncEngine, |
125 |
| - loop: Optional[asyncio.AbstractEventLoop], |
126 |
| - thread: Optional[Thread], |
127 |
| - ) -> None: |
128 |
| - """AlloyDBEngine constructor. |
129 |
| -
|
130 |
| - Args: |
131 |
| - key (object): Prevent direct constructor usage. |
132 |
| - engine (AsyncEngine): Async engine connection pool. |
133 |
| - loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine. |
134 |
| - thread (Optional[Thread]): Thread used to create the engine async. |
135 |
| -
|
136 |
| - Raises: |
137 |
| - Exception: If the constructor is called directly by the user. |
138 |
| - """ |
139 |
| - |
140 |
| - if key != AlloyDBEngine.__create_key: |
141 |
| - raise Exception( |
142 |
| - "Only create class through 'create' or 'create_sync' methods!" |
143 |
| - ) |
144 |
| - self._pool = pool |
145 |
| - self._loop = loop |
146 |
| - self._thread = thread |
147 | 96 |
|
148 | 97 | @classmethod
|
149 | 98 | def __start_background_loop(
|
@@ -317,7 +266,7 @@ async def getconn() -> asyncpg.Connection:
|
317 | 266 | async_creator=getconn,
|
318 | 267 | **engine_args,
|
319 | 268 | )
|
320 |
| - return cls(cls.__create_key, engine, loop, thread) |
| 269 | + return cls(PGEngine._PGEngine__create_key, engine, loop, thread) # type: ignore |
321 | 270 |
|
322 | 271 | @classmethod
|
323 | 272 | async def afrom_instance(
|
@@ -367,13 +316,21 @@ async def afrom_instance(
|
367 | 316 | return await asyncio.wrap_future(future)
|
368 | 317 |
|
369 | 318 | @classmethod
|
370 |
| - def from_engine( |
371 |
| - cls: type[AlloyDBEngine], |
372 |
| - engine: AsyncEngine, |
373 |
| - loop: Optional[asyncio.AbstractEventLoop] = None, |
| 319 | + def from_connection_string( |
| 320 | + cls, |
| 321 | + url: str | URL, |
| 322 | + **kwargs: Any, |
374 | 323 | ) -> AlloyDBEngine:
|
375 |
| - """Create an AlloyDBEngine instance from an AsyncEngine.""" |
376 |
| - return cls(cls.__create_key, engine, loop, None) |
| 324 | + """Create an AlloyDBEngine instance from arguments |
| 325 | + Args: |
| 326 | + url (Optional[str]): the URL used to connect to a database. Use url or set other arguments. |
| 327 | + Raises: |
| 328 | + ValueError: If not all database url arguments are specified |
| 329 | + Returns: |
| 330 | + AlloyDBEngine |
| 331 | + """ |
| 332 | + |
| 333 | + return AlloyDBEngine.from_engine_args(url=url, **kwargs) |
377 | 334 |
|
378 | 335 | @classmethod
|
379 | 336 | def from_engine_args(
|
@@ -408,197 +365,7 @@ def from_engine_args(
|
408 | 365 | raise ValueError("Driver must be type 'postgresql+asyncpg'")
|
409 | 366 |
|
410 | 367 | engine = create_async_engine(url, **kwargs)
|
411 |
| - return cls(cls.__create_key, engine, cls._default_loop, cls._default_thread) |
412 |
| - |
413 |
| - async def _run_as_async(self, coro: Awaitable[T]) -> T: |
414 |
| - """Run an async coroutine asynchronously""" |
415 |
| - # If a loop has not been provided, attempt to run in current thread |
416 |
| - if not self._loop: |
417 |
| - return await coro |
418 |
| - # Otherwise, run in the background thread |
419 |
| - return await asyncio.wrap_future( |
420 |
| - asyncio.run_coroutine_threadsafe(coro, self._loop) |
421 |
| - ) |
422 |
| - |
423 |
| - def _run_as_sync(self, coro: Awaitable[T]) -> T: |
424 |
| - """Run an async coroutine synchronously""" |
425 |
| - if not self._loop: |
426 |
| - raise Exception( |
427 |
| - "Engine was initialized without a background loop and cannot call sync methods." |
428 |
| - ) |
429 |
| - return asyncio.run_coroutine_threadsafe(coro, self._loop).result() |
430 |
| - |
431 |
| - async def close(self) -> None: |
432 |
| - """Dispose of connection pool""" |
433 |
| - await self._run_as_async(self._pool.dispose()) |
434 |
| - |
435 |
| - async def _ainit_vectorstore_table( |
436 |
| - self, |
437 |
| - table_name: str, |
438 |
| - vector_size: int, |
439 |
| - schema_name: str = "public", |
440 |
| - content_column: str = "content", |
441 |
| - embedding_column: str = "embedding", |
442 |
| - metadata_columns: list[Column] = [], |
443 |
| - metadata_json_column: str = "langchain_metadata", |
444 |
| - id_column: Union[str, Column] = "langchain_id", |
445 |
| - overwrite_existing: bool = False, |
446 |
| - store_metadata: bool = True, |
447 |
| - ) -> None: |
448 |
| - """ |
449 |
| - Create a table for saving of vectors to be used with AlloyDBVectorStore. |
450 |
| -
|
451 |
| - Args: |
452 |
| - table_name (str): The Postgres database table name. |
453 |
| - vector_size (int): Vector size for the embedding model to be used. |
454 |
| - schema_name (str): The schema name. |
455 |
| - Default: "public". |
456 |
| - content_column (str): Name of the column to store document content. |
457 |
| - Default: "page_content". |
458 |
| - embedding_column (str) : Name of the column to store vector embeddings. |
459 |
| - Default: "embedding". |
460 |
| - metadata_columns (list[Column]): A list of Columns to create for custom |
461 |
| - metadata. Default: []. Optional. |
462 |
| - metadata_json_column (str): The column to store extra metadata in JSON format. |
463 |
| - Default: "langchain_metadata". Optional. |
464 |
| - id_column (Union[str, Column]) : Column to store ids. |
465 |
| - Default: "langchain_id" column name with data type UUID. Optional. |
466 |
| - overwrite_existing (bool): Whether to drop existing table. Default: False. |
467 |
| - store_metadata (bool): Whether to store metadata in the table. |
468 |
| - Default: True. |
469 |
| -
|
470 |
| - Raises: |
471 |
| - :class:`DuplicateTableError <asyncpg.exceptions.DuplicateTableError>`: if table already exists. |
472 |
| - :class:`UndefinedObjectError <asyncpg.exceptions.UndefinedObjectError>`: if the data type of the id column is not a postgreSQL data type. |
473 |
| - """ |
474 |
| - async with self._pool.connect() as conn: |
475 |
| - await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) |
476 |
| - await conn.commit() |
477 |
| - |
478 |
| - if overwrite_existing: |
479 |
| - async with self._pool.connect() as conn: |
480 |
| - await conn.execute( |
481 |
| - text(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"') |
482 |
| - ) |
483 |
| - await conn.commit() |
484 |
| - |
485 |
| - id_data_type = "UUID" if isinstance(id_column, str) else id_column.data_type |
486 |
| - id_column_name = id_column if isinstance(id_column, str) else id_column.name |
487 |
| - |
488 |
| - query = f"""CREATE TABLE "{schema_name}"."{table_name}"( |
489 |
| - "{id_column_name}" {id_data_type} PRIMARY KEY, |
490 |
| - "{content_column}" TEXT NOT NULL, |
491 |
| - "{embedding_column}" vector({vector_size}) NOT NULL""" |
492 |
| - for column in metadata_columns: |
493 |
| - nullable = "NOT NULL" if not column.nullable else "" |
494 |
| - query += f',\n"{column.name}" {column.data_type} {nullable}' |
495 |
| - if store_metadata: |
496 |
| - query += f""",\n"{metadata_json_column}" JSON""" |
497 |
| - query += "\n);" |
498 |
| - |
499 |
| - async with self._pool.connect() as conn: |
500 |
| - await conn.execute(text(query)) |
501 |
| - await conn.commit() |
502 |
| - |
503 |
| - async def ainit_vectorstore_table( |
504 |
| - self, |
505 |
| - table_name: str, |
506 |
| - vector_size: int, |
507 |
| - schema_name: str = "public", |
508 |
| - content_column: str = "content", |
509 |
| - embedding_column: str = "embedding", |
510 |
| - metadata_columns: list[Column] = [], |
511 |
| - metadata_json_column: str = "langchain_metadata", |
512 |
| - id_column: Union[str, Column] = "langchain_id", |
513 |
| - overwrite_existing: bool = False, |
514 |
| - store_metadata: bool = True, |
515 |
| - ) -> None: |
516 |
| - """ |
517 |
| - Create a table for saving of vectors to be used with AlloyDBVectorStore. |
518 |
| -
|
519 |
| - Args: |
520 |
| - table_name (str): The database table name. |
521 |
| - vector_size (int): Vector size for the embedding model to be used. |
522 |
| - schema_name (str): The schema name. |
523 |
| - Default: "public". |
524 |
| - content_column (str): Name of the column to store document content. |
525 |
| - Default: "page_content". |
526 |
| - embedding_column (str) : Name of the column to store vector embeddings. |
527 |
| - Default: "embedding". |
528 |
| - metadata_columns (list[Column]): A list of Columns to create for custom |
529 |
| - metadata. Default: []. Optional. |
530 |
| - metadata_json_column (str): The column to store extra metadata in JSON format. |
531 |
| - Default: "langchain_metadata". Optional. |
532 |
| - id_column (Union[str, Column]) : Column to store ids. |
533 |
| - Default: "langchain_id" column name with data type UUID. Optional. |
534 |
| - overwrite_existing (bool): Whether to drop existing table. Default: False. |
535 |
| - store_metadata (bool): Whether to store metadata in the table. |
536 |
| - Default: True. |
537 |
| - """ |
538 |
| - await self._run_as_async( |
539 |
| - self._ainit_vectorstore_table( |
540 |
| - table_name, |
541 |
| - vector_size, |
542 |
| - schema_name, |
543 |
| - content_column, |
544 |
| - embedding_column, |
545 |
| - metadata_columns, |
546 |
| - metadata_json_column, |
547 |
| - id_column, |
548 |
| - overwrite_existing, |
549 |
| - store_metadata, |
550 |
| - ) |
551 |
| - ) |
552 |
| - |
553 |
| - def init_vectorstore_table( |
554 |
| - self, |
555 |
| - table_name: str, |
556 |
| - vector_size: int, |
557 |
| - schema_name: str = "public", |
558 |
| - content_column: str = "content", |
559 |
| - embedding_column: str = "embedding", |
560 |
| - metadata_columns: list[Column] = [], |
561 |
| - metadata_json_column: str = "langchain_metadata", |
562 |
| - id_column: Union[str, Column] = "langchain_id", |
563 |
| - overwrite_existing: bool = False, |
564 |
| - store_metadata: bool = True, |
565 |
| - ) -> None: |
566 |
| - """ |
567 |
| - Create a table for saving of vectors to be used with AlloyDBVectorStore. |
568 |
| -
|
569 |
| - Args: |
570 |
| - table_name (str): The database table name. |
571 |
| - vector_size (int): Vector size for the embedding model to be used. |
572 |
| - schema_name (str): The schema name. |
573 |
| - Default: "public". |
574 |
| - content_column (str): Name of the column to store document content. |
575 |
| - Default: "page_content". |
576 |
| - embedding_column (str) : Name of the column to store vector embeddings. |
577 |
| - Default: "embedding". |
578 |
| - metadata_columns (list[Column]): A list of Columns to create for custom |
579 |
| - metadata. Default: []. Optional. |
580 |
| - metadata_json_column (str): The column to store extra metadata in JSON format. |
581 |
| - Default: "langchain_metadata". Optional. |
582 |
| - id_column (Union[str, Column]) : Column to store ids. |
583 |
| - Default: "langchain_id" column name with data type UUID. Optional. |
584 |
| - overwrite_existing (bool): Whether to drop existing table. Default: False. |
585 |
| - store_metadata (bool): Whether to store metadata in the table. |
586 |
| - Default: True. |
587 |
| - """ |
588 |
| - self._run_as_sync( |
589 |
| - self._ainit_vectorstore_table( |
590 |
| - table_name, |
591 |
| - vector_size, |
592 |
| - schema_name, |
593 |
| - content_column, |
594 |
| - embedding_column, |
595 |
| - metadata_columns, |
596 |
| - metadata_json_column, |
597 |
| - id_column, |
598 |
| - overwrite_existing, |
599 |
| - store_metadata, |
600 |
| - ) |
601 |
| - ) |
| 368 | + return cls(PGEngine._PGEngine__create_key, engine, cls._default_loop, cls._default_thread) # type: ignore |
602 | 369 |
|
603 | 370 | async def _ainit_chat_history_table(
|
604 | 371 | self, table_name: str, schema_name: str = "public"
|
|
0 commit comments