11import uuid
22from typing import Callable , Dict , Optional
33
4+ # async pg
5+ import anyio
46import sqlalchemy
57from pgvector .sqlalchemy import Vector
68from pydantic import Field
79from sqlalchemy import Column , String , select , text
810from sqlalchemy .dialects .postgresql import ARRAY
911from sqlalchemy .exc import ProgrammingError
10- from sqlalchemy .orm import Session , declarative_base , sessionmaker
11- from sqlalchemy_utils import create_database , database_exists
12-
13- #async pg
14- import anyio
1512from sqlalchemy .ext .asyncio import (
16- create_async_engine ,
1713 AsyncEngine ,
1814 AsyncSession ,
1915 async_sessionmaker ,
16+ create_async_engine ,
2017)
18+ from sqlalchemy .orm import Session , declarative_base , sessionmaker
19+ from sqlalchemy_utils import create_database , database_exists
2120
2221import controlflow
23- from controlflow .memory .memory import MemoryProvider
2422from controlflow .memory .async_memory import AsyncMemoryProvider
23+ from controlflow .memory .memory import MemoryProvider
24+
2525try :
2626 # For embeddings, we can use langchain_openai or any other library:
2727 from langchain_openai import OpenAIEmbeddings
@@ -125,7 +125,7 @@ def configure(self, memory_key: str) -> None:
125125 pool_timeout = self .pool_timeout ,
126126 pool_recycle = self .pool_recycle ,
127127 pool_pre_ping = self .pool_pre_ping ,
128- )
128+ )
129129
130130 # 2) If DB doesn't exist, create it!
131131 if not database_exists (engine .url ):
@@ -239,7 +239,6 @@ def search(self, memory_key: str, query: str, n: int = 20) -> Dict[str, str]:
239239 return {row .id : row .text for row in results }
240240
241241
242-
243242class AsyncPostgresMemory (AsyncMemoryProvider ):
244243 """
245244 An async MemoryProvider storing text + embeddings in PostgreSQL
@@ -249,7 +248,7 @@ class AsyncPostgresMemory(AsyncMemoryProvider):
249248 database_url : str = Field (
250249 default = "postgresql+asyncpg://user:password@localhost:5432/your_database" ,
251250 description = "Async Postgres URL with the asyncpg driver, e.g. "
252- "'postgresql+asyncpg://user:pass@host:5432/dbname'."
251+ "'postgresql+asyncpg://user:pass@host:5432/dbname'." ,
253252 )
254253
255254 table_name : str = Field (
@@ -261,34 +260,30 @@ class AsyncPostgresMemory(AsyncMemoryProvider):
261260
262261 embedding_dimension : int = Field (
263262 default = 1536 ,
264- description = "Dimension of the embedding vectors. Must match your model output size."
263+ description = "Dimension of the embedding vectors. Must match your model output size." ,
265264 )
266265
267266 embedding_fn : Callable = Field (
268267 default_factory = lambda : OpenAIEmbeddings (model = "text-embedding-ada-002" ),
269- description = "Function that turns a string into a numeric vector."
268+ description = "Function that turns a string into a numeric vector." ,
270269 )
271270
272271 # -- Pool / Engine settings (SQLAlchemy will do the pooling)
273272 pool_size : int = Field (
274- 5 ,
275- description = "Number of permanent connections in the async pool."
273+ 5 , description = "Number of permanent connections in the async pool."
276274 )
277275 max_overflow : int = Field (
278- 10 ,
279- description = "Number of 'overflow' connections if the pool is full."
276+ 10 , description = "Number of 'overflow' connections if the pool is full."
280277 )
281278 pool_timeout : int = Field (
282- 30 ,
283- description = "Seconds to wait for a connection before raising an error."
279+ 30 , description = "Seconds to wait for a connection before raising an error."
284280 )
285281 pool_recycle : int = Field (
286282 1800 ,
287- description = "Recycle connections after N seconds to avoid stale connections."
283+ description = "Recycle connections after N seconds to avoid stale connections." ,
288284 )
289285 pool_pre_ping : bool = Field (
290- True ,
291- description = "Check connection health before using from the pool."
286+ True , description = "Check connection health before using from the pool."
292287 )
293288
294289 # We'll store an async engine + session factory:
@@ -299,7 +294,7 @@ class AsyncPostgresMemory(AsyncMemoryProvider):
299294 _table_class_cache : Dict [str , Base ] = {}
300295
301296 _configured : bool = False
302-
297+
303298 async def configure (self , memory_key : str ) -> None :
304299 """
305300 1) Create an async engine.
@@ -347,12 +342,16 @@ async def configure(self, memory_key: str) -> None:
347342 # (2) Actually create it (async):
348343 def _sync_create (connection ):
349344 """Helper function to run table creation in sync context."""
350- Base .metadata .create_all (connection , tables = [memory_model .__table__ ])
345+ Base .metadata .create_all (
346+ connection , tables = [memory_model .__table__ ]
347+ )
351348
352349 try :
353350 await conn .run_sync (_sync_create )
354351 except ProgrammingError as e :
355- raise RuntimeError (f"Failed to create table '{ table_name } ': { e } " )
352+ raise RuntimeError (
353+ f"Failed to create table '{ table_name } ': { e } "
354+ )
356355
357356 # 4) Now that the DB and table are ready, create a session factory
358357 self ._SessionLocal = async_sessionmaker (
@@ -362,7 +361,6 @@ def _sync_create(connection):
362361
363362 self ._configured = True
364363
365-
366364 def _get_table (self , memory_key : str ) -> Base :
367365 """
368366 Return or create the dynamic model class for 'memory_{key}' table.
@@ -456,4 +454,4 @@ async def search(self, memory_key: str, query: str, n: int = 20) -> Dict[str, st
456454 rows = results .all ()
457455
458456 # Convert list of Row objects -> dict
459- return {row .id : row .text for row in rows }
457+ return {row .id : row .text for row in rows }
0 commit comments