Skip to content

Commit c3141b9

Browse files
authored
Merge branch 'dev' into feat/add-request-log
2 parents 11d6ae3 + e79a9ab commit c3141b9

File tree

3 files changed

+137
-36
lines changed

3 files changed

+137
-36
lines changed

src/memos/graph_dbs/polardb.py

Lines changed: 96 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ def __init__(self, config: PolarDBGraphDBConfig):
151151
user=user,
152152
password=password,
153153
dbname=self.db_name,
154+
connect_timeout=60, # Connection timeout in seconds
155+
keepalives_idle=40, # Seconds of inactivity before sending keepalive (should be < server idle timeout)
156+
keepalives_interval=15, # Seconds between keepalive retries
157+
keepalives_count=5, # Number of keepalive retries before considering connection dead
154158
)
155159

156160
# Keep a reference to the pool for cleanup
@@ -179,7 +183,7 @@ def _get_config_value(self, key: str, default=None):
179183
else:
180184
return getattr(self.config, key, default)
181185

182-
def _get_connection(self):
186+
def _get_connection_old(self):
183187
"""Get a connection from the pool."""
184188
if self._pool_closed:
185189
raise RuntimeError("Connection pool has been closed")
@@ -188,7 +192,60 @@ def _get_connection(self):
188192
conn.autocommit = True
189193
return conn
190194

195+
def _get_connection(self):
196+
"""Get a connection from the pool."""
197+
if self._pool_closed:
198+
raise RuntimeError("Connection pool has been closed")
199+
200+
max_retries = 3
201+
for attempt in range(max_retries):
202+
try:
203+
conn = self.connection_pool.getconn()
204+
205+
# Check if connection is closed
206+
if conn.closed != 0:
207+
# Connection is closed, close it explicitly and try again
208+
try:
209+
conn.close()
210+
except Exception as e:
211+
logger.warning(f"Failed to close connection: {e}")
212+
if attempt < max_retries - 1:
213+
continue
214+
else:
215+
raise RuntimeError("Pool returned a closed connection")
216+
217+
# Set autocommit for PolarDB compatibility
218+
conn.autocommit = True
219+
return conn
220+
except Exception as e:
221+
if attempt >= max_retries - 1:
222+
raise RuntimeError(f"Failed to get a valid connection from pool: {e}") from e
223+
continue
224+
191225
def _return_connection(self, connection):
226+
"""Return a connection to the pool."""
227+
if not self._pool_closed and connection:
228+
try:
229+
# Check if connection is closed
230+
if hasattr(connection, "closed") and connection.closed != 0:
231+
# Connection is closed, just close it and don't return to pool
232+
try:
233+
connection.close()
234+
except Exception as e:
235+
logger.warning(f"Failed to close connection: {e}")
236+
return
237+
238+
# Connection is valid, return to pool
239+
self.connection_pool.putconn(connection)
240+
except Exception as e:
241+
# If putconn fails, close the connection
242+
logger.warning(f"Failed to return connection to pool: {e}")
243+
try:
244+
connection.close()
245+
except Exception as e:
246+
logger.warning(f"Failed to close connection: {e}")
247+
248+
def _return_connection_old(self, connection):
192249
"""Return a connection to the pool."""
193250
if not self._pool_closed and connection:
194251
self.connection_pool.putconn(connection)
@@ -306,7 +363,7 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in
306363
WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype
307364
"""
308365
query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
309-
params = [f'"{memory_type}"', f'"{user_name}"']
366+
params = [self.format_param_value(memory_type), self.format_param_value(user_name)]
310367

311368
# Get a connection from the pool
312369
conn = self._get_connection()
@@ -332,7 +389,7 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int:
332389
"""
333390
query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
334391
query += "\nLIMIT 1"
335-
params = [f'"{scope}"', f'"{user_name}"']
392+
params = [self.format_param_value(scope), self.format_param_value(user_name)]
336393

337394
# Get a connection from the pool
338395
conn = self._get_connection()
@@ -370,7 +427,11 @@ def remove_oldest_memory(
370427
ORDER BY ag_catalog.agtype_access_operator(properties, '"updated_at"'::agtype) DESC
371428
OFFSET %s
372429
"""
373-
select_params = [f'"{memory_type}"', f'"{user_name}"', keep_latest]
430+
select_params = [
431+
self.format_param_value(memory_type),
432+
self.format_param_value(user_name),
433+
keep_latest,
434+
]
374435
conn = self._get_connection()
375436
try:
376437
with conn.cursor() as cursor:
@@ -444,19 +505,23 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N
444505
SET properties = %s, embedding = %s
445506
WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype
446507
"""
447-
params = [json.dumps(properties), json.dumps(embedding_vector), f'"{id}"']
508+
params = [
509+
json.dumps(properties),
510+
json.dumps(embedding_vector),
511+
self.format_param_value(id),
512+
]
448513
else:
449514
query = f"""
450515
UPDATE "{self.db_name}_graph"."Memory"
451516
SET properties = %s
452517
WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype
453518
"""
454-
params = [json.dumps(properties), f'"{id}"']
519+
params = [json.dumps(properties), self.format_param_value(id)]
455520

456521
# Only add user filter when user_name is provided
457522
if user_name is not None:
458523
query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
459-
params.append(f'"{user_name}"')
524+
params.append(self.format_param_value(user_name))
460525

461526
# Get a connection from the pool
462527
conn = self._get_connection()
@@ -481,12 +546,12 @@ def delete_node(self, id: str, user_name: str | None = None) -> None:
481546
DELETE FROM "{self.db_name}_graph"."Memory"
482547
WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype
483548
"""
484-
params = [f'"{id}"']
549+
params = [self.format_param_value(id)]
485550

486551
# Only add user filter when user_name is provided
487552
if user_name is not None:
488553
query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
489-
params.append(f'"{user_name}"')
554+
params.append(self.format_param_value(user_name))
490555

491556
# Get a connection from the pool
492557
conn = self._get_connection()
@@ -774,28 +839,17 @@ def get_node(
774839

775840
select_fields = "id, properties, embedding" if include_embedding else "id, properties"
776841

777-
# Helper function to format parameter value
778-
def format_param_value(value: str) -> str:
779-
"""Format parameter value to handle both quoted and unquoted formats"""
780-
# Remove outer quotes if they exist
781-
if value.startswith('"') and value.endswith('"'):
782-
# Already has double quotes, return as is
783-
return value
784-
else:
785-
# Add double quotes
786-
return f'"{value}"'
787-
788842
query = f"""
789843
SELECT {select_fields}
790844
FROM "{self.db_name}_graph"."Memory"
791845
WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype
792846
"""
793-
params = [format_param_value(id)]
847+
params = [self.format_param_value(id)]
794848

795849
# Only add user filter when user_name is provided
796850
if user_name is not None:
797851
query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
798-
params.append(format_param_value(user_name))
852+
params.append(self.format_param_value(user_name))
799853

800854
conn = self._get_connection()
801855
try:
@@ -873,7 +927,7 @@ def get_nodes(
873927
where_conditions.append(
874928
"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = %s::agtype"
875929
)
876-
params.append(f"{id_val}")
930+
params.append(self.format_param_value(id_val))
877931

878932
where_clause = " OR ".join(where_conditions)
879933

@@ -885,7 +939,7 @@ def get_nodes(
885939

886940
user_name = user_name if user_name else self.config.user_name
887941
query += " AND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
888-
params.append(f'"{user_name}"')
942+
params.append(self.format_param_value(user_name))
889943

890944
conn = self._get_connection()
891945
try:
@@ -1834,7 +1888,7 @@ def export_graph(
18341888
if include_embedding and embedding_json is not None:
18351889
properties["embedding"] = embedding_json
18361890

1837-
nodes.append(self._parse_node(properties))
1891+
nodes.append(self._parse_node(json.loads(properties[1])))
18381892

18391893
except Exception as e:
18401894
logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True)
@@ -2559,7 +2613,7 @@ def get_neighbors_by_tag(
25592613
exclude_conditions.append(
25602614
"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) != %s::agtype"
25612615
)
2562-
params.append(f'"{exclude_id}"')
2616+
params.append(self.format_param_value(exclude_id))
25632617
where_clauses.append(f"({' AND '.join(exclude_conditions)})")
25642618

25652619
# Status filter - keep only 'activated'
@@ -2576,7 +2630,7 @@ def get_neighbors_by_tag(
25762630
where_clauses.append(
25772631
"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype"
25782632
)
2579-
params.append(f'"{user_name}"')
2633+
params.append(self.format_param_value(user_name))
25802634

25812635
# Testing showed no data; annotate.
25822636
where_clauses.append(
@@ -2965,3 +3019,18 @@ def _convert_graph_edges(self, core_node: dict) -> dict:
29653019
if tgt in id_map:
29663020
edge["target"] = id_map[tgt]
29673021
return data
3022+
3023+
def format_param_value(self, value: str | None) -> str:
3024+
"""Format parameter value to handle both quoted and unquoted formats"""
3025+
# Handle None value
3026+
if value is None:
3027+
logger.warning(f"format_param_value: value is None")
3028+
return "null"
3029+
3030+
# Remove outer quotes if they exist
3031+
if value.startswith('"') and value.endswith('"'):
3032+
# Already has double quotes, return as is
3033+
return value
3034+
else:
3035+
# Add double quotes
3036+
return f'"{value}"'

src/memos/mem_os/product_server.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,19 +71,15 @@ def chat(
7171
m.metadata.embedding = []
7272
new_memories_list.append(m)
7373
memories_list = new_memories_list
74-
system_prompt = self._build_base_system_prompt(base_prompt, mode="base")
75-
76-
memory_context = self._build_memory_context(memories_list, mode="base")
77-
78-
user_content = memory_context + query if memory_context else query
74+
system_prompt = self._build_system_prompt(memories_list, base_prompt)
7975

8076
history_info = []
8177
if history:
8278
history_info = history[-20:]
8379
current_messages = [
8480
{"role": "system", "content": system_prompt},
8581
*history_info,
86-
{"role": "user", "content": user_content},
82+
{"role": "user", "content": query},
8783
]
8884
response = self.chat_llm.generate(current_messages)
8985
time_end = time.time()
@@ -187,6 +183,42 @@ def _build_base_system_prompt(
187183
prefix = (base_prompt.strip() + "\n\n") if base_prompt else ""
188184
return prefix + sys_body
189185

186+
def _build_system_prompt(
187+
self,
188+
memories: list[TextualMemoryItem] | list[str] | None = None,
189+
base_prompt: str | None = None,
190+
**kwargs,
191+
) -> str:
192+
"""Build system prompt with optional memories context."""
193+
if base_prompt is None:
194+
base_prompt = (
195+
"You are a knowledgeable and helpful AI assistant. "
196+
"You have access to conversation memories that help you provide more personalized responses. "
197+
"Use the memories to understand the user's context, preferences, and past interactions. "
198+
"If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories."
199+
)
200+
201+
memory_context = ""
202+
if memories:
203+
memory_list = []
204+
for i, memory in enumerate(memories, 1):
205+
if isinstance(memory, TextualMemoryItem):
206+
text_memory = memory.memory
207+
else:
208+
if not isinstance(memory, str):
209+
logger.error("Unexpected memory type.")
210+
text_memory = memory
211+
memory_list.append(f"{i}. {text_memory}")
212+
memory_context = "\n".join(memory_list)
213+
214+
if "{memories}" in base_prompt:
215+
return base_prompt.format(memories=memory_context)
216+
elif base_prompt and memories:
217+
# For backward compatibility, append memories if no placeholder is found
218+
memory_context_with_header = "\n\n## Memories:\n" + memory_context
219+
return base_prompt + memory_context_with_header
220+
return base_prompt
221+
190222
def _build_memory_context(
191223
self,
192224
memories_all: list[TextualMemoryItem],

src/memos/memories/textual/tree.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def get_relevant_subgraph(
260260
center_id=core_id, depth=depth, center_status=center_status
261261
)
262262

263-
if not subgraph["core_node"]:
263+
if subgraph is None or not subgraph["core_node"]:
264264
logger.info(f"Skipping node {core_id} (inactive or not found).")
265265
continue
266266

@@ -281,9 +281,9 @@ def get_relevant_subgraph(
281281
{"id": core_id, "score": score, "core_node": core_node, "neighbors": neighbors}
282282
)
283283

284-
top_core = cores[0]
284+
top_core = cores[0] if cores else None
285285
return {
286-
"core_id": top_core["id"],
286+
"core_id": top_core["id"] if top_core else None,
287287
"nodes": list(all_nodes.values()),
288288
"edges": [{"source": f, "target": t, "type": ty} for (f, t, ty) in all_edges],
289289
}

0 commit comments

Comments
 (0)