@@ -151,6 +151,10 @@ def __init__(self, config: PolarDBGraphDBConfig):
151151 user = user ,
152152 password = password ,
153153 dbname = self .db_name ,
154+ connect_timeout = 60 , # Connection timeout in seconds
155+ keepalives_idle = 40 , # Seconds of inactivity before sending keepalive (should be < server idle timeout)
156+ keepalives_interval = 15 , # Seconds between keepalive retries
157+ keepalives_count = 5 , # Number of keepalive retries before considering connection dead
154158 )
155159
156160 # Keep a reference to the pool for cleanup
@@ -179,7 +183,7 @@ def _get_config_value(self, key: str, default=None):
179183 else :
180184 return getattr (self .config , key , default )
181185
182- def _get_connection (self ):
186+ def _get_connection_old (self ):
183187 """Get a connection from the pool."""
184188 if self ._pool_closed :
185189 raise RuntimeError ("Connection pool has been closed" )
@@ -188,7 +192,60 @@ def _get_connection(self):
188192 conn .autocommit = True
189193 return conn
190194
195+ def _get_connection (self ):
196+ """Get a connection from the pool."""
197+ if self ._pool_closed :
198+ raise RuntimeError ("Connection pool has been closed" )
199+
200+ max_retries = 3
201+ for attempt in range (max_retries ):
202+ try :
203+ conn = self .connection_pool .getconn ()
204+
205+ # Check if connection is closed
206+ if conn .closed != 0 :
207+ # Connection is closed, close it explicitly and try again
208+ try :
209+ conn .close ()
210+ except Exception as e :
211+ logger .warning (f"Failed to close connection: { e } " )
212+ if attempt < max_retries - 1 :
213+ continue
214+ else :
215+ raise RuntimeError ("Pool returned a closed connection" )
216+
217+ # Set autocommit for PolarDB compatibility
218+ conn .autocommit = True
219+ return conn
220+ except Exception as e :
221+ if attempt >= max_retries - 1 :
222+ raise RuntimeError (f"Failed to get a valid connection from pool: { e } " ) from e
223+ continue
224+
191225 def _return_connection (self , connection ):
226+ """Return a connection to the pool."""
227+ if not self ._pool_closed and connection :
228+ try :
229+ # Check if connection is closed
230+ if hasattr (connection , "closed" ) and connection .closed != 0 :
231+ # Connection is closed, just close it and don't return to pool
232+ try :
233+ connection .close ()
234+ except Exception as e :
235+ logger .warning (f"Failed to close connection: { e } " )
236+ return
237+
238+ # Connection is valid, return to pool
239+ self .connection_pool .putconn (connection )
240+ except Exception as e :
241+ # If putconn fails, close the connection
242+ logger .warning (f"Failed to return connection to pool: { e } " )
243+ try :
244+ connection .close ()
245+ except Exception as e :
246+ logger .warning (f"Failed to close connection: { e } " )
247+
248+ def _return_connection_old (self , connection ):
192249 """Return a connection to the pool."""
193250 if not self ._pool_closed and connection :
194251 self .connection_pool .putconn (connection )
@@ -306,7 +363,7 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in
306363 WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype
307364 """
308365 query += "\n AND ag_catalog.agtype_access_operator(properties, '\" user_name\" '::agtype) = %s::agtype"
309- params = [f'" { memory_type } "' , f'" { user_name } "' ]
366+ params = [self . format_param_value ( memory_type ), self . format_param_value ( user_name ) ]
310367
311368 # Get a connection from the pool
312369 conn = self ._get_connection ()
@@ -332,7 +389,7 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int:
332389 """
333390 query += "\n AND ag_catalog.agtype_access_operator(properties, '\" user_name\" '::agtype) = %s::agtype"
334391 query += "\n LIMIT 1"
335- params = [f'" { scope } "' , f'" { user_name } "' ]
392+ params = [self . format_param_value ( scope ), self . format_param_value ( user_name ) ]
336393
337394 # Get a connection from the pool
338395 conn = self ._get_connection ()
@@ -370,7 +427,11 @@ def remove_oldest_memory(
370427 ORDER BY ag_catalog.agtype_access_operator(properties, '"updated_at"'::agtype) DESC
371428 OFFSET %s
372429 """
373- select_params = [f'"{ memory_type } "' , f'"{ user_name } "' , keep_latest ]
430+ select_params = [
431+ self .format_param_value (memory_type ),
432+ self .format_param_value (user_name ),
433+ keep_latest ,
434+ ]
374435 conn = self ._get_connection ()
375436 try :
376437 with conn .cursor () as cursor :
@@ -444,19 +505,23 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N
444505 SET properties = %s, embedding = %s
445506 WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype
446507 """
447- params = [json .dumps (properties ), json .dumps (embedding_vector ), f'"{ id } "' ]
508+ params = [
509+ json .dumps (properties ),
510+ json .dumps (embedding_vector ),
511+ self .format_param_value (id ),
512+ ]
448513 else :
449514 query = f"""
450515 UPDATE "{ self .db_name } _graph"."Memory"
451516 SET properties = %s
452517 WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype
453518 """
454- params = [json .dumps (properties ), f'" { id } "' ]
519+ params = [json .dumps (properties ), self . format_param_value ( id ) ]
455520
456521 # Only add user filter when user_name is provided
457522 if user_name is not None :
458523 query += "\n AND ag_catalog.agtype_access_operator(properties, '\" user_name\" '::agtype) = %s::agtype"
459- params .append (f'" { user_name } "' )
524+ params .append (self . format_param_value ( user_name ) )
460525
461526 # Get a connection from the pool
462527 conn = self ._get_connection ()
@@ -481,12 +546,12 @@ def delete_node(self, id: str, user_name: str | None = None) -> None:
481546 DELETE FROM "{ self .db_name } _graph"."Memory"
482547 WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype
483548 """
484- params = [f'" { id } "' ]
549+ params = [self . format_param_value ( id ) ]
485550
486551 # Only add user filter when user_name is provided
487552 if user_name is not None :
488553 query += "\n AND ag_catalog.agtype_access_operator(properties, '\" user_name\" '::agtype) = %s::agtype"
489- params .append (f'" { user_name } "' )
554+ params .append (self . format_param_value ( user_name ) )
490555
491556 # Get a connection from the pool
492557 conn = self ._get_connection ()
@@ -774,28 +839,17 @@ def get_node(
774839
775840 select_fields = "id, properties, embedding" if include_embedding else "id, properties"
776841
777- # Helper function to format parameter value
778- def format_param_value (value : str ) -> str :
779- """Format parameter value to handle both quoted and unquoted formats"""
780- # Remove outer quotes if they exist
781- if value .startswith ('"' ) and value .endswith ('"' ):
782- # Already has double quotes, return as is
783- return value
784- else :
785- # Add double quotes
786- return f'"{ value } "'
787-
788842 query = f"""
789843 SELECT { select_fields }
790844 FROM "{ self .db_name } _graph"."Memory"
791845 WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype
792846 """
793- params = [format_param_value (id )]
847+ params = [self . format_param_value (id )]
794848
795849 # Only add user filter when user_name is provided
796850 if user_name is not None :
797851 query += "\n AND ag_catalog.agtype_access_operator(properties, '\" user_name\" '::agtype) = %s::agtype"
798- params .append (format_param_value (user_name ))
852+ params .append (self . format_param_value (user_name ))
799853
800854 conn = self ._get_connection ()
801855 try :
@@ -873,7 +927,7 @@ def get_nodes(
873927 where_conditions .append (
874928 "ag_catalog.agtype_access_operator(properties, '\" id\" '::agtype) = %s::agtype"
875929 )
876- params .append (f" { id_val } " )
930+ params .append (self . format_param_value ( id_val ) )
877931
878932 where_clause = " OR " .join (where_conditions )
879933
@@ -885,7 +939,7 @@ def get_nodes(
885939
886940 user_name = user_name if user_name else self .config .user_name
887941 query += " AND ag_catalog.agtype_access_operator(properties, '\" user_name\" '::agtype) = %s::agtype"
888- params .append (f'" { user_name } "' )
942+ params .append (self . format_param_value ( user_name ) )
889943
890944 conn = self ._get_connection ()
891945 try :
@@ -1834,7 +1888,7 @@ def export_graph(
18341888 if include_embedding and embedding_json is not None :
18351889 properties ["embedding" ] = embedding_json
18361890
1837- nodes .append (self ._parse_node (properties ))
1891+ nodes .append (self ._parse_node (json . loads ( properties [ 1 ]) ))
18381892
18391893 except Exception as e :
18401894 logger .error (f"[EXPORT GRAPH - NODES] Exception: { e } " , exc_info = True )
@@ -2559,7 +2613,7 @@ def get_neighbors_by_tag(
25592613 exclude_conditions .append (
25602614 "ag_catalog.agtype_access_operator(properties, '\" id\" '::agtype) != %s::agtype"
25612615 )
2562- params .append (f'" { exclude_id } "' )
2616+ params .append (self . format_param_value ( exclude_id ) )
25632617 where_clauses .append (f"({ ' AND ' .join (exclude_conditions )} )" )
25642618
25652619 # Status filter - keep only 'activated'
@@ -2576,7 +2630,7 @@ def get_neighbors_by_tag(
25762630 where_clauses .append (
25772631 "ag_catalog.agtype_access_operator(properties, '\" user_name\" '::agtype) = %s::agtype"
25782632 )
2579- params .append (f'" { user_name } "' )
2633+ params .append (self . format_param_value ( user_name ) )
25802634
25812635 # Testing showed no data; annotate.
25822636 where_clauses .append (
@@ -2965,3 +3019,18 @@ def _convert_graph_edges(self, core_node: dict) -> dict:
29653019 if tgt in id_map :
29663020 edge ["target" ] = id_map [tgt ]
29673021 return data
3022+
3023+ def format_param_value (self , value : str | None ) -> str :
3024+ """Format parameter value to handle both quoted and unquoted formats"""
3025+ # Handle None value
3026+ if value is None :
3027+ logger .warning (f"format_param_value: value is None" )
3028+ return "null"
3029+
3030+ # Remove outer quotes if they exist
3031+ if value .startswith ('"' ) and value .endswith ('"' ):
3032+ # Already has double quotes, return as is
3033+ return value
3034+ else :
3035+ # Add double quotes
3036+ return f'"{ value } "'
0 commit comments