Skip to content

Commit c9bbcb6

Browse files
committed
bugfix deduplicate_paths, and update leaf node attributes
1 parent 1a03304 commit c9bbcb6

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

muagent/db_handler/graph_db_handler/geabase_handler.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -285,15 +285,22 @@ def deduplicate_paths(self, result, block_attributes: dict = {}, select_attribut
285285
for i in n0+n1
286286
if select_attributes and not all(item not in i.items() for item in select_attributes.items())
287287
]
288-
# 路径去重
288+
p = [_p for _p in p if all([_pid not in block_node_ids for _pid in _p])]
289+
# deduplicate the paths
289290
path_strs = ["&&".join(_p) for _p in p]
290291
new_p = []
291-
new_path_strs_set = set()
292292
for path_str, _p in zip(path_strs, p):
293293
if not any(path_str in other for other in path_strs if path_str != other):
294-
if path_str not in new_path_strs_set and all([_pid not in block_node_ids for _pid in _p]):
295-
new_p.append(_p)
296-
new_path_strs_set.add(path_str)
294+
new_p.append(_p)
295+
# # 路径去重
296+
# path_strs = ["&&".join(_p) for _p in p]
297+
# new_p = []
298+
# new_path_strs_set = set()
299+
# for path_str, _p in zip(path_strs, p):
300+
# if not any(path_str in other for other in path_strs if path_str != other):
301+
# if path_str not in new_path_strs_set and all([_pid not in block_node_ids for _pid in _p]):
302+
# new_p.append(_p)
303+
# new_path_strs_set.add(path_str)
297304

298305
# 根据保留路径进行合并
299306
nodeid2type = {i["id"]: i["type"] for i in n0+n1}

muagent/service/ekg_construct/ekg_construct_base.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,14 +412,35 @@ def get_node_by_id(self, nodeid: str, node_type:str = None) -> GNode:
412412
node.attributes.update(extra_attrs)
413413
return node
414414

415-
def get_graph_by_nodeid(self, nodeid: str, node_type: str, teamid: str=None, hop: int = 10) -> Graph:
416-
if hop > 14:
415+
def get_graph_by_nodeid(
416+
self,
417+
nodeid: str,
418+
node_type: str,
419+
hop: int = 10,
420+
block_attributes: dict = {}
421+
) -> Graph:
422+
if hop<2:
423+
raise Exception(f"hop must be smaller than 2, now hop is {hop}")
424+
if hop >= 14:
417425
raise Exception(f"hop can't be larger than 14, now hop is {hop}")
418426
# filter the node which dont match teamid
419-
result = self.gb.get_hop_infos({'id': nodeid}, node_type=node_type, hop=hop)
427+
result = self.gb.get_hop_infos(
428+
{'id': nodeid}, node_type=node_type,
429+
hop=hop, block_attributes=block_attributes
430+
)
431+
432+
if block_attributes:
433+
leaf_nodeids = [node.id for node in result.nodes if node.type=="opsgptkg_schedule"]
434+
else:
435+
leaf_nodeids = [path[-1] for path in result.paths if len(path)==hop+1]
436+
420437
for node in result.nodes:
421438
extra_attrs = json.loads(node.attributes.pop("extra", "{}") or "{}")
422439
node.attributes.update(extra_attrs)
440+
if node.id in leaf_nodeids:
441+
neighbor_nodes = self.gb.get_neighbor_nodes({"id": node.id}, node_type=node.type)
442+
node.attributes["cnode_nums"] = len(neighbor_nodes)
443+
423444
for edge in result.edges:
424445
extra_attrs = json.loads(edge.attributes.pop("extra", "{}") or "{}")
425446
edge.attributes.update(extra_attrs)

0 commit comments

Comments
 (0)