Skip to content
This repository was archived by the owner on Mar 19, 2026. It is now read-only.

Commit 59d6c90

Browse files
authored
Merge pull request #397 from webcoderz/main
adding pgvector memory module
2 parents 4c02725 + 2793154 commit 59d6c90

File tree

3 files changed

+256
-0
lines changed

3 files changed

+256
-0
lines changed

examples/pg-memory.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import controlflow as cf
2+
from controlflow.memory.memory import Memory
3+
from controlflow.memory.providers.postgres import PostgresMemory
4+
5+
provider = PostgresMemory(
6+
database_url="postgresql://postgres:postgres@localhost:5432/your_database",
7+
# embedding_dimension=1536,
8+
# embedding_fn=OpenAIEmbeddings(),
9+
table_name="vector_db",
10+
)
11+
# Create a memory module for user preferences
12+
user_preferences = cf.Memory(
13+
key="user_preferences",
14+
instructions="Store and retrieve user preferences.",
15+
provider=provider,
16+
)
17+
18+
# Create an agent with access to the memory
19+
agent = cf.Agent(memories=[user_preferences])
20+
21+
22+
# Create a flow to ask for the user's favorite color
23+
@cf.flow
24+
def remember_color():
25+
return cf.run(
26+
"Ask the user for their favorite color and store it in memory",
27+
agents=[agent],
28+
interactive=True,
29+
)
30+
31+
32+
# Create a flow to recall the user's favorite color
33+
@cf.flow
34+
def recall_color():
35+
return cf.run(
36+
"What is the user's favorite color?",
37+
agents=[agent],
38+
)
39+
40+
41+
if __name__ == "__main__":
42+
print("First flow:")
43+
remember_color()
44+
45+
print("\nSecond flow:")
46+
result = recall_color()
47+
print(result)

src/controlflow/memory/memory.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,16 @@ def get_memory_provider(provider: str) -> MemoryProvider:
166166

167167
return lance_providers.LanceMemory()
168168

169+
# --- Postgres ---
170+
elif provider.startswith("postgres"):
171+
try:
172+
import sqlalchemy
173+
except ImportError:
174+
raise ImportError(
175+
"To use Postgres as a memory provider, please install the `sqlalchemy` package."
176+
)
177+
178+
import controlflow.memory.providers.postgres as postgres_providers
179+
180+
return postgres_providers.PostgresMemory()
169181
raise ValueError(f'Memory provider "{provider}" could not be loaded from a string.')
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import uuid
2+
from typing import Callable, Dict, Optional
3+
4+
import sqlalchemy
5+
from pgvector.sqlalchemy import Vector
6+
from pydantic import Field
7+
from sqlalchemy import Column, String, select, text
8+
from sqlalchemy.dialects.postgresql import ARRAY
9+
from sqlalchemy.exc import ProgrammingError
10+
from sqlalchemy.orm import Session, declarative_base, sessionmaker
11+
from sqlalchemy_utils import create_database, database_exists
12+
13+
import controlflow
14+
from controlflow.memory.memory import MemoryProvider
15+
16+
try:
17+
# For embeddings, we can use langchain_openai or any other library:
18+
from langchain_openai import OpenAIEmbeddings
19+
except ImportError:
20+
raise ImportError(
21+
"To use an embedding function similar to LanceDB's default, "
22+
"please install lancedb with: pip install lancedb"
23+
)
24+
25+
# SQLAlchemy base class for declarative models
26+
Base = declarative_base()
27+
28+
29+
class SQLMemoryTable(Base):
30+
"""
31+
A simple declarative model that represents a memory record.
32+
33+
We’ll dynamically set the __tablename__ at runtime.
34+
"""
35+
36+
__abstract__ = True
37+
id = Column(String, primary_key=True)
38+
text = Column(String)
39+
# Use pgvector for storing embeddings in a Postgres Vector column
40+
# vector = Column(Vector(dim=1536)) # Adjust dimension to match your embedding model
41+
42+
43+
class PostgresMemory(MemoryProvider):
44+
"""
45+
A ControlFlow MemoryProvider that stores text + embeddings in PostgreSQL
46+
using SQLAlchemy and pg_vector. Each Memory module gets its own table.
47+
"""
48+
49+
# Default database URL. You can point this to your actual Postgres instance.
50+
# Requires the pgvector extension installed and the sqlalchemy-pgvector package.
51+
database_url: str = Field(
52+
default="postgresql://user:password@localhost:5432/your_database",
53+
description="SQLAlchemy-compatible database URL to a Postgres instance with pgvector.",
54+
)
55+
table_name: str = Field(
56+
"memory_{key}",
57+
description="""
58+
Name of the table to store this memory partition. "{key}" will be replaced
59+
by the memory’s key attribute.
60+
""",
61+
)
62+
63+
embedding_dimension: int = Field(
64+
default=1536,
65+
description="Dimension of the embedding vectors. Match your model's output.",
66+
)
67+
68+
embedding_fn: Callable = Field(
69+
default_factory=lambda: OpenAIEmbeddings(
70+
model="text-embedding-ada-002",
71+
),
72+
description="A function that turns a string into a vector.",
73+
)
74+
75+
# Internal: keep a cached Session maker
76+
_SessionLocal: Optional[sessionmaker] = None
77+
78+
# This dict will map "table_name" -> "model class"
79+
_table_class_cache: Dict[str, Base] = {}
80+
81+
def configure(self, memory_key: str) -> None:
82+
"""
83+
Configure a SQLAlchemy session and ensure the table for this
84+
memory partition is created if it does not already exist.
85+
"""
86+
engine = sqlalchemy.create_engine(self.database_url)
87+
88+
# 2) If DB doesn't exist, create it!
89+
if not database_exists(engine.url):
90+
create_database(engine.url)
91+
92+
with engine.connect() as conn:
93+
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
94+
conn.commit()
95+
96+
self._SessionLocal = sessionmaker(bind=engine)
97+
98+
# Dynamically create a specialized table model for this memory_key
99+
table_name = self.table_name.format(key=memory_key)
100+
101+
# 1) Check if table already in metadata
102+
if table_name not in Base.metadata.tables:
103+
# 2) Create the dynamic class + table
104+
memory_model = type(
105+
f"SQLMemoryTable_{memory_key}",
106+
(SQLMemoryTable,),
107+
{
108+
"__tablename__": table_name,
109+
"vector": Column(Vector(dim=self.embedding_dimension)),
110+
},
111+
)
112+
113+
try:
114+
Base.metadata.create_all(engine, tables=[memory_model.__table__])
115+
# Store it in the cache
116+
self._table_class_cache[table_name] = memory_model
117+
except ProgrammingError as e:
118+
raise RuntimeError(f"Failed to create table {table_name}: {e}")
119+
120+
def _get_session(self) -> Session:
121+
if not self._SessionLocal:
122+
raise RuntimeError(
123+
"Session is not initialized. Make sure to call configure() first."
124+
)
125+
return self._SessionLocal()
126+
127+
def _get_table(self, memory_key: str) -> Base:
128+
"""
129+
Return a dynamically generated declarative model class
130+
mapped to the memory_{key} table. Each memory partition
131+
has a separate table.
132+
"""
133+
table_name = self.table_name.format(key=memory_key)
134+
135+
# Return the cached class if already built
136+
if table_name in self._table_class_cache:
137+
return self._table_class_cache[table_name]
138+
139+
# If for some reason it's not there, create it now (or raise error):
140+
memory_model = type(
141+
f"SQLMemoryTable_{memory_key}",
142+
(SQLMemoryTable,),
143+
{
144+
"__tablename__": table_name,
145+
"vector": Column(Vector(dim=self.embedding_dimension)),
146+
},
147+
)
148+
self._table_class_cache[table_name] = memory_model
149+
return memory_model
150+
151+
def add(self, memory_key: str, content: str) -> str:
152+
"""
153+
Insert a new memory record into the Postgres table,
154+
generating an embedding and storing it in a vector column.
155+
Returns the memory’s ID (uuid).
156+
"""
157+
memory_id = str(uuid.uuid4())
158+
model_cls = self._get_table(memory_key)
159+
160+
# Generate an embedding for the content
161+
embedding = self.embedding_fn.embed_query(content)
162+
163+
with self._get_session() as session:
164+
record = model_cls(id=memory_id, text=content, vector=embedding)
165+
session.add(record)
166+
session.commit()
167+
168+
return memory_id
169+
170+
def delete(self, memory_key: str, memory_id: str) -> None:
171+
"""
172+
Delete a memory record by its UUID.
173+
"""
174+
model_cls = self._get_table(memory_key)
175+
176+
with self._get_session() as session:
177+
session.query(model_cls).filter(model_cls.id == memory_id).delete()
178+
session.commit()
179+
180+
def search(self, memory_key: str, query: str, n: int = 20) -> Dict[str, str]:
181+
"""
182+
Uses pgvector’s approximate nearest neighbor search with the `<->` operator to find
183+
the top N matching records for the embedded query. Returns a dict of {id: text}.
184+
"""
185+
model_cls = self._get_table(memory_key)
186+
# Generate embedding for the query
187+
query_embedding = self.embedding_fn.embed_query(query)
188+
embedding_col = model_cls.vector
189+
190+
with self._get_session() as session:
191+
results = session.execute(
192+
select(model_cls.id, model_cls.text)
193+
.order_by(embedding_col.l2_distance(query_embedding))
194+
.limit(n)
195+
).all()
196+
197+
return {row.id: row.text for row in results}

0 commit comments

Comments
 (0)