Skip to content

Commit 6214025

Browse files
committed
feat: add right paper cite
feat: add right paper cite
1 parent 8d51aa0 commit 6214025

File tree

6 files changed

+133
-69
lines changed

6 files changed

+133
-69
lines changed

backend/.env.dev.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ MAX_RETRIES=5
1313
# E2B_API_KEY=
1414
SERVER_HOST=http://localhost:8000
1515

16+
# 使用 email 注册账号从 https://openalex.org/ 文献
17+
OPENALEX_EMAIL=
18+
1619
LOG_LEVEL=DEBUG
1720
DEBUG=true
1821
# 确保安装 Redis

backend/app/core/agents.py

Lines changed: 85 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
CODER_PROMPT,
88
MODELER_PROMPT,
99
)
10-
from app.core.functions import tools
10+
from app.core.functions import coder_tools, writer_tools
1111
from app.models.model import CoderToWriter
1212
from app.models.user_output import UserOutput
1313
from app.utils.enums import CompTemplate, FormatOutPut
@@ -17,6 +17,7 @@
1717
from app.utils.redis_manager import redis_manager
1818
from app.schemas.response import SystemMessage
1919
from app.tools.base_interpreter import BaseCodeInterpreter
20+
from app.tools.openalex_scholar import OpenAlexScholar
2021

2122

2223
class Agent:
@@ -26,7 +27,7 @@ def __init__(
2627
model: LLM,
2728
max_chat_turns: int = 30, # 单个agent最大对话轮次
2829
user_output: UserOutput = None,
29-
max_memory: int = 20, # 最大记忆轮次
30+
max_memory: int = 25, # 最大记忆轮次
3031
) -> None:
3132
self.task_id = task_id
3233
self.model = model
@@ -85,7 +86,7 @@ def clear_memory(self):
8586
self.chat_history = self.chat_history[:2] + self.chat_history[-5:]
8687

8788

88-
class ModelerAgent(Agent): # 继承自Agent类而不是BaseModel
89+
class ModelerAgent(Agent): # 继承自Agent类
8990
def __init__(
9091
self,
9192
model: LLM,
@@ -168,7 +169,7 @@ async def run(self, prompt: str, subtask_title: str) -> CoderToWriter:
168169
logger.info(f"当前对话轮次: {self.current_chat_turns}")
169170
response = await self.model.chat(
170171
history=self.chat_history,
171-
tools=tools,
172+
tools=coder_tools,
172173
tool_choice="auto",
173174
agent_name=self.__class__.__name__,
174175
)
@@ -274,7 +275,7 @@ async def run(self, prompt: str, subtask_title: str) -> CoderToWriter:
274275

275276
completion_response = await self.model.chat(
276277
history=self.chat_history,
277-
tools=tools,
278+
tools=coder_tools,
278279
tool_choice="auto",
279280
agent_name=self.__class__.__name__,
280281
)
@@ -318,10 +319,12 @@ def __init__(
318319
comp_template: CompTemplate = CompTemplate,
319320
format_output: FormatOutPut = FormatOutPut.Markdown,
320321
user_output: UserOutput = None,
322+
scholar: OpenAlexScholar = None,
321323
) -> None:
322324
super().__init__(task_id, model, max_chat_turns, user_output)
323325
self.format_out_put = format_output
324326
self.comp_template = comp_template
327+
self.scholar = scholar
325328
self.system_prompt = get_writer_prompt(format_output)
326329
self.available_images: list[str] = []
327330

@@ -347,28 +350,86 @@ async def run(
347350
image_prompt = f"\n可用的图片链接列表:\n{image_list}\n请在写作时适当引用这些图片链接。"
348351
prompt = prompt + image_prompt
349352

350-
try:
351-
logger.info(f"{self.__class__.__name__}:开始:执行对话")
352-
self.current_chat_turns = 0 # 重置对话轮次计数器
353+
logger.info(f"{self.__class__.__name__}:开始:执行对话")
354+
self.current_chat_turns += 1 # 重置对话轮次计数器
353355

354-
# 更新对话历史
355-
self.append_chat_history({"role": "system", "content": self.system_prompt})
356-
self.append_chat_history({"role": "user", "content": prompt})
356+
# 更新对话历史
357+
self.append_chat_history({"role": "system", "content": self.system_prompt})
358+
self.append_chat_history({"role": "user", "content": prompt})
357359

358-
# 获取历史消息用于本次对话
359-
response = await self.model.chat(
360-
history=self.chat_history,
361-
agent_name=self.__class__.__name__,
362-
sub_title=sub_title,
363-
)
360+
# 获取历史消息用于本次对话
361+
response = await self.model.chat(
362+
history=self.chat_history,
363+
tools=writer_tools,
364+
tool_choice="auto",
365+
agent_name=self.__class__.__name__,
366+
sub_title=sub_title,
367+
)
368+
369+
if (
370+
hasattr(response.choices[0].message, "tool_calls")
371+
and response.choices[0].message.tool_calls
372+
):
373+
logger.info("检测到工具调用")
374+
tool_call = response.choices[0].message.tool_calls[0]
375+
tool_id = tool_call.id
376+
tool_call.function.name
377+
if tool_call.function.name == "search_papers":
378+
logger.info("调用工具: search_papers")
379+
await redis_manager.publish_message(
380+
self.task_id,
381+
SystemMessage(content=f"写作手调用{tool_call.function.name}工具"),
382+
)
383+
384+
query = json.loads(tool_call.function.arguments)["query"]
385+
386+
full_content = response.choices[0].message.content
387+
# 更新对话历史 - 添加助手的响应
388+
self.append_chat_history(
389+
{
390+
"role": "assistant",
391+
"content": full_content,
392+
"tool_calls": [
393+
{
394+
"id": tool_id,
395+
"type": "function",
396+
"function": {
397+
"name": "search_papers",
398+
"arguments": json.dumps({"query": query}),
399+
},
400+
}
401+
],
402+
}
403+
)
404+
405+
try:
406+
papers = self.scholar.search_papers(query)
407+
except Exception as e:
408+
logger.error(f"搜索文献失败: {str(e)}")
409+
return f"搜索文献失败: {str(e)}"
410+
# TODO: pass to frontend
411+
self.scholar.print_papers(papers)
412+
self.append_chat_history(
413+
{
414+
"role": "tool",
415+
"content": papers,
416+
"tool_call_id": tool_id,
417+
"name": "search_papers",
418+
}
419+
)
420+
next_response = await self.model.chat(
421+
history=self.chat_history,
422+
tools=writer_tools,
423+
tool_choice="auto",
424+
agent_name=self.__class__.__name__,
425+
sub_title=sub_title,
426+
)
427+
response_content = next_response.choices[0].message.content
428+
else:
364429
response_content = response.choices[0].message.content
365-
self.chat_history.append({"role": "assistant", "content": response_content})
366-
logger.info(f"{self.__class__.__name__}:完成:执行对话")
367-
return response_content
368-
except Exception as e:
369-
error_msg = f"执行过程中遇到错误: {str(e)}"
370-
logger.error(f"Agent执行失败: {str(e)}")
371-
return error_msg
430+
self.chat_history.append({"role": "assistant", "content": response_content})
431+
logger.info(f"{self.__class__.__name__}:完成:执行对话")
432+
return response_content
372433

373434
async def summarize(self) -> str:
374435
"""

backend/app/core/functions.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
from typing import List, Dict, Any
2-
from semanticscholar import SemanticScholar, PaginatedResults
3-
4-
tools = [
1+
coder_tools = [
52
{
63
"type": "function",
74
"function": {
@@ -33,19 +30,22 @@
3330
# TODO: get_cites
3431

3532

36-
def search_papers(query: str) -> List[Dict[str, Any]]:
37-
"""Search for papers using a query string."""
38-
sch = SemanticScholar()
39-
results: PaginatedResults = sch.search_paper(query, limit=10)
40-
return [
41-
{
42-
"title": paper.title,
43-
"abstract": paper.abstract,
44-
"authorsName": [author.name for author in paper.authors],
45-
"citations": [citation.title for citation in paper.citations],
46-
}
47-
for paper in results
48-
]
49-
50-
5133
## writeragent tools
34+
writer_tools = [
35+
{
36+
"type": "function",
37+
"function": {
38+
"name": "search_papers",
39+
"description": "Search for papers using a query string.",
40+
"strict": True,
41+
"parameters": {
42+
"type": "object",
43+
"properties": {
44+
"query": {"type": "string", "description": "The query string"}
45+
},
46+
},
47+
"required": ["query"],
48+
"additionalProperties": False,
49+
},
50+
},
51+
]

backend/app/core/prompts.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
4. The working directory is already set up, and any uploaded files are already in the current directory
2626
5. You can directly access files in the current directory without asking the user about file existence
2727
6. For data analysis tasks, if you see Excel files (.xlsx), use pandas to read them directly
28-
7. try to visualize the data , process and results using seaborn and matplotlibs
28+
7. try to visualize the data , process and results using seaborn first and then matplotlibs,be nature sci style.
2929
3030
For example:
3131
# Correct:
@@ -60,7 +60,7 @@
6060
11. 保存的图片名称需要语义化,方便用户理解
6161
12. 在生成代码时,对于包含单引号的字符串,请使用双引号包裹,避免使用转义字符
6262
13. **你尽量在较少的对话轮次内完成任务。减少反复思考的次数**
63-
14. 在求解问题和建立模型过程中,进行充分可视化
63+
14. 在求解问题和建立模型**过程中**,进行充分可视化
6464
6565
6666
Important:
@@ -89,6 +89,7 @@ def get_writer_prompt(
8989
4. 严格按照参考用户输入的格式模板以及**正确的编号顺序**
9090
5. 不需要询问用户
9191
6. 当提到图片时,请使用提供的图片列表中的文件名
92+
7. when you write,check if you need to use tools search_papers to cite.if you need, markdown Footnote e.g.[^1]
9293
"""
9394

9495

backend/app/core/workflow.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from app.core.agents import WriterAgent, CoderAgent
22
from app.core.llm import LLM, simple_chat
3-
from app.models.model import CoderToWriter
43
from app.schemas.request import Problem
54
from app.schemas.response import SystemMessage
5+
from app.tools.openalex_scholar import OpenAlexScholar
66
from app.utils.log_util import logger
77
from app.utils.common_utils import create_work_dir, get_config_template
88
from app.models.user_output import UserOutput
@@ -66,6 +66,10 @@ async def execute(self, problem: Problem):
6666
timeout=3000,
6767
)
6868

69+
# Example usage
70+
71+
scholar = OpenAlexScholar(email=settings.OPENALEX_EMAIL) # 请替换为您的真实邮箱
72+
6973
await redis_manager.publish_message(
7074
self.task_id,
7175
SystemMessage(content="创建完成"),
@@ -91,6 +95,7 @@ async def execute(self, problem: Problem):
9195
model=llm_model,
9296
comp_template=problem.comp_template,
9397
format_output=problem.format_output,
98+
scholar=scholar,
9499
)
95100

96101
################################################ solution steps

backend/app/tools/openalex_scholar.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def search_papers(self, query: str, limit: int = 10) -> List[Dict[str, Any]]:
7070
# 添加邮箱参数到请求URL
7171
if self.email:
7272
params["mailto"] = self.email
73+
else:
74+
raise ValueError("配置OpenAlex邮箱获取访问文献权利")
7375

7476
# 设置请求头,包含User-Agent和邮箱信息
7577
headers = {
@@ -146,6 +148,21 @@ def search_papers(self, query: str, limit: int = 10) -> List[Dict[str, Any]]:
146148

147149
return papers
148150

151+
def print_papers(self, papers: List[Dict[str, Any]]):
152+
for paper in papers:
153+
print("\n" + "=" * 80)
154+
print(f"标题: {paper['title']}")
155+
print(f"\n摘要: {paper['abstract']}")
156+
print("\n作者:")
157+
for author in paper["authors"]:
158+
print(f"- {author['name']}")
159+
if author["institution"]:
160+
print(f" 所属机构: {author['institution']}")
161+
print(f"\n引用次数: {paper['citations_count']}")
162+
print(f"发表年份: {paper['publication_year']}")
163+
print(f"\n引用格式:\n{paper['citation_format']}")
164+
print("=" * 80)
165+
149166
def _format_citation(self, work: Dict[str, Any]) -> str:
150167
"""Format citation in a readable format."""
151168
# 获取所有作者
@@ -176,26 +193,3 @@ def _format_citation(self, work: Dict[str, Any]) -> str:
176193
citation += f" DOI: {doi}"
177194

178195
return citation
179-
180-
181-
if __name__ == "__main__":
182-
# Example usage
183-
scholar = OpenAlexScholar(email="[email protected]") # 请替换为您的真实邮箱
184-
try:
185-
papers = scholar.search_papers("machine learning")
186-
for paper in papers:
187-
print("\n" + "=" * 80)
188-
print(f"标题: {paper['title']}")
189-
print(f"\n摘要: {paper['abstract']}")
190-
print("\n作者:")
191-
for author in paper["authors"]:
192-
print(f"- {author['name']}")
193-
if author["institution"]:
194-
print(f" 所属机构: {author['institution']}")
195-
print(f"\n引用次数: {paper['citations_count']}")
196-
print(f"发表年份: {paper['publication_year']}")
197-
print(f"\n引用格式:\n{paper['citation_format']}")
198-
print("=" * 80)
199-
except Exception as e:
200-
print(f"发生错误: {e}")
201-
print("请检查您的网络连接或API访问权限。")

0 commit comments

Comments
 (0)