@@ -1361,89 +1361,157 @@ def get_subgraph(
13611361 r)
13621362 $$ ) as (centers agtype, neighbors agtype, rels agtype);
13631363 """
1364- query = f"""
1365- SELECT * FROM cypher('{ self .db_name } _graph', $$
1366- MATCH(center: Memory)-[r * 1..{ depth } ]->(neighbor:Memory)
1367- WHERE
1368- center.id = '{ center_id } '
1369- AND center.status = '{ center_status } '
1370- AND center.user_name = '{ user_name } '
1371- RETURN
1372- collect(DISTINCT
1373- center), collect(DISTINCT
1374- neighbor), collect(DISTINCT
1375- r)
1376- $$ ) as (centers agtype, neighbors agtype, rels agtype);
1377- """
1364+ # Use UNION ALL for better performance: separate queries for depth 1 and depth 2
1365+ if depth == 1 :
1366+ query = f"""
1367+ SELECT * FROM cypher('{ self .db_name } _graph', $$
1368+ MATCH(center: Memory)-[r]->(neighbor:Memory)
1369+ WHERE
1370+ center.id = '{ center_id } '
1371+ AND center.status = '{ center_status } '
1372+ AND center.user_name = '{ user_name } '
1373+ RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r)
1374+ $$ ) as (centers agtype, neighbors agtype, rels agtype);
1375+ """
1376+ else :
1377+ # For depth >= 2, use UNION ALL to combine depth 1 and depth 2 queries
1378+ query = f"""
1379+ SELECT * FROM cypher('{ self .db_name } _graph', $$
1380+ MATCH(center: Memory)-[r]->(neighbor:Memory)
1381+ WHERE
1382+ center.id = '{ center_id } '
1383+ AND center.status = '{ center_status } '
1384+ AND center.user_name = '{ user_name } '
1385+ RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r)
1386+ UNION ALL
1387+ MATCH(center: Memory)-[r]->(n:Memory)-[r1]->(neighbor:Memory)
1388+ WHERE
1389+ center.id = '{ center_id } '
1390+ AND center.status = '{ center_status } '
1391+ AND center.user_name = '{ user_name } '
1392+ RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r1)
1393+ $$ ) as (centers agtype, neighbors agtype, rels agtype);
1394+ """
13781395 conn = self ._get_connection ()
13791396 logger .info (f"[get_subgraph] Query: { query } " )
13801397 try :
13811398 with conn .cursor () as cursor :
13821399 cursor .execute (query )
1383- result = cursor .fetchone ()
1400+ results = cursor .fetchall ()
13841401
1385- if not result or not result [ 0 ] :
1402+ if not results :
13861403 return {"core_node" : None , "neighbors" : [], "edges" : []}
13871404
1388- # Parse center node
1389- centers_data = result [ 0 ] if result [ 0 ] else "[]"
1390- neighbors_data = result [ 1 ] if result [ 1 ] else "[]"
1391- edges_data = result [ 2 ] if result [ 2 ] else "[]"
1405+ # Merge results from all UNION ALL rows
1406+ all_centers_list = []
1407+ all_neighbors_list = []
1408+ all_edges_list = []
13921409
1393- # Parse JSON data
1394- try :
1395- # Clean ::vertex and ::edge suffixes in data
1396- if isinstance (centers_data , str ):
1397- centers_data = centers_data .replace ("::vertex" , "" )
1398- if isinstance (neighbors_data , str ):
1399- neighbors_data = neighbors_data .replace ("::vertex" , "" )
1400- if isinstance (edges_data , str ):
1401- edges_data = edges_data .replace ("::edge" , "" )
1402-
1403- centers_list = (
1404- json .loads (centers_data ) if isinstance (centers_data , str ) else centers_data
1405- )
1406- neighbors_list = (
1407- json .loads (neighbors_data )
1408- if isinstance (neighbors_data , str )
1409- else neighbors_data
1410- )
1411- edges_list = (
1412- json .loads (edges_data ) if isinstance (edges_data , str ) else edges_data
1413- )
1414- except json .JSONDecodeError as e :
1415- logger .error (f"Failed to parse JSON data: { e } " )
1416- return {"core_node" : None , "neighbors" : [], "edges" : []}
1410+ for result in results :
1411+ if not result or not result [0 ]:
1412+ continue
1413+
1414+ centers_data = result [0 ] if result [0 ] else "[]"
1415+ neighbors_data = result [1 ] if result [1 ] else "[]"
1416+ edges_data = result [2 ] if result [2 ] else "[]"
1417+
1418+ # Parse JSON data
1419+ try :
1420+ # Clean ::vertex and ::edge suffixes in data
1421+ if isinstance (centers_data , str ):
1422+ centers_data = centers_data .replace ("::vertex" , "" )
1423+ if isinstance (neighbors_data , str ):
1424+ neighbors_data = neighbors_data .replace ("::vertex" , "" )
1425+ if isinstance (edges_data , str ):
1426+ edges_data = edges_data .replace ("::edge" , "" )
1427+
1428+ centers_list = (
1429+ json .loads (centers_data )
1430+ if isinstance (centers_data , str )
1431+ else centers_data
1432+ )
1433+ neighbors_list = (
1434+ json .loads (neighbors_data )
1435+ if isinstance (neighbors_data , str )
1436+ else neighbors_data
1437+ )
1438+ edges_list = (
1439+ json .loads (edges_data ) if isinstance (edges_data , str ) else edges_data
1440+ )
1441+
1442+ # Collect data from this row
1443+ if isinstance (centers_list , list ):
1444+ all_centers_list .extend (centers_list )
1445+ if isinstance (neighbors_list , list ):
1446+ all_neighbors_list .extend (neighbors_list )
1447+ if isinstance (edges_list , list ):
1448+ all_edges_list .extend (edges_list )
1449+ except json .JSONDecodeError as e :
1450+ logger .error (f"Failed to parse JSON data: { e } " )
1451+ continue
14171452
1418- # Parse center node
1453+ # Deduplicate centers by ID
1454+ centers_dict = {}
1455+ for center_data in all_centers_list :
1456+ if isinstance (center_data , dict ) and "properties" in center_data :
1457+ center_id_key = center_data ["properties" ].get ("id" )
1458+ if center_id_key and center_id_key not in centers_dict :
1459+ centers_dict [center_id_key ] = center_data
1460+
1461+ # Parse center node (use first center)
14191462 core_node = None
1420- if centers_list and len ( centers_list ) > 0 :
1421- center_data = centers_list [ 0 ]
1463+ if centers_dict :
1464+ center_data = next ( iter ( centers_dict . values ()))
14221465 if isinstance (center_data , dict ) and "properties" in center_data :
14231466 core_node = self ._parse_node (center_data ["properties" ])
14241467
1468+ # Deduplicate neighbors by ID
1469+ neighbors_dict = {}
1470+ for neighbor_data in all_neighbors_list :
1471+ if isinstance (neighbor_data , dict ) and "properties" in neighbor_data :
1472+ neighbor_id = neighbor_data ["properties" ].get ("id" )
1473+ if neighbor_id and neighbor_id not in neighbors_dict :
1474+ neighbors_dict [neighbor_id ] = neighbor_data
1475+
14251476 # Parse neighbor nodes
14261477 neighbors = []
1427- if isinstance (neighbors_list , list ):
1428- for neighbor_data in neighbors_list :
1429- if isinstance (neighbor_data , dict ) and "properties" in neighbor_data :
1430- neighbor_parsed = self ._parse_node (neighbor_data ["properties" ])
1431- neighbors .append (neighbor_parsed )
1478+ for neighbor_data in neighbors_dict .values ():
1479+ if isinstance (neighbor_data , dict ) and "properties" in neighbor_data :
1480+ neighbor_parsed = self ._parse_node (neighbor_data ["properties" ])
1481+ neighbors .append (neighbor_parsed )
1482+
1483+ # Deduplicate edges by (source, target, type)
1484+ edges_dict = {}
1485+ for edge_group in all_edges_list :
1486+ if isinstance (edge_group , list ):
1487+ for edge_data in edge_group :
1488+ if isinstance (edge_data , dict ):
1489+ edge_key = (
1490+ edge_data .get ("start_id" , "" ),
1491+ edge_data .get ("end_id" , "" ),
1492+ edge_data .get ("label" , "" ),
1493+ )
1494+ if edge_key not in edges_dict :
1495+ edges_dict [edge_key ] = {
1496+ "type" : edge_data .get ("label" , "" ),
1497+ "source" : edge_data .get ("start_id" , "" ),
1498+ "target" : edge_data .get ("end_id" , "" ),
1499+ }
1500+ elif isinstance (edge_group , dict ):
1501+ # Handle single edge (not in a list)
1502+ edge_key = (
1503+ edge_group .get ("start_id" , "" ),
1504+ edge_group .get ("end_id" , "" ),
1505+ edge_group .get ("label" , "" ),
1506+ )
1507+ if edge_key not in edges_dict :
1508+ edges_dict [edge_key ] = {
1509+ "type" : edge_group .get ("label" , "" ),
1510+ "source" : edge_group .get ("start_id" , "" ),
1511+ "target" : edge_group .get ("end_id" , "" ),
1512+ }
14321513
1433- # Parse edges
1434- edges = []
1435- if isinstance (edges_list , list ):
1436- for edge_group in edges_list :
1437- if isinstance (edge_group , list ):
1438- for edge_data in edge_group :
1439- if isinstance (edge_data , dict ):
1440- edges .append (
1441- {
1442- "type" : edge_data .get ("label" , "" ),
1443- "source" : edge_data .get ("start_id" , "" ),
1444- "target" : edge_data .get ("end_id" , "" ),
1445- }
1446- )
1514+ edges = list (edges_dict .values ())
14471515
14481516 return self ._convert_graph_edges (
14491517 {"core_node" : core_node , "neighbors" : neighbors , "edges" : edges }
0 commit comments