Skip to content

Commit 1cab0a8

Browse files
committed
feat: ekg add new node and attributes
1 parent f2b9146 commit 1cab0a8

File tree

9 files changed

+284
-127
lines changed

9 files changed

+284
-127
lines changed

Dockerfile

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,6 @@ WORKDIR /home/user
44

55
COPY ./requirements.txt /home/user/docker_requirements.txt
66

7-
8-
# RUN apt-get update
9-
# RUN apt-get install -y iputils-ping telnetd net-tools vim tcpdump
10-
# RUN echo telnet stream tcp nowait telnetd /usr/sbin/tcpd /usr/sbin/in.telnetd /etc/inetd.conf
11-
# RUN service inetutils-inetd start
12-
# service inetutils-inetd status
13-
14-
# RUN wget https://oss-cdn.nebula-graph.com.cn/package/3.6.0/nebula-graph-3.6.0.ubuntu1804.amd64.deb
15-
# RUN dpkg -i nebula-graph-3.6.0.ubuntu1804.amd64.deb
167
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
178
RUN pip install -r /home/user/docker_requirements.txt --retries 5 --timeout 120
189

docker-compose.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,12 @@ services:
153153
USER:
154154
root
155155
ports:
156-
# - 5050:3737
157156
- 8080:8080
158157
networks:
159158
- ekg-net
159+
volumes:
160+
- ./runtime:/home/user/runtime
161+
command: ["bash", "-c", "cd /home/user/runtime && mvn package && java -jar /home/user/runtime/bootstrap/muagent-runtime.jar"]
160162

161163
ekgservice:
162164
build:

examples/ekg_examples/who_is_spy_game.py

Lines changed: 61 additions & 24 deletions
Large diffs are not rendered by default.

muagent/db_handler/graph_db_handler/geabase_handler.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,17 @@ def delete_edge(self, src_id, dst_id, edge_type: str = None) -> GbaseExecStatus:
145145

146146
def delete_edges(self, id_pairs: List, edge_type: str = None) -> GbaseExecStatus:
147147
# geabase 不支持直接根据边关系进行检索
148-
src_id, dst_id, timestamp = self.get_current_edgeID(src_id, dst_id, edge_type)
149-
# src_id, dst_id = double_hashing(src_id), double_hashing(dst_id)
150-
gql = f"MATCH ()-[e:{edge_type}{{@src_id:{src_id}, @dst_id:{dst_id}}}]->() DELETE e"
151-
gql = f"MATCH (n:opsgptkg_intent )-[r]->(t1) DELETE r"
152-
return self._get_crud_status(self.execute(gql))
148+
try:
149+
src_id, dst_id, timestamp = self.get_current_edgeID(src_id, dst_id, edge_type)
150+
# src_id, dst_id = double_hashing(src_id), double_hashing(dst_id)
151+
gql = f"MATCH ()-[e:{edge_type}{{@src_id:{src_id}, @dst_id:{dst_id}}}]->() DELETE e"
152+
gql = f"MATCH (n:opsgptkg_intent )-[r]->(t1) DELETE r"
153+
return self._get_crud_status(self.execute(gql))
154+
except Exception as e:
155+
return GbaseExecStatus(
156+
errorMessage=e,
157+
errorCode=-1,
158+
)
153159

154160
def get_nodeIDs(self, attributes: dict, node_type: str) -> List[int]:
155161
result = self.get_current_nodes(attributes, node_type)
@@ -236,22 +242,30 @@ def check_neighbor_exist(self, attributes: dict, node_type: str = None, check_at
236242
filter_result = [i for i in result if all([item in i.attributes.items() for item in check_attributes.items()])]
237243
return len(filter_result) > 0
238244

239-
def get_hop_infos(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: List[dict] = [], select_attributes: dict = {}, reverse=False) -> Graph:
245+
def get_hop_infos(
246+
self,
247+
attributes: dict,
248+
node_type: str = None,
249+
hop: int = 2,
250+
block_attributes: List[dict] = [],
251+
select_attributes: dict = {},
252+
reverse=False
253+
) -> Graph:
240254
'''
241-
hop >= 2, 表面需要至少两跳
255+
hop >= 1
242256
'''
243257
hop_max = self.hop_max
244258
#
245259
where_str = ' and '.join([f"n0.{k}='{v}'" for k, v in attributes.items()])
246260
if reverse:
247-
gql = f"MATCH p = (n0:{node_type} WHERE {where_str})<-[e]-{{1,{min(hop, hop_max)}}}(n1) RETURN n0, n1, e, p"
261+
gql = f"MATCH p = (n0:{node_type} WHERE {where_str})<-[e]-{{0,{min(hop, hop_max)}}}(n1) RETURN n0, n1, e, p"
248262
else:
249-
gql = f"MATCH p = (n0:{node_type} WHERE {where_str})-[e]->{{1,{min(hop, hop_max)}}}(n1) RETURN n0, n1, e, p"
263+
gql = f"MATCH p = (n0:{node_type} WHERE {where_str})-[e]->{{0,{min(hop, hop_max)}}}(n1) RETURN n0, n1, e, p"
250264
last_node_ids, last_node_types = [], []
251265

252266
result = {}
253267
iter_index = 0
254-
while hop > 1:
268+
while hop >= 1:
255269
if last_node_ids == [] and iter_index==0:
256270
#
257271
result = self.execute(gql)
@@ -272,7 +286,8 @@ def get_hop_infos(self, attributes: dict, node_type: str = None, hop: int = 2, b
272286

273287
result = self.merge_hotinfos(result, _result, reverse=reverse)
274288
#
275-
last_node_ids, last_node_types, result = self.deduplicate_paths(result, block_attributes, select_attributes, hop=min(hop, hop_max)+iter_index*hop_max, reverse=reverse)
289+
last_node_ids, last_node_types, result = self.deduplicate_paths(
290+
result, block_attributes, select_attributes, hop=min(hop, hop_max)+iter_index*hop_max, reverse=reverse)
276291
hop -= hop_max
277292
iter_index += 1
278293

@@ -381,6 +396,7 @@ def decode_result(self, geabase_result, gql: str) -> Dict:
381396
for col_data, rk, sk in zip(row["columns"], return_keys, save_keys):
382397
_decode_func = decode_geabase_result_func_by_key.get(sk, self.decode_attribute)
383398
# print(sk, json.dumps(col_data, ensure_ascii=False, indent=2))
399+
if col_data is None or col_data == {}: continue
384400
decode_reuslt = _decode_func(col_data, rk)
385401
if ".attr" in sk:
386402
attr_dict.setdefault(sk, {}).update(decode_reuslt)
@@ -447,8 +463,17 @@ def decode_edge(self, col_data, k) -> Dict:
447463
def _decode_edge(data):
448464
edgeVal= data.get("edgeVal", {})
449465
edge_val_json = {
450-
**{"SRCID": int(edgeVal.get("srcId", "")), "DSTID": int(edgeVal.get("dstId", "")), "type": edgeVal.get("type", "")},
451-
**{k: v.get("strVal", "") if "strVal" in v else v.get("intVal", "0") for k, v in edgeVal.get("props", {}).items()}
466+
**{
467+
"SRCID": int(edgeVal.get("srcId", "")),
468+
"DSTID": int(edgeVal.get("dstId", "")),
469+
"type": edgeVal.get("type", ""),
470+
"timestamp": int(edgeVal.get("timestamp", "1"))
471+
},
472+
**{
473+
k: v.get("strVal", "")
474+
if "strVal" in v else v.get("intVal", "0")
475+
for k, v in edgeVal.get("props", {}).items()
476+
}
452477
}
453478
# 存在业务逻辑
454479
edge_val_json["start_id"] = edge_val_json.pop("original_src_id1__")

muagent/httpapis/ekg_construct/api.py

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from muagent.service.ekg_construct.ekg_construct_base import EKGConstructService
1212
from muagent.schemas.apis.ekg_api_schema import *
13+
from muagent.schemas.ekg import *
1314
from muagent.service.ekg_reasoning.src.graph_search.graph_search_main import main
1415

1516

@@ -23,8 +24,39 @@ def wrapping_reponse(result, errorMessage="ok", success=0):
2324
success=success
2425
)
2526

27+
def autofill_nodes(nodes: List[GNode]):
28+
'''
29+
兼容
30+
'''
31+
new_nodes = []
32+
for node in nodes:
33+
schema = TYPE2SCHEMA.get(node.type,)
34+
node.attributes.update(node.attributes.pop("extra", {}))
35+
node_data = schema(
36+
**{**{"id": node.id, "type": node.type}, **node.attributes}
37+
)
38+
node_data = {
39+
k:v
40+
for k, v in node_data.dict().items()
41+
if k not in ["type", "ID", "id", "extra"]
42+
}
43+
new_nodes.append(GNode(**{
44+
"id": node.id,
45+
"type": node.type,
46+
"attributes": {**node_data, **node.attributes}
47+
}))
48+
return new_nodes
49+
2650
#
27-
def init_app(llm, llm_config, embeddings, ekg_construct_service: EKGConstructService, memory_manager, geabase_handler, intention_router):
51+
def init_app(
52+
llm,
53+
llm_config,
54+
embeddings,
55+
ekg_construct_service: EKGConstructService,
56+
memory_manager,
57+
geabase_handler,
58+
intention_router
59+
):
2860

2961
app = FastAPI()
3062

@@ -179,8 +211,6 @@ async def text2graph(request: EKGT2GRequest):
179211

180212

181213
# ~/ekg/graph/update
182-
# @app.post("/ekg/graph/update", response_model=EKGResponse)
183-
# async def update_graph(request: UpdateGraphRequest):
184214
@app.post("/ekg/graph/update", response_model=EKGAIResponse)
185215
async def update_graph(request: EKGFeaturesRequest):
186216
logger.info(request.features.query)
@@ -196,14 +226,17 @@ async def update_graph(request: EKGFeaturesRequest):
196226

197227
# 将 origin_nodes 和 nodes 转换为 GNode 对象
198228
origin_nodes = [GNode(**n) for n in origin_nodes]
229+
origin_nodes = autofill_nodes(origin_nodes)
199230
nodes = [GNode(**n) for n in nodes]
200-
231+
nodes = autofill_nodes(nodes)
232+
201233
# 处理 origin_edges,给每个 edge 设置 type 字段
202234
origin_edges = [
203235
GEdge(
204236
start_id=e['start_id'],
205237
end_id=e['end_id'],
206-
type=f"{nodeid2type_dict.get(e['start_id'], 'unknown')}_route_{nodeid2type_dict.get(e['end_id'], 'unknown')}", # 使用默认值 'unknown' 以防 id 不在字典中
238+
# 使用默认值 'unknown' 以防 id 不在字典中
239+
type=f"{nodeid2type_dict.get(e['start_id'], 'unknown')}_route_{nodeid2type_dict.get(e['end_id'], 'unknown')}",
207240
attributes=e.get("attributes", {})
208241
)
209242
for e in origin_edges
@@ -220,34 +253,6 @@ async def update_graph(request: EKGFeaturesRequest):
220253
for e in edges
221254
]
222255

223-
# 将 GEdge 和 GNode 对象转换回字典以保持原有 JSON 格式
224-
# origin_edges_dict = [
225-
# {
226-
# "start_id": edge.start_id,
227-
# "end_id": edge.end_id,
228-
# "type": edge.type,
229-
# "attributes": edge.attributes
230-
# }
231-
# for edge in origin_edges
232-
# ]
233-
234-
# edges_dict = [
235-
# {
236-
# "start_id": edge.start_id,
237-
# "end_id": edge.end_id,
238-
# "type": edge.type,
239-
# "attributes": edge.attributes
240-
# }
241-
# for edge in edges
242-
# ]
243-
244-
# # 更新 query 的内容,将 edges 和 originEdges 部分重新保存
245-
# request.features.query['originEdges'] = origin_edges_dict
246-
# request.features.query['edges'] = edges_dict
247-
248-
249-
# query = UpdateGraphRequest(**request.features.query)
250-
251256
# 添加预测逻辑的代码
252257
errorMessage = "ok"
253258
successCode = True
@@ -318,9 +323,6 @@ def get_graph(request: EKGFeaturesRequest):
318323
nodeid=query.nodeid, node_type=query.nodeType,
319324
hop=query.hop
320325
)
321-
322-
# nodes = graph.nodes.dict()
323-
# edges = graph.edges.dict()
324326
nodes = graph.nodes
325327
edges = graph.edges
326328
except Exception as e:

muagent/schemas/ekg/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
__all__ = [
55
"EKGEdgeSchema", "EKGNodeSchema",
66
"EKGTaskNodeSchema", "EKGIntentNodeSchema", "EKGAnalysisNodeSchema", "EKGScheduleNodeSchema", "EKGPhenomenonNodeSchema",
7+
"EKGToolTypeSchema", "EKGToolSchema", "EKGAgentSchema",
78
"EKGNodeTbaseSchema", "EKGEdgeTbaseSchema", "EKGTbaseData",
89
"EKGGraphSlsSchema", "EKGSlsData",
910
"SHAPE2TYPE", "TYPE2SCHEMA",

0 commit comments

Comments
 (0)