Skip to content

Commit 8639800

Browse files
committed
add ekg_construct test
1 parent 53a50d9 commit 8639800

File tree

7 files changed

+449
-104
lines changed

7 files changed

+449
-104
lines changed

muagent/connector/configs/generate_prompt.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,5 @@ def createMKGPrompt(conversation, schemas, language="en", **kwargs) -> str:
6060
def createText2EKGPrompt(text, language="en", **kwargs) -> str:
6161
prompt = text2EKG_prompt_zh if language == "zh" else text2EKG_prompt_en
6262
prompt = replacePrompt(prompt, keys=["text"])
63-
from loguru import logger
64-
logger.debug(f"{prompt}")
6563
prompt = prompt.format(**{"text": text,})
6664
return cleanPrompt(prompt)

muagent/db_handler/graph_db_handler/base_gb_handler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ def search_nodes_by_attr(self, attributes: dict) -> List[GNode]:
5858
def search_edges_by_attr(self, attributes: dict, edge_type: str = None) -> List[GEdge]:
5959
pass
6060

61+
def get_nodes_by_ids(self, ids: List[int]) -> List[GNode]:
62+
pass
63+
6164
def get_current_node(self, attributes: dict, node_type: str = None, return_keys: list = []) -> GNode:
6265
pass
6366

@@ -73,5 +76,5 @@ def get_neighbor_nodes(self, attributes: dict, node_type: str = None, return_key
7376
def get_neighbor_edges(self, attributes: dict, node_type: str = None, return_keys: list = []) -> List[GEdge]:
7477
pass
7578

76-
def get_hop_infos(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: dict = {}, select_attributes: dict = {}) -> Graph:
79+
def get_hop_infos(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: dict = {}, select_attributes: dict = {}, reverse: bool =False) -> Graph:
7780
pass

muagent/db_handler/graph_db_handler/geabase_handler.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -132,19 +132,26 @@ def get_nodeIDs(self, attributes: dict, node_type: str) -> List[int]:
132132
def get_current_nodeID(self, attributes: dict, node_type: str) -> int:
133133
result = self.get_current_node(attributes, node_type)
134134
return result.attributes.get("ID")
135-
return result.get("ID")
136135

137136
def get_current_edgeID(self, src_id, dst_id, edeg_type:str = None):
138137
if not isinstance(src_id, int) or not isinstance(dst_id, int):
139138
result = self.get_current_edge(src_id, dst_id, edeg_type)
140139
logger.debug(f"{result}")
141140
return result.attributes.get("srcId"), result.attributes.get("dstId"), result.attributes.get("timestamp")
142-
return result.get("srcId"), result.get("dstId"),
143141
else:
144142
return src_id, dst_id, 1
145143

146144
def get_current_node(self, attributes: dict, node_type: str = None, return_keys: list = []) -> GNode:
147145
return self.get_current_nodes(attributes, node_type, return_keys)[0]
146+
147+
def get_nodes_by_ids(self, ids: List[int] = []) -> List[GNode]:
148+
where_str = f'@id in {ids}'
149+
gql = f"MATCH (n0 WHERE {where_str}) RETURN n0"
150+
#
151+
result = self.execute(gql, return_keys=[])
152+
result = self.decode_result(result, gql)
153+
nodes = result.get("n0", []) or result.get("n0.attr", [])
154+
return [GNode(id=node["id"], type=node["type"], attributes=node) for node in nodes]
148155

149156
def get_current_nodes(self, attributes: dict, node_type: str = None, return_keys: list = []) -> List[GNode]:
150157
#
@@ -170,8 +177,6 @@ def get_current_edge(self, src_id, dst_id, edge_type:str = None, return_keys: li
170177

171178
edges = result.get("e", []) or result.get("e.attr", [])
172179
return [GEdge(start_id=edge["start_id"], end_id=edge["end_id"], type=edge["type"], attributes=edge) for edge in edges][0]
173-
174-
return result[0]
175180

176181
def get_neighbor_nodes(self, attributes: dict, node_type: str = None, return_keys: list = []) -> List[GNode]:
177182
#
@@ -184,7 +189,6 @@ def get_neighbor_nodes(self, attributes: dict, node_type: str = None, return_key
184189
result = self.decode_result(result, gql)
185190
nodes = result.get("n1", []) or result.get("n1.attr", [])
186191
return [GNode(id=node["id"], type=node["type"], attributes=node) for node in nodes]
187-
return result.get("n1", []) or result.get("n1.attr", [])
188192

189193
def get_neighbor_edges(self, attributes: dict, node_type: str = None, return_keys: list = []) -> List[GEdge]:
190194
#
@@ -198,22 +202,23 @@ def get_neighbor_edges(self, attributes: dict, node_type: str = None, return_key
198202

199203
edges = result.get("e", []) or result.get("e.attr", [])
200204
return [GEdge(start_id=edge["start_id"], end_id=edge["end_id"], type=edge["type"], attributes=edge) for edge in edges]
201-
return result.get("e", []) or result.get("e.attr", [])
202205

203206
def check_neighbor_exist(self, attributes: dict, node_type: str = None, check_attributes: dict = {}) -> bool:
204207
result = self.get_neighbor_nodes(attributes, node_type,)
205208
filter_result = [i for i in result if all([item in i.attributes.items() for item in check_attributes.items()])]
206209
return len(filter_result) > 0
207210

208-
def get_hop_infos(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: dict = {}, select_attributes: dict = {}) -> Graph:
211+
def get_hop_infos(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: dict = {}, select_attributes: dict = {}, reverse=False) -> Graph:
209212
'''
210213
hop >= 2, 表面需要至少两跳
211214
'''
212215
hop_max = 10
213-
hop_list = []
214216
#
215217
where_str = ' and '.join([f"n0.{k}='{v}'" for k, v in attributes.items()])
216-
gql = f"MATCH p = (n0:{node_type} WHERE {where_str})-[e]->{{1,{min(hop, hop_max)}}}(n1) RETURN n0, n1, e, p"
218+
if reverse:
219+
gql = f"MATCH p = (n0:{node_type} WHERE {where_str})<-[e]-{{1,{min(hop, hop_max)}}}(n1) RETURN n0, n1, e, p"
220+
else:
221+
gql = f"MATCH p = (n0:{node_type} WHERE {where_str})-[e]->{{1,{min(hop, hop_max)}}}(n1) RETURN n0, n1, e, p"
217222
last_node_ids, last_node_types = [], []
218223

219224
result = {}
@@ -238,19 +243,16 @@ def get_hop_infos(self, attributes: dict, node_type: str = None, hop: int = 2, b
238243
nodes = [GNode(id=node["id"], type=node["type"], attributes=node) for node in result.get("n1", [])]
239244
edges = [GEdge(start_id=edge["start_id"], end_id=edge["end_id"], type=edge["type"], attributes=edge) for edge in result.get("e", [])]
240245
return Graph(nodes=nodes, edges=edges, paths=result.get("p", []))
241-
return result
242-
246+
243247
def get_hop_nodes(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: dict = []) -> List[GNode]:
244248
#
245249
result = self.get_hop_infos(attributes, node_type, hop, block_attributes)
246250
return result.nodes
247-
return result.get("n1", []) or result.get("n1.attr", [])
248251

249252
def get_hop_edges(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: dict = []) -> List[GEdge]:
250253
#
251254
result = self.get_hop_infos(attributes, node_type, hop, block_attributes)
252255
return result.edges
253-
return result.get("e", []) or result.get("e.attr", [])
254256

255257
def get_hop_paths(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: dict = []) -> List[str]:
256258
#
@@ -353,14 +355,28 @@ def decode_result(self, geabase_result, gql: str) -> Dict:
353355
return output
354356

355357
def decode_path(self, col_data, k) -> List:
356-
path = []
357358
steps = col_data.get("pathVal", {}).get("steps", [])
359+
connections = {}
358360
for step in steps:
359361
props = step["props"]
360-
if path == []:
361-
path.append(props["original_src_id1__"].get("strVal", "") or props["original_src_id1__"].get("intVal", -1))
362-
363-
path.append(props["original_dst_id2__"].get("strVal", "") or props["original_dst_id2__"].get("intVal", -1))
362+
# if path == []:
363+
# path.append(props["original_src_id1__"].get("strVal", "") or props["original_src_id1__"].get("intVal", -1))
364+
# path.append(props["original_dst_id2__"].get("strVal", "") or props["original_dst_id2__"].get("intVal", -1))
365+
366+
start = props["original_src_id1__"].get("strVal", "") or props["original_src_id1__"].get("intVal", -1)
367+
end = props["original_dst_id2__"].get("strVal", "") or props["original_dst_id2__"].get("intVal", -1)
368+
connections[start] = end
369+
370+
# 找到头部(1)
371+
for k in connections:
372+
if k not in connections.values():
373+
head = k
374+
path = [head]
375+
376+
# 根据连通关系构建路径
377+
while head in connections:
378+
head = connections[head]
379+
path.append(head)
364380

365381
return path
366382

Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from .sqlalchemy_handler import SqlalchemyHandler
2-
3-
__all__ = [
4-
"SqlalchemyHandler"
1+
from .sqlalchemy_handler import SqlalchemyHandler
2+
3+
__all__ = [
4+
"SqlalchemyHandler"
55
]

0 commit comments

Comments
 (0)