Skip to content

Commit f5eae2f

Browse files
authored
fix get_subgraph (#633)
1 parent c6cabf5 commit f5eae2f

File tree

1 file changed

+134
-66
lines changed

1 file changed

+134
-66
lines changed

src/memos/graph_dbs/polardb.py

Lines changed: 134 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,89 +1361,157 @@ def get_subgraph(
13611361
r)
13621362
$$ ) as (centers agtype, neighbors agtype, rels agtype);
13631363
"""
1364-
query = f"""
1365-
SELECT * FROM cypher('{self.db_name}_graph', $$
1366-
MATCH(center: Memory)-[r * 1..{depth}]->(neighbor:Memory)
1367-
WHERE
1368-
center.id = '{center_id}'
1369-
AND center.status = '{center_status}'
1370-
AND center.user_name = '{user_name}'
1371-
RETURN
1372-
collect(DISTINCT
1373-
center), collect(DISTINCT
1374-
neighbor), collect(DISTINCT
1375-
r)
1376-
$$ ) as (centers agtype, neighbors agtype, rels agtype);
1377-
"""
1364+
# Use UNION ALL for better performance: separate queries for depth 1 and depth 2
1365+
if depth == 1:
1366+
query = f"""
1367+
SELECT * FROM cypher('{self.db_name}_graph', $$
1368+
MATCH(center: Memory)-[r]->(neighbor:Memory)
1369+
WHERE
1370+
center.id = '{center_id}'
1371+
AND center.status = '{center_status}'
1372+
AND center.user_name = '{user_name}'
1373+
RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r)
1374+
$$ ) as (centers agtype, neighbors agtype, rels agtype);
1375+
"""
1376+
else:
1377+
# For depth >= 2, use UNION ALL to combine depth 1 and depth 2 queries
1378+
query = f"""
1379+
SELECT * FROM cypher('{self.db_name}_graph', $$
1380+
MATCH(center: Memory)-[r]->(neighbor:Memory)
1381+
WHERE
1382+
center.id = '{center_id}'
1383+
AND center.status = '{center_status}'
1384+
AND center.user_name = '{user_name}'
1385+
RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r)
1386+
UNION ALL
1387+
MATCH(center: Memory)-[r]->(n:Memory)-[r1]->(neighbor:Memory)
1388+
WHERE
1389+
center.id = '{center_id}'
1390+
AND center.status = '{center_status}'
1391+
AND center.user_name = '{user_name}'
1392+
RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r1)
1393+
$$ ) as (centers agtype, neighbors agtype, rels agtype);
1394+
"""
13781395
conn = self._get_connection()
13791396
logger.info(f"[get_subgraph] Query: {query}")
13801397
try:
13811398
with conn.cursor() as cursor:
13821399
cursor.execute(query)
1383-
result = cursor.fetchone()
1400+
results = cursor.fetchall()
13841401

1385-
if not result or not result[0]:
1402+
if not results:
13861403
return {"core_node": None, "neighbors": [], "edges": []}
13871404

1388-
# Parse center node
1389-
centers_data = result[0] if result[0] else "[]"
1390-
neighbors_data = result[1] if result[1] else "[]"
1391-
edges_data = result[2] if result[2] else "[]"
1405+
# Merge results from all UNION ALL rows
1406+
all_centers_list = []
1407+
all_neighbors_list = []
1408+
all_edges_list = []
13921409

1393-
# Parse JSON data
1394-
try:
1395-
# Clean ::vertex and ::edge suffixes in data
1396-
if isinstance(centers_data, str):
1397-
centers_data = centers_data.replace("::vertex", "")
1398-
if isinstance(neighbors_data, str):
1399-
neighbors_data = neighbors_data.replace("::vertex", "")
1400-
if isinstance(edges_data, str):
1401-
edges_data = edges_data.replace("::edge", "")
1402-
1403-
centers_list = (
1404-
json.loads(centers_data) if isinstance(centers_data, str) else centers_data
1405-
)
1406-
neighbors_list = (
1407-
json.loads(neighbors_data)
1408-
if isinstance(neighbors_data, str)
1409-
else neighbors_data
1410-
)
1411-
edges_list = (
1412-
json.loads(edges_data) if isinstance(edges_data, str) else edges_data
1413-
)
1414-
except json.JSONDecodeError as e:
1415-
logger.error(f"Failed to parse JSON data: {e}")
1416-
return {"core_node": None, "neighbors": [], "edges": []}
1410+
for result in results:
1411+
if not result or not result[0]:
1412+
continue
1413+
1414+
centers_data = result[0] if result[0] else "[]"
1415+
neighbors_data = result[1] if result[1] else "[]"
1416+
edges_data = result[2] if result[2] else "[]"
1417+
1418+
# Parse JSON data
1419+
try:
1420+
# Clean ::vertex and ::edge suffixes in data
1421+
if isinstance(centers_data, str):
1422+
centers_data = centers_data.replace("::vertex", "")
1423+
if isinstance(neighbors_data, str):
1424+
neighbors_data = neighbors_data.replace("::vertex", "")
1425+
if isinstance(edges_data, str):
1426+
edges_data = edges_data.replace("::edge", "")
1427+
1428+
centers_list = (
1429+
json.loads(centers_data)
1430+
if isinstance(centers_data, str)
1431+
else centers_data
1432+
)
1433+
neighbors_list = (
1434+
json.loads(neighbors_data)
1435+
if isinstance(neighbors_data, str)
1436+
else neighbors_data
1437+
)
1438+
edges_list = (
1439+
json.loads(edges_data) if isinstance(edges_data, str) else edges_data
1440+
)
1441+
1442+
# Collect data from this row
1443+
if isinstance(centers_list, list):
1444+
all_centers_list.extend(centers_list)
1445+
if isinstance(neighbors_list, list):
1446+
all_neighbors_list.extend(neighbors_list)
1447+
if isinstance(edges_list, list):
1448+
all_edges_list.extend(edges_list)
1449+
except json.JSONDecodeError as e:
1450+
logger.error(f"Failed to parse JSON data: {e}")
1451+
continue
14171452

1418-
# Parse center node
1453+
# Deduplicate centers by ID
1454+
centers_dict = {}
1455+
for center_data in all_centers_list:
1456+
if isinstance(center_data, dict) and "properties" in center_data:
1457+
center_id_key = center_data["properties"].get("id")
1458+
if center_id_key and center_id_key not in centers_dict:
1459+
centers_dict[center_id_key] = center_data
1460+
1461+
# Parse center node (use first center)
14191462
core_node = None
1420-
if centers_list and len(centers_list) > 0:
1421-
center_data = centers_list[0]
1463+
if centers_dict:
1464+
center_data = next(iter(centers_dict.values()))
14221465
if isinstance(center_data, dict) and "properties" in center_data:
14231466
core_node = self._parse_node(center_data["properties"])
14241467

1468+
# Deduplicate neighbors by ID
1469+
neighbors_dict = {}
1470+
for neighbor_data in all_neighbors_list:
1471+
if isinstance(neighbor_data, dict) and "properties" in neighbor_data:
1472+
neighbor_id = neighbor_data["properties"].get("id")
1473+
if neighbor_id and neighbor_id not in neighbors_dict:
1474+
neighbors_dict[neighbor_id] = neighbor_data
1475+
14251476
# Parse neighbor nodes
14261477
neighbors = []
1427-
if isinstance(neighbors_list, list):
1428-
for neighbor_data in neighbors_list:
1429-
if isinstance(neighbor_data, dict) and "properties" in neighbor_data:
1430-
neighbor_parsed = self._parse_node(neighbor_data["properties"])
1431-
neighbors.append(neighbor_parsed)
1478+
for neighbor_data in neighbors_dict.values():
1479+
if isinstance(neighbor_data, dict) and "properties" in neighbor_data:
1480+
neighbor_parsed = self._parse_node(neighbor_data["properties"])
1481+
neighbors.append(neighbor_parsed)
1482+
1483+
# Deduplicate edges by (source, target, type)
1484+
edges_dict = {}
1485+
for edge_group in all_edges_list:
1486+
if isinstance(edge_group, list):
1487+
for edge_data in edge_group:
1488+
if isinstance(edge_data, dict):
1489+
edge_key = (
1490+
edge_data.get("start_id", ""),
1491+
edge_data.get("end_id", ""),
1492+
edge_data.get("label", ""),
1493+
)
1494+
if edge_key not in edges_dict:
1495+
edges_dict[edge_key] = {
1496+
"type": edge_data.get("label", ""),
1497+
"source": edge_data.get("start_id", ""),
1498+
"target": edge_data.get("end_id", ""),
1499+
}
1500+
elif isinstance(edge_group, dict):
1501+
# Handle single edge (not in a list)
1502+
edge_key = (
1503+
edge_group.get("start_id", ""),
1504+
edge_group.get("end_id", ""),
1505+
edge_group.get("label", ""),
1506+
)
1507+
if edge_key not in edges_dict:
1508+
edges_dict[edge_key] = {
1509+
"type": edge_group.get("label", ""),
1510+
"source": edge_group.get("start_id", ""),
1511+
"target": edge_group.get("end_id", ""),
1512+
}
14321513

1433-
# Parse edges
1434-
edges = []
1435-
if isinstance(edges_list, list):
1436-
for edge_group in edges_list:
1437-
if isinstance(edge_group, list):
1438-
for edge_data in edge_group:
1439-
if isinstance(edge_data, dict):
1440-
edges.append(
1441-
{
1442-
"type": edge_data.get("label", ""),
1443-
"source": edge_data.get("start_id", ""),
1444-
"target": edge_data.get("end_id", ""),
1445-
}
1446-
)
1514+
edges = list(edges_dict.values())
14471515

14481516
return self._convert_graph_edges(
14491517
{"core_node": core_node, "neighbors": neighbors, "edges": edges}

0 commit comments

Comments
 (0)