Skip to content

Commit fae1ae0

Browse files
committed
[bugfix][update_graph]
1 parent aa4b0cd commit fae1ae0

File tree

13 files changed

+399
-98
lines changed

13 files changed

+399
-98
lines changed

Dockerfile

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
From python:3.9.18-bookworm
2+
# FROM python:3.9-slim-bookworm
23

34
WORKDIR /home/user
45

@@ -11,10 +12,11 @@ COPY ./requirements.txt /home/user/docker_requirements.txt
1112
# RUN service inetutils-inetd start
1213
# service inetutils-inetd status
1314

14-
RUN wget https://oss-cdn.nebula-graph.com.cn/package/3.6.0/nebula-graph-3.6.0.ubuntu1804.amd64.deb
15-
RUN dpkg -i nebula-graph-3.6.0.ubuntu1804.amd64.deb
15+
# RUN wget https://oss-cdn.nebula-graph.com.cn/package/3.6.0/nebula-graph-3.6.0.ubuntu1804.amd64.deb
16+
# RUN dpkg -i nebula-graph-3.6.0.ubuntu1804.amd64.deb
1617

1718
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
18-
RUN pip install -r /home/user/docker_requirements.txt
19+
RUN pip install fastapi uvicorn notebook
20+
# RUN pip install -r /home/user/docker_requirements.txt
1921

2022
CMD ["bash"]

docker-compose.yaml

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
version: '0.1'
2+
3+
services:
4+
metad0:
5+
image: vesoft/nebula-metad:v3.8.0
6+
container_name: metad0
7+
environment:
8+
USER: root
9+
command:
10+
- --meta_server_addrs=metad0:9559
11+
- --local_ip=metad0
12+
- --ws_ip=metad0
13+
- --port=9559
14+
- --ws_http_port=19559
15+
- --data_path=/data/meta
16+
- --log_dir=/logs
17+
- --v=0
18+
- --minloglevel=0
19+
healthcheck:
20+
test: ["CMD", "curl", "-sf", "http://metad0:19559/status"]
21+
interval: 30s
22+
timeout: 10s
23+
retries: 3
24+
start_period: 20s
25+
ports:
26+
- 9559:9559
27+
- 19559:19559
28+
- 19560
29+
volumes:
30+
- ./data/meta0:/data/meta
31+
- ./logs/meta0:/logs
32+
networks:
33+
- ekg-net
34+
restart: on-failure
35+
cap_add:
36+
- SYS_PTRACE
37+
38+
storaged0:
39+
image: vesoft/nebula-storaged:v3.8.0
40+
container_name: storaged0
41+
environment:
42+
USER: root
43+
TZ: "${TZ}"
44+
command:
45+
- --meta_server_addrs=metad0:9559
46+
- --local_ip=storaged0
47+
- --ws_ip=storaged0
48+
- --port=9779
49+
- --ws_http_port=19779
50+
- --data_path=/data/storage
51+
- --log_dir=/logs
52+
- --v=0
53+
- --minloglevel=0
54+
depends_on:
55+
- metad0
56+
healthcheck:
57+
test: ["CMD", "curl", "-sf", "http://storaged0:19779/status"]
58+
interval: 30s
59+
timeout: 10s
60+
retries: 3
61+
start_period: 20s
62+
ports:
63+
- 9779:9779
64+
- 19779:19779
65+
- 19780
66+
volumes:
67+
- ./data/storage0:/data/storage
68+
- ./logs/storage0:/logs
69+
networks:
70+
- ekg-net
71+
restart: on-failure
72+
cap_add:
73+
- SYS_PTRACE
74+
75+
graphd:
76+
image: vesoft/nebula-graphd:v3.8.0
77+
container_name: graphd
78+
environment:
79+
USER: root
80+
TZ: "${TZ}"
81+
command:
82+
- --meta_server_addrs=metad0:9559
83+
- --port=9669
84+
- --local_ip=graphd
85+
- --ws_ip=graphd
86+
- --ws_http_port=19669
87+
- --log_dir=/logs
88+
- --v=0
89+
- --minloglevel=0
90+
depends_on:
91+
- storaged0
92+
healthcheck:
93+
test: ["CMD", "curl", "-sf", "http://graphd:19669/status"]
94+
interval: 30s
95+
timeout: 10s
96+
retries: 3
97+
start_period: 20s
98+
ports:
99+
- 9669:9669
100+
- 19669:19669
101+
- 19670
102+
volumes:
103+
- ./logs/graph:/logs
104+
networks:
105+
- ekg-net
106+
restart: on-failure
107+
cap_add:
108+
- SYS_PTRACE
109+
110+
redis-stack:
111+
image: redis/redis-stack:7.4.0-v0
112+
container_name: redis
113+
ports:
114+
- "6379:6379"
115+
- "8001:8001"
116+
volumes:
117+
- ./logs/redis:/var/lib/redis/logs
118+
networks:
119+
- ekg-net
120+
restart: always
121+
122+
123+
ollama:
124+
image: ollama/ollama:0.3.6
125+
container_name: ollama
126+
environment:
127+
USER: root
128+
TZ: "${TZ}"
129+
ports:
130+
- 11434:11434
131+
volumes:
132+
- //d/models/ollama:/root/.ollama # windows path
133+
# - /User/models:/root/.ollama # linux/mac path
134+
networks:
135+
- ekg-net
136+
restart: on-failure
137+
cap_add:
138+
- SYS_PTRACE
139+
# deploy:
140+
# resources:
141+
# reservations:
142+
# devices:
143+
# - driver: nvidia
144+
# count: all # 或者您想要的数量,例如 1
145+
# capabilities: [gpu]
146+
147+
ekgservice:
148+
build:
149+
context: .
150+
dockerfile: Dockerfile
151+
container_name: ekgservice
152+
image: muagent:test
153+
environment:
154+
USER: root
155+
TZ: "${TZ}"
156+
ports:
157+
- 5050:3737
158+
# - 8080:8888
159+
volumes:
160+
- ./examples:/home/user/muagent/examples
161+
- ./muagent:/home/user/muagent/muagent
162+
- ./tests:/home/user/muagent/tests
163+
restart: on-failure
164+
networks:
165+
- ekg-net
166+
# command: ["python", "/home/user/muagent/examples/ekg_examples/start.py"] # 指定要执行的脚本
167+
command: ["python", "/home/user/muagent/tests/httpapis/fastapi_test.py"] # 指定要执行的脚本
168+
169+
networks:
170+
ekg-net:

muagent/connector/memory_manager.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ def get_memory_pool_by_key_content(self, key: str, content: str, ):
655655
r = self.th.search(content)
656656
return self.tbasedoc2Memory(r)
657657

658-
def get_memory_pool_by_all(self, search_key_contents: dict):
658+
def get_memory_pool_by_all(self, search_key_contents: dict, limit: int =10):
659659
'''
660660
search_key_contents:
661661
- key: str, key must in message keys
@@ -666,11 +666,17 @@ def get_memory_pool_by_all(self, search_key_contents: dict):
666666
if not v: continue
667667
if k == "keyword":
668668
querys.append(f"@{k}:{{{v}}}")
669+
elif k == "role_tags":
670+
tags_str = '|'.join([f"*{tag}*" for tag in v]) if isinstance(v, list) else f"{v}"
671+
querys.append(f"@role_tags:{tags_str}")
672+
elif k == "start_datetime":
673+
query = f"(@start_datetime:[{v[0]} {v[1]}])"
674+
querys.append(query)
669675
else:
670676
querys.append(f"@{k}:{v}")
671677

672678
query = f"({')('.join(querys)})" if len(querys) >=2 else "".join(querys)
673-
r = self.th.search(query)
679+
r = self.th.search(query, limit=limit)
674680
return self.tbasedoc2Memory(r)
675681

676682
def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, chat_index: str = "default", **kwargs) -> List[Message]:
@@ -707,7 +713,7 @@ def text_retrieval(self, text: str, chat_index: str = "default", **kwargs) -> L
707713
return self._text_retrieval_from_cache(memory.messages, text)
708714

709715
def datetime_retrieval(self, chat_index: str, datetime: str, text: str = None, n: int = 5, key: str = "start_datetime", **kwargs) -> List[Message]:
710-
intput_timestamp = datefromatToTimestamp(datetime, 1)
716+
intput_timestamp = dateformatToTimestamp(datetime, 1000, "%Y-%m-%d %H:%M:%S.%f")
711717
query = f"(@chat_index:{chat_index})(@{key}:[{intput_timestamp-n*60} {intput_timestamp+n*60}])"
712718
# logger.debug(f"datetime_retrieval query: {query}")
713719
r = self.th.search(query)
@@ -787,8 +793,8 @@ def localMessage2TbaseMessage(self, message: Message):
787793
# if content is not None:
788794
# tbase_message["customed_kargs"][key] = content
789795

790-
tbase_message["start_datetime"] = datefromatToTimestamp(message.start_datetime, 1)
791-
tbase_message["end_datetime"] = datefromatToTimestamp(message.end_datetime, 1)
796+
tbase_message["start_datetime"] = dateformatToTimestamp(message.start_datetime, 1000, "%Y-%m-%d %H:%M:%S.%f")
797+
tbase_message["end_datetime"] = dateformatToTimestamp(message.end_datetime, 1000, "%Y-%m-%d %H:%M:%S.%f")
792798

793799
if self.use_vector and self.embed_config:
794800
vector_dict = get_embedding(
@@ -830,8 +836,8 @@ def tbasedoc2Memory(self, r_docs) -> Memory:
830836
memory.append(message)
831837

832838
for message in memory.messages:
833-
message.start_datetime = timestampToDateformat(int(message.start_datetime), 1)
834-
message.end_datetime = timestampToDateformat(int(message.end_datetime), 1)
839+
message.start_datetime = timestampToDateformat(int(message.start_datetime), 1000, "%Y-%m-%d %H:%M:%S.%f")
840+
message.end_datetime = timestampToDateformat(int(message.end_datetime), 1000, "%Y-%m-%d %H:%M:%S.%f")
835841

836842
memory.sort_by_key("end_datetime")
837843
# for message in memory.message:

muagent/connector/schema/message.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ def check_datetime(cls, values):
7979
start_datetime = values.get("start_datetime")
8080
end_datetime = values.get("end_datetime")
8181
if start_datetime is None:
82-
values["start_datetime"] = getCurrentDatetime()
82+
values["start_datetime"] = getCurrentDatetime("%Y-%m-%d %H:%M:%S.%f")
8383
if end_datetime is None:
84-
values["end_datetime"] = getCurrentDatetime()
84+
values["end_datetime"] = getCurrentDatetime("%Y-%m-%d %H:%M:%S.%f")
8585
return values
8686

8787
@root_validator(pre=True)
@@ -99,7 +99,7 @@ def check_message_index(cls, values):
9999
def update_attribute(self, key: str, value):
100100
if hasattr(self, key):
101101
setattr(self, key, value)
102-
self.end_datetime = getCurrentDatetime()
102+
self.end_datetime = getCurrentDatetime("%Y-%m-%d %H:%M:%S.%f")
103103
else:
104104
raise AttributeError(f"{key} is not a valid property of {self.__class__.__name__}")
105105

muagent/db_handler/graph_db_handler/base_gb_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,5 @@ def get_neighbor_nodes(self, attributes: dict, node_type: str = None, return_key
6464
def get_neighbor_edges(self, attributes: dict, node_type: str = None, return_keys: list = []) -> List[GEdge]:
6565
pass
6666

67-
def get_hop_infos(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: dict = {}, select_attributes: dict = {}, reverse=False) -> Graph:
67+
def get_hop_infos(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: List[dict] = {}, select_attributes: dict = {}, reverse=False) -> Graph:
6868
pass

muagent/db_handler/graph_db_handler/geabase_handler.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def check_neighbor_exist(self, attributes: dict, node_type: str = None, check_at
223223
filter_result = [i for i in result if all([item in i.attributes.items() for item in check_attributes.items()])]
224224
return len(filter_result) > 0
225225

226-
def get_hop_infos(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: dict = {}, select_attributes: dict = {}, reverse=False) -> Graph:
226+
def get_hop_infos(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: List[dict] = [], select_attributes: dict = {}, reverse=False) -> Graph:
227227
'''
228228
hop >= 2, 表面需要至少两跳
229229
'''
@@ -267,30 +267,31 @@ def get_hop_infos(self, attributes: dict, node_type: str = None, hop: int = 2, b
267267
edges = self.convert2GEdges(result.get("e", []))
268268
return Graph(nodes=nodes, edges=edges, paths=result.get("p", []))
269269

270-
def get_hop_nodes(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: dict = []) -> List[GNode]:
270+
def get_hop_nodes(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: List[dict] = []) -> List[GNode]:
271271
#
272272
result = self.get_hop_infos(attributes, node_type, hop, block_attributes)
273273
return result.nodes
274274

275-
def get_hop_edges(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: dict = []) -> List[GEdge]:
275+
def get_hop_edges(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: List[dict] = []) -> List[GEdge]:
276276
#
277277
result = self.get_hop_infos(attributes, node_type, hop, block_attributes)
278278
return result.edges
279279

280-
def get_hop_paths(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: dict = []) -> List[str]:
280+
def get_hop_paths(self, attributes: dict, node_type: str = None, hop: int = 2, block_attributes: List[dict] = []) -> List[str]:
281281
#
282282
result = self.get_hop_infos(attributes, node_type, hop, block_attributes)
283283
return result.paths
284284

285-
def deduplicate_paths(self, result, block_attributes: dict = {}, select_attributes: dict = {}, hop:int=None, reverse=False):
285+
def deduplicate_paths(self, result, block_attributes: List[dict] = {}, select_attributes: dict = {}, hop:int=None, reverse=False):
286286
# 获取数据
287287
n0, n1, e, p = result["n0"], result["n1"], result["e"], result["p"]
288288
block_node_ids = [
289289
i["id"]
290290
for i in n0+n1
291+
for block_attribute in block_attributes
291292
# 这里block为空时也会生效,属于合理情况
292293
# if block_attributes=={} or all(item in i.items() for item in block_attributes.items())
293-
if block_attributes and all(item in i.items() for item in block_attributes.items())
294+
if block_attribute and all(item in i.items() for item in block_attribute.items())
294295
] + [
295296
i["id"]
296297
for i in n0+n1

muagent/httpapis/ekg_construct/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ async def embedding_predict(request: EmbeddingsRequest):
247247

248248
def create_api(llm, embeddings):
249249
app = init_app(llm, embeddings)
250-
uvicorn.run(app, host="localhost", port=3737)
250+
uvicorn.run(app, host="127.0.0.1", port=3737)
251251

252252
# def create_api(ekg_construct_service: EKGConstructService):
253253
# app = init_app(ekg_construct_service)

0 commit comments

Comments
 (0)