11from datetime import datetime
22from typing import Any , Literal
33
4+ from nebulagraph_python .value_wrapper import ValueWrapper
5+
46from memos .configs .graph_db import NebulaGraphDBConfig
57from memos .dependency import require_python_package
68from memos .graph_dbs .base import BaseGraphDB
@@ -170,7 +172,35 @@ def edge_exists(
170172
171173 # Graph Query & Reasoning
172174 def get_node (self , id : str ) -> dict [str , Any ] | None :
173- raise NotImplementedError
175+ """
176+ Retrieve a Memory node by its unique ID.
177+
178+ Args:
179+ id (str): Node ID (Memory.id)
180+
181+ Returns:
182+ dict: Node properties as key-value pairs, or None if not found.
183+ """
184+ gql = f"""
185+ USE memory_graph
186+ MATCH (v {{id: '{ id } '}})
187+ RETURN v
188+ """
189+
190+ try :
191+ result = self .client .execute (gql )
192+ record = result .one_or_none ()
193+ if record is None :
194+ return None
195+
196+ node_wrapper = record ["v" ].as_node ()
197+ props = node_wrapper .get_properties ()
198+
199+ return {key : self ._parse_node (val ) for key , val in props .items ()}
200+
201+ except Exception as e :
202+ logger .error (f"[get_node] Failed to retrieve node '{ id } ': { e } " )
203+ return None
174204
175205 def get_nodes (self , ids : list [str ]) -> list [dict [str , Any ]]:
176206 raise NotImplementedError
@@ -237,8 +267,49 @@ def search_by_embedding(
237267 - Typical use case: restrict to 'status = activated' to avoid
238268 matching archived or merged nodes.
239269 """
270+ dim = len (vector )
271+ vector_str = "," .join (f"{ float (x )} " for x in vector )
272+ gql_vector = f"VECTOR<{ dim } , FLOAT>([{ vector_str } ])"
273+
274+ where_clauses = []
275+ if scope :
276+ where_clauses .append (f'n.memory_type == "{ scope } "' )
277+ if status :
278+ where_clauses .append (f'n.status == "{ status } "' )
279+ if not self .config .use_multi_db and self .config .user_name :
280+ where_clauses .append (f'n.user_name == "{ self .config .user_name } "' )
240281
241- raise NotImplementedError
282+ where_clause = f"WHERE { ' AND ' .join (where_clauses )} " if where_clauses else ""
283+
284+ gql = f"""
285+ USE memory_graph
286+ MATCH (n@Memory)
287+ { where_clause }
288+ ORDER BY euclidean(n.embedding, { gql_vector } ) ASC
289+ APPROXIMATE
290+ LIMIT { top_k }
291+ OPTIONS {{ METRIC: L2, TYPE: IVF, NPROBE: 8 }}
292+ RETURN n.id AS id, euclidean(n.embedding, { gql_vector } ) AS score
293+ """
294+
295+ try :
296+ result = self .client .execute (gql )
297+ except Exception as e :
298+ logger .error (f"[search_by_embedding] Query failed: { e } " )
299+ return []
300+
301+ try :
302+ output = []
303+ for row in result :
304+ values = row .values ()
305+ id_val = values [0 ].as_string ()
306+ score_val = values [1 ].as_double ()
307+ if threshold is None or score_val <= threshold :
308+ output .append ({"id" : id_val , "score" : score_val })
309+ return output
310+ except Exception as e :
311+ logger .error (f"[search_by_embedding] Result parse failed: { e } " )
312+ return []
242313
243314 def get_by_metadata (self , filters : list [dict [str , Any ]]) -> list [str ]:
244315 raise NotImplementedError
@@ -356,10 +427,23 @@ def _ensure_database_exists(self):
356427 """
357428 create_graph = "CREATE GRAPH IF NOT EXISTS memory_graph TYPED MemoryGraphType"
358429 set_graph_working = "SESSION SET GRAPH memory_graph"
430+ create_vector_index = """
431+ CREATE VECTOR INDEX IF NOT EXISTS memory_vector_index
432+ ON NODE Memory::embedding
433+ OPTIONS {
434+ DIM: 3072,
435+ METRIC: L2,
436+ TYPE: IVF,
437+ NLIST: 100,
438+ TRAINSIZE: 1000
439+ }
440+ FOR memory_graph
441+ """
359442 try :
360443 self .client .execute (create_tag )
361444 self .client .execute (create_graph )
362445 self .client .execute (set_graph_working )
446+ self .client .execute (create_vector_index )
363447 logger .info ("✅ Graph `memory_graph` is now the working graph." )
364448 except Exception as e :
365449 logger .error (f"❌ Failed to create tag: { e } " )
@@ -378,5 +462,25 @@ def _create_basic_property_indexes(self) -> None:
378462 def _index_exists (self , index_name : str ) -> bool :
379463 """raise NotImplementedError"""
380464
381- def _parse_node (self , node_data : dict [str , Any ]) -> dict [str , Any ]:
382- return node_data
465+ def _parse_node (self , value : ValueWrapper ) -> Any :
466+ if value is None or value .is_null ():
467+ return None
468+ try :
469+ primitive_value = value .cast_primitive ()
470+ except Exception as e :
471+ logger .warning (f"cast_primitive failed for value: { value } , error: { e } " )
472+ try :
473+ primitive_value = value .cast ()
474+ except Exception as e2 :
475+ logger .warning (f"cast failed for value: { value } , error: { e2 } " )
476+ return str (value )
477+
478+ if isinstance (primitive_value , ValueWrapper ):
479+ return self ._parse_node (primitive_value )
480+
481+ if isinstance (primitive_value , list ):
482+ return [
483+ self ._parse_node (v ) if isinstance (v , ValueWrapper ) else v for v in primitive_value
484+ ]
485+
486+ return primitive_value
0 commit comments