Skip to content

Commit cc8e44a

Browse files
committed
update from internal
1 parent 17fb8ec commit cc8e44a

File tree

4 files changed

+167
-117
lines changed

4 files changed

+167
-117
lines changed

muagent/connector/memory_manager.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -625,8 +625,20 @@ def append_tools(self, tool_information: dict, chat_index: str, nodeid: str, use
625625
pass
626626
self.append(message)
627627

628-
def get_memory_by_tag(self, tag: str) -> Memory:
629-
return self.get_memory_pool_by_key_content(key='tag', content=f'*{tag}*')
628+
def get_memory_by_chatindex_tags(self, chat_index: str, tags: List[str], limit: int = 10) -> Memory:
629+
'''
630+
:param chat_index: str,
631+
:param tags: List[str], search message by any tag (match or)
632+
'''
633+
tags_str = '|'.join([f"*{tag}*" for tag in tags])
634+
querys = [
635+
f"@chat_index:{chat_index}",
636+
f"@role_tags:{tags_str}",
637+
]
638+
query = f"({')('.join(querys)})" if len(querys) >=2 else "".join(querys)
639+
logger.debug(f"{query}")
640+
r = self.th.search(query, limit=limit)
641+
return self.tbasedoc2Memory(r)
630642

631643
def get_memory_pool(self, chat_index: str = "") -> Memory:
632644
return self.get_memory_pool_by_all({"chat_index": chat_index})

muagent/db_handler/graph_db_handler/geabase_handler.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,11 @@ def add_edges(self, edges: List[GEdge]) -> dict:
8181

8282
def update_node(self, attributes: dict, set_attributes: dict, node_type: str = None, ID: int = None) -> dict:
8383
# demo: "MATCH (n:opsgptkg_employee {@ID: xxxx}) SET n.originname = 'xxx', n.description = 'xxx'"
84-
set_str = ", ".join([f"n.{k}='{v}'" if isinstance(v, (str, bool)) else f"n.{k}={v}" for k, v in set_attributes.items()])
84+
set_str = ", ".join([
85+
f"n.{k}='{v}'" if isinstance(v, (str, bool)) else f"n.{k}={v}"
86+
for k, v in set_attributes.items()
87+
if k not in ["ID"]
88+
])
8589

8690
if (ID is None) or (not isinstance(ID, int)):
8791
ID = self.get_current_nodeID(attributes, node_type)
@@ -94,7 +98,10 @@ def update_edge(self, src_id, dst_id, set_attributes: dict, edge_type: str = Non
9498
src_id, dst_id, timestamp = self.get_current_edgeID(src_id, dst_id, edge_type)
9599
src_type, dst_type = self.get_nodetypes_by_edgetype(edge_type)
96100
# src_id, dst_id = double_hashing(src_id), double_hashing(dst_id)
97-
set_str = ", ".join([f"e.{k}='{v}'" if isinstance(v, (str, bool)) else f"e.{k}={v}" for k, v in set_attributes.items()])
101+
set_str = ", ".join([
102+
f"e.{k}='{v}'" if isinstance(v, (str, bool)) else f"e.{k}={v}"
103+
for k, v in set_attributes.items()
104+
])
98105
# demo: MATCH ()-[r:PlayFor{@src_id:1, @dst_id:100, @timestamp:0}]->() SET r.contract = 0;
99106
# gql = f"MATCH ()-[e:{edge_type}{{@src_id:{src_id}, @dst_id:{dst_id}, timestamp:{timestamp}}}]->() SET {set_str}"
100107
gql = f"MATCH (n0:{src_type} {{@id: {src_id}}})-[e]->(n1:{dst_type} {{@id:{dst_id}}}) SET {set_str}"
@@ -197,7 +204,6 @@ def get_neighbor_nodes(self, attributes: dict, node_type: str = None, return_key
197204
result = self.decode_result(result, gql)
198205
nodes = result.get("n1", []) or result.get("n1.attr", [])
199206
return self.convert2GNodes(nodes)
200-
return [GNode(id=node["id"], type=node["type"], attributes=node) for node in nodes]
201207

202208
def get_neighbor_edges(self, attributes: dict, node_type: str = None, return_keys: list = []) -> List[GEdge]:
203209
#
@@ -233,22 +239,27 @@ def get_hop_infos(self, attributes: dict, node_type: str = None, hop: int = 2, b
233239
result = {}
234240
iter_index = 0
235241
while hop > 1:
236-
if last_node_ids == []:
242+
if last_node_ids == [] and iter_index==0:
237243
#
238244
result = self.execute(gql)
239245
result = self.decode_result(result, gql)
246+
elif last_node_ids == []:
247+
pass
240248
else:
241249
for _node_id, _node_type in zip(last_node_ids, last_node_types):
242250
where_str = f"n0.id='{_node_id}'"
243-
gql = f"MATCH p = (n0:{_node_type} WHERE {where_str})-[e]->{{1,{min(hop, hop_max)}}}(n1) RETURN n0, n1, e, p"
251+
if reverse:
252+
gql = f"MATCH p = (n0:{_node_type} WHERE {where_str})<-[e]-{{1,{min(hop, hop_max)}}}(n1) RETURN n0, n1, e, p"
253+
else:
254+
gql = f"MATCH p = (n0:{_node_type} WHERE {where_str})-[e]->{{1,{min(hop, hop_max)}}}(n1) RETURN n0, n1, e, p"
244255
#
245256
_result = self.execute(gql)
246257
_result = self.decode_result(_result, gql)
247-
# logger.info(f"p_lens: {len(_result['p'])}")
258+
# logger.info(f"p_lens: {_result['p']}")
248259

249-
result = self.merge_hotinfos(result, _result)
260+
result = self.merge_hotinfos(result, _result, reverse=reverse)
250261
#
251-
last_node_ids, last_node_types, result = self.deduplicate_paths(result, block_attributes, select_attributes, hop=min(hop, hop_max)+iter_index*hop_max)
262+
last_node_ids, last_node_types, result = self.deduplicate_paths(result, block_attributes, select_attributes, hop=min(hop, hop_max)+iter_index*hop_max, reverse=reverse)
252263
hop -= hop_max
253264
iter_index += 1
254265

@@ -271,7 +282,7 @@ def get_hop_paths(self, attributes: dict, node_type: str = None, hop: int = 2, b
271282
result = self.get_hop_infos(attributes, node_type, hop, block_attributes)
272283
return result.paths
273284

274-
def deduplicate_paths(self, result, block_attributes: dict = {}, select_attributes: dict = {}, hop:int=None):
285+
def deduplicate_paths(self, result, block_attributes: dict = {}, select_attributes: dict = {}, hop:int=None, reverse=False):
275286
# 获取数据
276287
n0, n1, e, p = result["n0"], result["n1"], result["e"], result["p"]
277288
block_node_ids = [
@@ -292,39 +303,42 @@ def deduplicate_paths(self, result, block_attributes: dict = {}, select_attribut
292303
for path_str, _p in zip(path_strs, p):
293304
if not any(path_str in other for other in path_strs if path_str != other):
294305
new_p.append(_p)
295-
# # 路径去重
296-
# path_strs = ["&&".join(_p) for _p in p]
297-
# new_p = []
298-
# new_path_strs_set = set()
299-
# for path_str, _p in zip(path_strs, p):
300-
# if not any(path_str in other for other in path_strs if path_str != other):
301-
# if path_str not in new_path_strs_set and all([_pid not in block_node_ids for _pid in _p]):
302-
# new_p.append(_p)
303-
# new_path_strs_set.add(path_str)
304306

305307
# 根据保留路径进行合并
306308
nodeid2type = {i["id"]: i["type"] for i in n0+n1}
307309
unique_node_ids = [j for i in new_p for j in i]
308-
last_node_ids = list(set([i[-1] for i in new_p if len(i)>=hop]))
310+
if reverse:
311+
last_node_ids = list(set([i[0] for i in new_p if len(i)>=hop]))
312+
else:
313+
last_node_ids = list(set([i[-1] for i in new_p if len(i)>=hop]))
314+
309315
last_node_types = [nodeid2type[i] for i in last_node_ids]
310316
new_n0 = deduplicate_dict([i for i in n0 if i["id"] in unique_node_ids])
311317
new_n1 = deduplicate_dict([i for i in n1 if i["id"] in unique_node_ids])
312318
new_e = deduplicate_dict([i for i in e if i["start_id"] in unique_node_ids and i["end_id"] in unique_node_ids])
313319

314320
return last_node_ids, last_node_types, {"n0": new_n0, "n1": new_n1, "e": new_e, "p": new_p}
315321

316-
def merge_hotinfos(self, result1, result2) -> Dict:
322+
def merge_hotinfos(self, result1, result2, reverse=False) -> Dict:
317323
old_n0_sets = set([n["id"] for n in result1["n0"]])
318324
old_n1_sets = set([n["id"] for n in result1["n1"]])
319325
new_n0 = result1["n0"] + [n for n in result2["n0"] if n["id"] not in old_n0_sets]
320326
new_n1 = result1["n1"] + [n for n in result2["n1"] if n["id"] not in old_n1_sets]
321327
new_e = result1["e"] + result2["e"]
322-
new_p = result1["p"] + [
323-
p_old_1 + p_old_2[1:]
324-
for p_old_1 in result1["p"]
325-
for p_old_2 in result2["p"]
326-
if p_old_2[0] == p_old_1[-1]
327-
] # + result2["p"]
328+
if reverse:
329+
new_p = result1["p"] + [
330+
p_old_2[:-1] + p_old_1
331+
for p_old_1 in result1["p"]
332+
for p_old_2 in result2["p"]
333+
if p_old_2[-1] == p_old_1[0]
334+
] # + result2["p"]
335+
else:
336+
new_p = result1["p"] + [
337+
p_old_1 + p_old_2[1:]
338+
for p_old_1 in result1["p"]
339+
for p_old_2 in result2["p"]
340+
if p_old_2[0] == p_old_1[-1]
341+
] # + result2["p"]
328342
new_result = {"n0": new_n0, "n1": new_n1, "e": new_e, "p": new_p}
329343
return new_result
330344

0 commit comments

Comments
 (0)