Skip to content

Commit 9e1b6aa

Browse files
Merge pull request #72 from codefuse-ai/bug_70
bugfix: #70 acquire llm output by CustomLLMModel class
2 parents 089b68f + d13bebc commit 9e1b6aa

File tree

8 files changed

+408
-12
lines changed

8 files changed

+408
-12
lines changed

examples/ekg_examples/who_is_spy_game.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,21 @@ def test_whoisspy_datas(ekg_service, ):
286286

287287
logger.info(neighbor_nodes)
288288
logger.info(current_nodes)
289+
290+
291+
logger.info('剧本杀/谁是卧底/智能交互/开始新一轮的讨论')
292+
start_nodetype ='opsgptkg_task'
293+
start_nodeid = hash_id('剧本杀/谁是卧底/智能交互/开始新一轮的讨论')
294+
295+
neighbor_nodes = ekg_service.gb.get_neighbor_nodes(attributes={"id": start_nodeid,},
296+
node_type=start_nodetype)
297+
298+
current_nodes = ekg_service.gb.get_current_nodes(attributes={"id": start_nodeid,},
299+
node_type=start_nodetype)
300+
301+
logger.info(neighbor_nodes)
302+
logger.info(current_nodes)
303+
289304

290305

291306

muagent/codechat/code_analyzer/code_intepreter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ def get_intepretation(self, code_list):
3232
res = {}
3333
for code in code_list:
3434
message = CODE_INTERPERT_TEMPLATE.format(code=code)
35-
message = [HumanMessage(content=message)]
36-
chat_res = chat_model.predict_messages(message)
37-
content = chat_res.content
35+
# message = [HumanMessage(content=message)]
36+
# chat_res = chat_model.predict_messages(message)
37+
# content = chat_res.content
38+
content = chat_model.predict(message)
3839
res[code] = content
3940
return res
4041

muagent/codechat/code_search/cypher_generator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ def get_cypher(self, query: str):
5454
content = self.NGQL_GENERATION_PROMPT.format(schema=schema, question=query)
5555
# logger.info(content)
5656
ans = ''
57-
message = [HumanMessage(content=content)]
58-
chat_res = self.model.predict_messages(message)
59-
ans = chat_res.content
57+
# message = [HumanMessage(content=content)]
58+
# chat_res = self.model.predict_messages(message)
59+
# ans = chat_res.content
60+
self.model.predict(content)
6061

6162
ans = replace_lt_gt(ans)
6263

muagent/db_handler/graph_db_handler/nebula_handler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def execute_cypher(self, cypher: str, space_name: str = '',ignore_log: bool = Fa
9191

9292
if ignore_log == False:
9393
if resp.is_succeeded():
94-
logger.info(f"Successfully executed Cypher query: {cypher}")
94+
#logger.info(f"Successfully executed Cypher query: {cypher}")
95+
pass
9596

9697
else:
9798
logger.error(f"Failed to execute Cypher query: {cypher}")

muagent/service/ekg_reasoning/src/geabase_handler/geabase_handlerplus.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def get_extra_tag(self, rootNodeId = 'None', rootNodeType = 'opsgptkg_task', key
449449
oneNode = self.geabase_handler.get_current_node(attributes={"id": rootNodeId,},
450450
node_type=rootNodeType)
451451
except:
452-
logging.info('user_input_memory_tag 没有找到合适的数据, 可能原因是当前查找对象不是 opsgptkg_task 类型的节点' )
452+
logging.info('get_extra_tag 没有找到合适的数据, 可能原因是当前查找对象不是 opsgptkg_task 类型的节点' )
453453
return None
454454
# print(oneNode)
455455
if oneNode.attributes['extra'] == '':
@@ -463,6 +463,29 @@ def get_extra_tag(self, rootNodeId = 'None', rootNodeType = 'opsgptkg_task', key
463463
return None
464464
else:
465465
return extra[key]
466+
467+
def get_tag(self, rootNodeId = 'None', rootNodeType = 'opsgptkg_task', key = 'ignorememory'):
468+
# print(f'rootNodeId is {rootNodeId}, rootNodeType is {rootNodeType}')
469+
try:
470+
oneNode = self.geabase_handler.get_current_node(attributes={"id": rootNodeId,},
471+
node_type=rootNodeType)
472+
except:
473+
logging.info(f'get_tag 没有找到合适的数据, 可能原因是当前对象为 {rootNodeType} 类型的节点' )
474+
return None
475+
# print(oneNode)
476+
if key not in oneNode.attributes.keys():
477+
logging.info(f'get_tag 没有找到合适的数据, 可能原因key名错误' )
478+
return None
479+
if oneNode.attributes[key] == '':
480+
logging.info(f'get_tag 没有找到合适的数据, 可能原因为空字符串' )
481+
return None
482+
483+
return oneNode.attributes[key]
484+
485+
486+
487+
488+
466489

467490
def user_input_memory_tag(self, rootNodeId = 'None', rootNodeType = 'opsgptkg_task'):
468491
print(f'rootNodeId is {rootNodeId}, rootNodeType is {rootNodeType}')

muagent/service/ekg_reasoning/src/graph_search/geabase_search_plus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ def get_tool_ancestor(self, sessionId, start_nodeid = '为什么余额宝没收
547547

548548
#查祖先节点 reverse=True
549549
neighborNodes = self.geabase_handler.get_neighbor_nodes(attributes={"id": nodeid_now,}, node_type=nodetype_now, reverse=True)
550-
print(nodeid_now, nodetype_now, neighborNodes, '=========')
550+
#print(nodeid_now, nodetype_now, neighborNodes, '=========')
551551

552552

553553
for i in range(len(neighborNodes) ):

muagent/service/ekg_reasoning/src/graph_search/graph_search_main.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -605,11 +605,13 @@ def outputFuc(self):
605605
currentNodeType = self.gst.search_node_type(self.nodeid_in_subtree, self.currentNodeId )
606606
logging.info(f"currentNodeId is {self.currentNodeId} ")
607607
logging.info(f"currentNodeType is {currentNodeType} , currentNodeId is {self.currentNodeId}")
608-
if self.gb_handler.get_extra_tag(rootNodeId = self.currentNodeId, rootNodeType = currentNodeType, key = 'dodisplay') == 'True' or self.gb_handler.get_extra_tag(rootNodeId = self.currentNodeId, rootNodeType = currentNodeType, key = 'dodisplay') == 'Ture':
608+
if self.gb_handler.get_extra_tag(rootNodeId = self.currentNodeId, rootNodeType = currentNodeType, key = 'dodisplay') == 'True' \
609+
or self.gb_handler.get_extra_tag(rootNodeId = self.currentNodeId, rootNodeType = currentNodeType, key = 'dodisplay') == 'Ture' \
610+
or self.gb_handler.get_tag(rootNodeId = self.currentNodeId, rootNodeType = currentNodeType, key = 'dodisplay') == 'True' :
609611
outputinfo_str = self.memory_handler.get_output(self.sessionId, self.start_datetime, self.end_datetime)
610612
else:
611-
dodisplaystr = self.gb_handler.get_extra_tag(rootNodeId = self.currentNodeId, rootNodeType = currentNodeType, key = 'dodisplay')
612-
logging.info(f" 查询dodisplay字段结果为空, 结果为{dodisplaystr}")
613+
dodisplaystr = self.gb_handler.get_tag(rootNodeId = self.currentNodeId, rootNodeType = currentNodeType, key = 'dodisplay')
614+
logging.info(f" 查询dodisplay字段结果为空, 或者为{dodisplaystr},本次不对外输出")
613615
outputinfo_str = None
614616

615617

0 commit comments

Comments
 (0)