@@ -109,8 +109,8 @@ def __init__(self, config: NebulaGraphDBConfig):
109109
110110 self .config = config
111111 self .client = NebulaClient (
112- hosts = config .get ("hosts " ),
113- username = config .get ("user_name " ),
112+ hosts = config .get ("uri " ),
113+ username = config .get ("user " ),
114114 password = config .get ("password" ),
115115 )
116116 self .db_name = config .space
@@ -132,19 +132,11 @@ def create_index(
132132 dimensions : int = 3072 ,
133133 index_name : str = "memory_vector_index" ,
134134 ) -> None :
135- create_vector_index = f"""
136- CREATE VECTOR INDEX IF NOT EXISTS { index_name }
137- ON NODE Memory::{ vector_property }
138- OPTIONS {{
139- DIM: { dimensions } ,
140- METRIC: L2,
141- TYPE: IVF,
142- NLIST: 100,
143- TRAINSIZE: 1000
144- }}
145- FOR memory_graph
146- """
147- self .client .execute (create_vector_index )
135+ # Create vector index if it doesn't exist
136+ if not self ._vector_index_exists (index_name ):
137+ self ._create_vector_index (label , vector_property , dimensions , index_name )
138+ # Create indexes
139+ self ._create_basic_property_indexes ()
148140
149141 def remove_oldest_memory (self , memory_type : str , keep_latest : int ) -> None :
150142 """
@@ -468,27 +460,54 @@ def get_neighbors_by_tag(
468460 Returns:
469461 List of dicts with node details and overlap count.
470462 """
471- where_user = ""
463+ if not tags :
464+ return []
465+
466+ where_clauses = [
467+ 'n.status = "activated"' ,
468+ 'NOT (n.node_type = "reasoning")' ,
469+ 'NOT (n.memory_type = "WorkingMemory")' ,
470+ ]
471+ if exclude_ids :
472+ where_clauses .append (f"NOT (n.id IN { exclude_ids } )" )
473+
472474 if not self .config .use_multi_db and self .config .user_name :
473- user_name = self .config .user_name
474- where_user = f"AND n.user_name = { user_name } "
475+ where_clauses .append (f'n.user_name = "{ self .config .user_name } "' )
476+
477+ where_clause = " AND " .join (where_clauses )
478+ tag_list_literal = "[" + ", " .join (f'"{ _escape_str (t )} "' for t in tags ) + "]"
475479
476480 query = f"""
477- MATCH (n@Memory)
478- LET overlap_tags = [tag IN n.tags WHERE tag IN { tags } ]
479- WHERE NOT n.id IN { exclude_ids }
480- AND n.status = 'activated'
481- AND n.node_type <> 'reasoning'
482- AND n.memory_type <> 'WorkingMemory'
483- { where_user }
484- AND size(overlap_tags) >= { min_overlap }
485- RETURN n, size(overlap_tags) AS overlap_count
486- ORDER BY overlap_count DESC
487- LIMIT { top_k }
488- """
489- print (query )
481+ LET tag_list = { tag_list_literal }
482+
483+ MATCH (n@Memory)
484+ WHERE { where_clause }
485+ RETURN n.id AS id,
486+ n.tags AS tags,
487+ n.user_name AS user_name,
488+ n.memory AS memory,
489+ n.status AS status,
490+ n.node_type AS node_type,
491+ n.memory_type AS memory_type,
492+ size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
493+ ORDER BY overlap_count DESC
494+ LIMIT { top_k }
495+ """
496+
490497 result = self .client .execute (query )
491- return [self ._parse_node (dict (record )) for record in result ]
498+ neighbors = []
499+ for row in result :
500+ props = {col : self ._parse_node (row [col ]) for col in result .column_names }
501+
502+ node_tags = props .get ("tags" , [])
503+ overlap_tags = list (set (node_tags ) & set (tags ))
504+
505+ if len (overlap_tags ) >= min_overlap :
506+ props ["overlap_count" ] = len (overlap_tags )
507+ neighbors .append (props )
508+
509+ neighbors .sort (key = lambda x : x ["overlap_count" ], reverse = True )
510+ return neighbors [:top_k ]
492511
493512 def get_children_with_embeddings (self , id : str ) -> list [dict [str , Any ]]:
494513 where_user = ""
@@ -503,10 +522,16 @@ def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]:
503522 RETURN c.id AS id, c.embedding AS embedding, c.memory AS memory
504523 """
505524 result = self .client .execute (query )
506- return [
507- {"id" : r ["id" ].value , "embedding" : r ["embedding" ].value , "memory" : r ["memory" ].value }
508- for r in result
509- ]
525+ children = []
526+ for r in result :
527+ children .append (
528+ {
529+ "id" : self ._parse_node (r ["id" ]),
530+ "embedding" : self ._parse_node (r ["embedding" ]),
531+ "memory" : self ._parse_node (r ["memory" ]),
532+ }
533+ )
534+ return children
510535
511536 def get_subgraph (
512537 self , center_id : str , depth : int = 2 , center_status : str = "activated"
@@ -1038,26 +1063,50 @@ def _ensure_database_exists(self):
10381063
10391064 # TODO
10401065 def _vector_index_exists (self , index_name : str = "memory_vector_index" ) -> bool :
1041- raise NotImplementedError
1066+ return False
10421067
1043- # TODO
10441068 def _create_vector_index (
10451069 self , label : str , vector_property : str , dimensions : int , index_name : str
10461070 ) -> None :
10471071 """
10481072 Create a vector index for the specified property in the label.
10491073 """
1050- raise NotImplementedError
1074+ create_vector_index = f"""
1075+ CREATE VECTOR INDEX IF NOT EXISTS { index_name }
1076+ ON NODE Memory::{ vector_property }
1077+ OPTIONS {{
1078+ DIM: { dimensions } ,
1079+ METRIC: L2,
1080+ TYPE: IVF,
1081+ NLIST: 100,
1082+ TRAINSIZE: 1000
1083+ }}
1084+ FOR memory_graph
1085+ """
1086+ self .client .execute (create_vector_index )
10511087
1052- # TODO
10531088 def _create_basic_property_indexes (self ) -> None :
10541089 """
1055- Create standard B-tree indexes on memory_type, created_at,
1090+ Create standard B-tree indexes on status, memory_type, created_at
10561091 and updated_at fields.
10571092 Create standard B-tree indexes on user_name when use Shared Database
1058- Multi-Tenant Mode
1093+ Multi-Tenant Mode.
10591094 """
1060- raise NotImplementedError
1095+ fields = ["status" , "memory_type" , "created_at" , "updated_at" ]
1096+ if not self .config .use_multi_db :
1097+ fields .append ("user_name" )
1098+
1099+ for field in fields :
1100+ index_name = f"idx_memory_{ field } "
1101+ gql = f"""
1102+ CREATE INDEX IF NOT EXISTS { index_name } ON NODE Memory({ field } )
1103+ FOR memory_graph
1104+ """
1105+ try :
1106+ self .client .execute (gql )
1107+ logger .info (f"✅ Created index: { index_name } on field { field } " )
1108+ except Exception as e :
1109+ logger .error (f"❌ Failed to create index { index_name } : { e } " )
10611110
10621111 def _index_exists (self , index_name : str ) -> bool :
10631112 """
@@ -1086,4 +1135,11 @@ def _parse_node(self, value: ValueWrapper) -> Any:
10861135 self ._parse_node (v ) if isinstance (v , ValueWrapper ) else v for v in primitive_value
10871136 ]
10881137
1138+ if type (primitive_value ).__name__ == "NVector" :
1139+ try :
1140+ return list (primitive_value .values )
1141+ except Exception as e3 :
1142+ logger .warning (f"Failed to convert NVector: { primitive_value } , error: { e3 } " )
1143+ return str (primitive_value )
1144+
10891145 return primitive_value
0 commit comments