Skip to content

Commit 6a8fa07

Browse files
committed
[bugfix][graph db search circle bug]
1 parent a4ee39d commit 6a8fa07

File tree

5 files changed

+21
-11
lines changed

5 files changed

+21
-11
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ dist
1717
.ipynb_checkpoints
1818
zdatafront*
1919
*antgroup*
20-
*ipynb
20+
*ipynb
21+
*log

docker-compose.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
version: '3.4'
1+
version: 'Beta'
22
services:
33
metad0:
44
# image: docker.io/vesoft/nebula-metad:v3.8.0
@@ -129,7 +129,7 @@ services:
129129
- 11434:11434
130130
volumes:
131131
# - //d/models/ollama:/root/.ollama # windows path
132-
- /Users/wangyunpeng/Downloads/models:/root/.ollama # linux/mac path
132+
# - /Users/wangyunpeng/Downloads/models:/root/.ollama # linux/mac path
133133
networks:
134134
- ekg-net
135135
restart: on-failure
@@ -170,4 +170,4 @@ services:
170170
networks:
171171
ekg-net:
172172
# driver: bridge
173-
# external: true
173+
external: true

muagent/db_handler/graph_db_handler/geabase_handler.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,17 +301,20 @@ def deduplicate_paths(self, result, block_attributes: List[dict] = {}, select_at
301301
# deduplicate the paths
302302
path_strs = ["&&".join(_p) for _p in p]
303303
new_p = []
304+
add_path_strs = set()
304305
for path_str, _p in zip(path_strs, p):
306+
if path_str in add_path_strs: continue
305307
if not any(path_str in other for other in path_strs if path_str != other):
306308
new_p.append(_p)
309+
add_path_strs.add(path_str)
307310

308311
# 根据保留路径进行合并
309312
nodeid2type = {i["id"]: i["type"] for i in n0+n1}
310313
unique_node_ids = [j for i in new_p for j in i]
311314
if reverse:
312-
last_node_ids = list(set([i[0] for i in new_p if len(i)>=hop]))
315+
last_node_ids = list(set([i[0] for i in new_p if len(i)>=hop and i[0] not in i[1:]]))
313316
else:
314-
last_node_ids = list(set([i[-1] for i in new_p if len(i)>=hop]))
317+
last_node_ids = list(set([i[-1] for i in new_p if len(i)>=hop and i[-1] not in i[:-1]]))
315318

316319
last_node_types = [nodeid2type[i] for i in last_node_ids]
317320
new_n0 = deduplicate_dict([i for i in n0 if i["id"] in unique_node_ids])
@@ -393,12 +396,17 @@ def decode_result(self, geabase_result, gql: str) -> Dict:
393396
def decode_path(self, col_data, k) -> List:
394397
steps = col_data.get("pathVal", {}).get("steps", [])
395398
connections = {}
396-
for step in steps:
399+
head = None
400+
path = []
401+
for idx, step in enumerate(steps):
397402
props = step["props"]
398403
start = props["original_src_id1__"].get("strVal", "") or props["original_src_id1__"].get("intVal", -1)
399404
end = props["original_dst_id2__"].get("strVal", "") or props["original_dst_id2__"].get("intVal", -1)
400405
connections[start] = end
401406

407+
head = start if idx==0 else head
408+
path = [start] if idx==0 else path
409+
402410
# 找到头部(1)
403411
for k in connections:
404412
if k not in connections.values():
@@ -409,7 +417,8 @@ def decode_path(self, col_data, k) -> List:
409417
while head in connections:
410418
head = connections[head]
411419
path.append(head)
412-
420+
if head == path[0]:
421+
break
413422
return path
414423

415424
def decode_vertex(self, col_data, k) -> Dict:

muagent/db_handler/vector_db_handler/tbase_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def insert_data_hash(
6363
self,
6464
data_list: Union[list[dict], dict],
6565
key: str = "message_index",
66-
expire_time: int = 86400,
66+
expire_time: int = None,
6767
need_etime: bool = True
6868
):
6969
'''

muagent/service/ekg_construct/ekg_construct_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -630,8 +630,8 @@ def get_graph_by_nodeid(
630630
) -> Graph:
631631
if hop<2:
632632
raise Exception(f"hop must be smaller than 2, now hop is {hop}")
633-
if hop >= 14:
634-
raise Exception(f"hop can't be larger than 14, now hop is {hop}")
633+
if hop >= 20:
634+
raise Exception(f"hop can't be larger than 20, now hop is {hop}")
635635
# filter the node which dont match teamid
636636
result = self.gb.get_hop_infos(
637637
{'id': nodeid}, node_type=node_type,

0 commit comments

Comments
 (0)