Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(self, config: PolarDBGraphDBConfig):
# Create connection pool
self.connection_pool = psycopg2.pool.ThreadedConnectionPool(
minconn=5,
maxconn=2000,
maxconn=100,
host=host,
port=port,
user=user,
Expand Down Expand Up @@ -1338,6 +1338,7 @@ def get_subgraph(
"edges": [...]
}
"""
logger.info(f"[get_subgraph] center_id: {center_id}")
if not 1 <= depth <= 5:
raise ValueError("depth must be 1-5")

Expand Down Expand Up @@ -1375,6 +1376,7 @@ def get_subgraph(
$$ ) as (centers agtype, neighbors agtype, rels agtype);
"""
conn = self._get_connection()
logger.info(f"[get_subgraph] Query: {query}")
try:
with conn.cursor() as cursor:
cursor.execute(query)
Expand Down Expand Up @@ -1746,6 +1748,7 @@ def search_by_embedding(

# Build filter conditions using common method
filter_conditions = self._build_filter_conditions_sql(filter)
logger.info(f"[search_by_embedding] filter_conditions: {filter_conditions}")
where_clauses.extend(filter_conditions)

where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
Expand Down Expand Up @@ -1918,7 +1921,7 @@ def get_by_metadata(
knowledgebase_ids=knowledgebase_ids,
default_user_name=self._get_config_value("user_name"),
)
print(f"[111get_by_metadata] user_name_conditions: {user_name_conditions}")
logger.info(f"[get_by_metadata] user_name_conditions: {user_name_conditions}")

# Add user_name WHERE clause
if user_name_conditions:
Expand All @@ -1929,6 +1932,7 @@ def get_by_metadata(

# Build filter conditions using common method
filter_where_clause = self._build_filter_conditions_cypher(filter)
logger.info(f"[get_by_metadata] filter_where_clause: {filter_where_clause}")

where_str = " AND ".join(where_conditions) + filter_where_clause

Expand Down Expand Up @@ -2393,6 +2397,7 @@ def get_all_memory_items(

# Build filter conditions using common method
filter_where_clause = self._build_filter_conditions_cypher(filter)
logger.info(f"[get_all_memory_items] filter_where_clause: {filter_where_clause}")

# Use cypher query to retrieve memory items
if include_embedding:
Expand Down Expand Up @@ -2426,6 +2431,7 @@ def get_all_memory_items(
nodes = []
node_ids = set()
conn = self._get_connection()
logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}")
try:
with conn.cursor() as cursor:
cursor.execute(cypher_query)
Expand Down Expand Up @@ -3456,7 +3462,11 @@ def _convert_graph_edges(self, core_node: dict) -> dict:
id_map = {}
core_node = data.get("core_node", {})
if not core_node:
return core_node
return {
"core_node": None,
"neighbors": data.get("neighbors", []),
"edges": data.get("edges", []),
}
core_meta = core_node.get("metadata", {})
if "graph_id" in core_meta and "id" in core_node:
id_map[core_meta["graph_id"]] = core_node["id"]
Expand Down Expand Up @@ -3507,7 +3517,6 @@ def _build_user_name_and_kb_ids_conditions_cypher(
"""
user_name_conditions = []
effective_user_name = user_name if user_name else default_user_name
print(f"[delete_node_by_prams] effective_user_name: {effective_user_name}")

if effective_user_name:
escaped_user_name = effective_user_name.replace("'", "''")
Expand Down
29 changes: 4 additions & 25 deletions src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,16 @@ def get_relevant_subgraph(
)

if subgraph is None or not subgraph["core_node"]:
logger.info(f"Skipping node {core_id} (inactive or not found).")
continue
node = self.graph_store.get_node(core_id, user_name=user_name)
subgraph["neighbors"] = [node]

core_node = subgraph["core_node"]
neighbors = subgraph["neighbors"]
edges = subgraph["edges"]

# Collect nodes
all_nodes[core_node["id"]] = core_node
if core_node:
all_nodes[core_node["id"]] = core_node
for n in neighbors:
all_nodes[n["id"]] = n

Expand Down Expand Up @@ -339,28 +340,6 @@ def delete_all(self) -> None:
logger.error(f"An error occurred while deleting all memories: {e}")
raise

def delete_by_filter(
self,
writable_cube_ids: list[str],
memory_ids: list[str] | None = None,
file_ids: list[str] | None = None,
filter: dict | None = None,
) -> int:
"""Delete memories by filter.
Returns:
int: Number of nodes deleted.
"""
try:
return self.graph_store.delete_node_by_prams(
writable_cube_ids=writable_cube_ids,
memory_ids=memory_ids,
file_ids=file_ids,
filter=filter,
)
except Exception as e:
logger.error(f"An error occurred while deleting memories by filter: {e}")
raise

def load(self, dir: str) -> None:
try:
memory_file = os.path.join(dir, self.config.memory_filename)
Expand Down
Loading