Skip to content

Commit a86ecc7

Browse files
committed
update ekg construct graph test
1 parent b7e6905 commit a86ecc7

File tree

5 files changed

+201
-98
lines changed

5 files changed

+201
-98
lines changed

muagent/db_handler/graph_db_handler/base_gb_handler.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,68 +13,56 @@ class GBHandler:
1313
def __init__(self) -> None:
1414
pass
1515

16-
def add_node(self, node: GNode):
16+
def add_node(self, node: GNode) -> dict:
1717
return self.add_nodes([node])
1818

19-
def add_nodes(self, nodes: List[GNode]):
19+
def add_nodes(self, nodes: List[GNode]) -> dict:
2020
pass
2121

22-
def add_edge(self, edge: GEdge):
22+
def add_edge(self, edge: GEdge) -> dict:
2323
return self.add_edges([edge])
2424

25-
def add_edges(self, edges: List[GEdge]):
25+
def add_edges(self, edges: List[GEdge]) -> dict:
2626
pass
2727

28-
def update_node(self, attributes: dict, set_attributes: dict, node_type: str = None, ID: int = None):
28+
def update_node(self, attributes: dict, set_attributes: dict, node_type: str = None, ID: int = None) -> dict:
2929
pass
3030

31-
def update_edge(self, src_id, dst_id, set_attributes: dict, edge_type: str = None):
31+
def update_edge(self, src_id, dst_id, set_attributes: dict, edge_type: str = None) -> dict:
3232
pass
3333

34-
def delete_node(self, attributes: dict, node_type: str = None, ID: int = None):
34+
def delete_node(self, attributes: dict, node_type: str = None, ID: int = None) -> dict:
3535
pass
3636

37-
def delete_nodes(self, attributes: dict, node_type: str = None, ID: int = None):
37+
def delete_nodes(self, attributes: dict, node_type: str = None, IDs: List[int] = []) -> dict:
3838
pass
3939

40-
def delete_edge(self, src_id, dst_id, edge_type: str = None):
40+
def delete_edge(self, src_id, dst_id, edge_type: str = None) -> dict:
4141
pass
42-
43-
def delete_edges(self, src_id, dst_id, edge_type: str = None):
44-
pass
45-
46-
def search_node_by_nodeid(self, nodeid: str, node_type: str = None) -> GNode:
47-
pass
48-
49-
def search_edges_by_nodeid(self, nodeid: str, node_type: str = None) -> List[GEdge]:
50-
pass
51-
52-
def search_edge_by_nodeids(self, start_id: str, end_id: str, edge_type: str = None) -> GEdge:
53-
pass
54-
55-
def search_nodes_by_attr(self, attributes: dict) -> List[GNode]:
56-
pass
57-
58-
def search_edges_by_attr(self, attributes: dict, edge_type: str = None) -> List[GEdge]:
42+
43+
def delete_edges(self, id_pairs: List, edge_type: str = None):
5944
pass
6045

61-
def get_nodes_by_ids(self, ids: List[int]) -> List[GNode]:
46+
def get_nodeIDs(self, attributes: dict, node_type: str) -> List[int]:
6247
pass
6348

6449
def get_current_node(self, attributes: dict, node_type: str = None, return_keys: list = []) -> GNode:
6550
pass
51+
52+
def get_nodes_by_ids(self, ids: List[int] = []) -> List[GNode]:
53+
pass
6654

6755
def get_current_nodes(self, attributes: dict, node_type: str = None, return_keys: list = []) -> List[GNode]:
6856
pass
6957

7058
def get_current_edge(self, src_id, dst_id, edge_type:str = None, return_keys: list = []) -> GEdge:
7159
pass
7260

73-
def get_neighbor_nodes(self, attributes: dict, node_type: str = None, return_keys: list = []) -> List[GNode]:
61+
def get_neighbor_nodes(self, attributes: dict, node_type: str = None, return_keys: list = [], reverse=False) -> List[GNode]:
7462
pass
7563

7664
def get_neighbor_edges(self, attributes: dict, node_type: str = None, return_keys: list = []) -> List[GEdge]:
7765
pass
7866

79-
def get_hop_infos(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: dict = {}, select_attributes: dict = {}, reverse: bool =False) -> Graph:
80-
pass
67+
def get_hop_infos(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: dict = {}, select_attributes: dict = {}, reverse=False) -> Graph:
68+
pass

muagent/db_handler/graph_db_handler/geabase_handler.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ def add_nodes(self, nodes: List[GNode]) -> dict:
4646
node_str_list = []
4747
for node in nodes:
4848
node_type = node.type
49-
node_attributes = {"@id": double_hashing(node.id), "id": node.id}
49+
node_attributes = {"id": node.id}
50+
node_attributes["@id"] = node.attributes.pop("ID", "") or double_hashing(node.id)
5051
node_attributes.update(node.attributes)
51-
# _ = node_attributes.pop("type")
52-
# logger.debug(f"{node_attributes}")
53-
node_str = ", ".join([f"{k}: '{v}'" if isinstance(v, str) else f"{k}: {v}" for k, v in node_attributes.items()])
52+
53+
node_str = ", ".join([f"{k}: '{v}'" if isinstance(v, (str, bool)) else f"{k}: {v}" for k, v in node_attributes.items()])
5454
node_str_list.append(f"(:{node_type} {{{node_str}}})")
5555

5656
gql = f"INSERT {','.join(node_str_list)}"
@@ -64,19 +64,21 @@ def add_edges(self, edges: List[GEdge]) -> dict:
6464
edge_str_list = []
6565
for edge in edges:
6666
edge_type = edge.type
67-
src_id, dst_id = double_hashing(edge.start_id,), double_hashing(edge.end_id,)
68-
edge_attributes = {"@src_id": src_id, "@dst_id": dst_id}
67+
edge_attributes = {
68+
"@src_id": edge.attributes.pop("SRCID", 0) or double_hashing(edge.start_id,),
69+
"@dst_id": edge.attributes.pop("DSTID", 0) or double_hashing(edge.end_id,)
70+
}
6971
edge_attributes.update(edge.attributes)
70-
# _ = edge_attributes.pop("type")
71-
edge_str = ", ".join([f"{k}: '{v}'" if isinstance(v, str) else f"{k}: {v}" for k, v in edge_attributes.items()])
72+
73+
edge_str = ", ".join([f"{k}: '{v}'" if isinstance(v, (str, bool)) else f"{k}: {v}" for k, v in edge_attributes.items()])
7274
edge_str_list.append(f"()-[:{edge_type} {{{edge_str}}}]->()")
7375

7476
gql = f"INSERT {','.join(edge_str_list)}"
7577
return self.execute(gql)
7678

7779
def update_node(self, attributes: dict, set_attributes: dict, node_type: str = None, ID: int = None) -> dict:
7880
# demo: "MATCH (n:opsgptkg_employee {@ID: xxxx}) SET n.originname = 'xxx', n.description = 'xxx'"
79-
set_str = ", ".join([f"n.{k}='{v}'" if isinstance(v, str) else f"n.{k}={v}" for k, v in set_attributes.items()])
81+
set_str = ", ".join([f"n.{k}='{v}'" if isinstance(v, (str, bool)) else f"n.{k}={v}" for k, v in set_attributes.items()])
8082

8183
if (ID is None) or (not isinstance(ID, int)):
8284
ID = self.get_current_nodeID(attributes, node_type)
@@ -89,7 +91,7 @@ def update_edge(self, src_id, dst_id, set_attributes: dict, edge_type: str = Non
8991
src_id, dst_id, timestamp = self.get_current_edgeID(src_id, dst_id, edge_type)
9092
src_type, dst_type = self.get_nodetypes_by_edgetype(edge_type)
9193
# src_id, dst_id = double_hashing(src_id), double_hashing(dst_id)
92-
set_str = ", ".join([f"e.{k}='{v}'" if isinstance(v, str) else f"e.{k}={v}" for k, v in set_attributes.items()])
94+
set_str = ", ".join([f"e.{k}='{v}'" if isinstance(v, (str, bool)) else f"e.{k}={v}" for k, v in set_attributes.items()])
9395
# demo: MATCH ()-[r:PlayFor{@src_id:1, @dst_id:100, @timestamp:0}]->() SET r.contract = 0;
9496
# gql = f"MATCH ()-[e:{edge_type}{{@src_id:{src_id}, @dst_id:{dst_id}, timestamp:{timestamp}}}]->() SET {set_str}"
9597
gql = f"MATCH (n0:{src_type} {{@id: {src_id}}})-[e]->(n1:{dst_type} {{@id:{dst_id}}}) SET {set_str}"
@@ -102,7 +104,7 @@ def delete_node(self, attributes: dict, node_type: str = None, ID: int = None) -
102104
gql = f"MATCH (n:{node_type}) WHERE n.@ID={ID} DELETE n"
103105
return self.execute(gql)
104106

105-
def delete_nodes(self, attributes: dict, node_type: str = None, IDs: List[int] = None) -> dict:
107+
def delete_nodes(self, attributes: dict, node_type: str = None, IDs: List[int] = []) -> dict:
106108
if (IDs is None) or len(IDs)==0:
107109
IDs = self.get_nodeIDs(attributes, node_type)
108110
# ID = double_hashing(ID)
@@ -138,7 +140,7 @@ def get_current_edgeID(self, src_id, dst_id, edeg_type:str = None):
138140
if not isinstance(src_id, int) or not isinstance(dst_id, int):
139141
result = self.get_current_edge(src_id, dst_id, edeg_type)
140142
logger.debug(f"{result}")
141-
return result.attributes.get("srcId"), result.attributes.get("dstId"), result.attributes.get("timestamp")
143+
return result.attributes.get("SRCID"), result.attributes.get("DSTID"), result.attributes.get("timestamp")
142144
else:
143145
return src_id, dst_id, 1
144146

@@ -164,7 +166,8 @@ def get_current_nodes(self, attributes: dict, node_type: str = None, return_keys
164166
result = self.execute(gql, return_keys=return_keys)
165167
result = self.decode_result(result, gql)
166168

167-
nodes = result.get("n0", []) or result.get("n0.attr", [])
169+
nodes = result.get("n0", []) or result.get("n0.attr", [])
170+
return self.convert2GNodes(nodes)
168171
return [GNode(id=node["id"], type=node["type"], attributes=node) for node in nodes]
169172

170173
def get_current_edge(self, src_id, dst_id, edge_type:str = None, return_keys: list = []) -> GEdge:
@@ -177,6 +180,7 @@ def get_current_edge(self, src_id, dst_id, edge_type:str = None, return_keys: li
177180
result = self.decode_result(result, gql)
178181

179182
edges = result.get("e", []) or result.get("e.attr", [])
183+
return self.convert2GEdges(edges)[0]
180184
return [GEdge(start_id=edge["start_id"], end_id=edge["end_id"], type=edge["type"], attributes=edge) for edge in edges][0]
181185

182186
def get_neighbor_nodes(self, attributes: dict, node_type: str = None, return_keys: list = [], reverse=False) -> List[GNode]:
@@ -192,6 +196,7 @@ def get_neighbor_nodes(self, attributes: dict, node_type: str = None, return_key
192196
result = self.execute(gql, return_keys=return_keys)
193197
result = self.decode_result(result, gql)
194198
nodes = result.get("n1", []) or result.get("n1.attr", [])
199+
return self.convert2GNodes(nodes)
195200
return [GNode(id=node["id"], type=node["type"], attributes=node) for node in nodes]
196201

197202
def get_neighbor_edges(self, attributes: dict, node_type: str = None, return_keys: list = []) -> List[GEdge]:
@@ -205,6 +210,7 @@ def get_neighbor_edges(self, attributes: dict, node_type: str = None, return_key
205210
result = self.decode_result(result, gql)
206211

207212
edges = result.get("e", []) or result.get("e.attr", [])
213+
return self.convert2GEdges(edges)
208214
return [GEdge(start_id=edge["start_id"], end_id=edge["end_id"], type=edge["type"], attributes=edge) for edge in edges]
209215

210216
def check_neighbor_exist(self, attributes: dict, node_type: str = None, check_attributes: dict = {}) -> bool:
@@ -244,8 +250,10 @@ def get_hop_infos(self, attributes: dict, node_type: str = None, hop: int = 2, b
244250
last_node_ids, last_node_types, result = self.deduplicate_paths(result, block_attributes, select_attributes)
245251
hop -= hop_max
246252

247-
nodes = [GNode(id=node["id"], type=node["type"], attributes=node) for node in result.get("n1", [])]
248-
edges = [GEdge(start_id=edge["start_id"], end_id=edge["end_id"], type=edge["type"], attributes=edge) for edge in result.get("e", [])]
253+
nodes = self.convert2GNodes(result.get("n1", []))
254+
edges = self.convert2GEdges(result.get("e", []))
255+
# nodes = [GNode(id=node["id"], type=node["type"], attributes=node) for node in result.get("n1", [])]
256+
# edges = [GEdge(start_id=edge["start_id"], end_id=edge["end_id"], type=edge["type"], attributes=edge) for edge in result.get("e", [])]
249257
return Graph(nodes=nodes, edges=edges, paths=result.get("p", []))
250258

251259
def get_hop_nodes(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: dict = []) -> List[GNode]:
@@ -387,16 +395,17 @@ def decode_path(self, col_data, k) -> List:
387395
def decode_vertex(self, col_data, k) -> Dict:
388396
vertextVal = col_data.get("vertexVal", {})
389397
node_val_json = {
390-
**{"ID": vertextVal.get("id", ""), "type": vertextVal.get("type", "")},
398+
**{"ID": int(vertextVal.get("id", "")), "type": vertextVal.get("type", "")},
391399
**{k: v.get("strVal", "") or v.get("intVal", "0") for k, v in vertextVal.get("props", {}).items()}
392400
}
401+
node_val_json.pop("biz_node_id", "")
393402
return node_val_json
394403

395404
def decode_edge(self, col_data, k) -> Dict:
396405
def _decode_edge(data):
397406
edgeVal= data.get("edgeVal", {})
398407
edge_val_json = {
399-
**{"srcId": edgeVal.get("srcId", ""), "dstId": edgeVal.get("dstId", ""), "type": edgeVal.get("type", "")},
408+
**{"SRCID": int(edgeVal.get("srcId", "")), "DSTID": int(edgeVal.get("dstId", "")), "type": edgeVal.get("type", "")},
400409
**{k: v.get("strVal", "") or v.get("intVal", "0") for k, v in edgeVal.get("props", {}).items()}
401410
}
402411
# 存在业务逻辑
@@ -422,4 +431,21 @@ def get_nodetypes_by_edgetype(self, edge_type: str):
422431
if edge_bridge in edge_type:
423432
src_type, dst_type = edge_type.split(edge_bridge)
424433
break
425-
return src_type, dst_type
434+
return src_type, dst_type
435+
436+
def convert2GNodes(self, raw_nodes: List[Dict]) -> List[GNode]:
437+
nodes = []
438+
for node in raw_nodes:
439+
node_id = node.pop("id")
440+
node_type = node.pop("type")
441+
nodes.append(GNode(id=node_id, type=node_type, attributes=node))
442+
return nodes
443+
444+
def convert2GEdges(self, raw_edges: List[Dict]) -> List[GEdge]:
445+
edges = []
446+
for edge in raw_edges:
447+
start_id = edge.pop("start_id")
448+
end_id = edge.pop("end_id")
449+
edge_type = edge.pop("type")
450+
edges.append(GEdge(start_id=start_id, end_id=end_id, type=edge_type, attributes=edge))
451+
return edges

muagent/schemas/ekg/ekg_graph.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ class NodeSchema(BaseModel):
3131

3232
class EdgeSchema(BaseModel):
3333
# entity_id, ekg_node:{graph_id}:{node_type}:{content_md5}
34-
src_id: int = None
34+
SRCID: int = None
3535
original_src_id1__: str
3636
# entity_id, ekg_node:{graph_id}:{node_type}:{content_md5}
37-
dst_id: int = None
37+
DSTID: int = None
3838
original_dst_id2__: str
3939
#
4040
timestamp: int
@@ -107,10 +107,10 @@ def attrbutes(self, ):
107107
}
108108

109109
class EKGTaskNodeSchema(EKGNodeSchema):
110-
# tool: str
111-
# needCheck: bool
110+
tool: str
111+
needcheck: bool
112112
# when to access
113-
accessCriteria: str
113+
accesscriteria: str
114114
#
115115
# owner: str
116116

@@ -121,19 +121,21 @@ def attrbutes(self, ):
121121
"name": self.name,
122122
"description": self.description,
123123
"teamids": self.teamids,
124-
"accessCriteria": self.accessCriteria
124+
"accesscriteria": self.accesscriteria,
125+
"needcheck": self.needcheck,
126+
"tool": self.tool
125127
},
126128
**extra_attr
127129
}
128130

129131

130132
class EKGAnalysisNodeSchema(EKGNodeSchema):
131133
# when to access
132-
accessCriteria: str
134+
accesscriteria: str
133135
# do summary or not
134-
summarySwtich: bool
136+
summaryswtich: bool
135137
# summary template
136-
dslTemplate: str
138+
dsltemplate: str
137139

138140
def attrbutes(self, ):
139141
extra_attr = json.loads(self.extra)
@@ -142,9 +144,9 @@ def attrbutes(self, ):
142144
"name": self.name,
143145
"description": self.description,
144146
"teamids": self.teamids,
145-
"accessCriteria": self.accessCriteria,
146-
"summarySwtich": self.summarySwtich,
147-
"dslTemplate": self.dslTemplate
147+
"accesscriteria": self.accesscriteria,
148+
"summaryswtich": self.summaryswtich,
149+
"dsltemplate": self.dsltemplate
148150
},
149151
**extra_attr
150152
}

0 commit comments

Comments
 (0)