|
8 | 8 | from a2a.types import AgentSkill, Message |
9 | 9 | from gpt_researcher import GPTResearcher |
10 | 10 |
|
| 11 | +from gpt_researcher_agent.env_patch import with_local_env |
11 | 12 |
|
12 | | -from beeai_sdk.a2a.extensions import TrajectoryExtensionServer, TrajectoryExtensionSpec, AgentDetail |
| 13 | +from beeai_sdk.a2a.extensions import ( |
| 14 | + TrajectoryExtensionServer, |
| 15 | + TrajectoryExtensionSpec, |
| 16 | + AgentDetail, |
| 17 | + LLMServiceExtensionServer, |
| 18 | + LLMServiceExtensionSpec, |
| 19 | + EmbeddingServiceExtensionServer, |
| 20 | + EmbeddingServiceExtensionSpec, |
| 21 | +) |
13 | 22 | from beeai_sdk.a2a.types import RunYield |
14 | 23 | from beeai_sdk.server import Server |
15 | 24 | from beeai_sdk.server.context import RunContext |
|
18 | 27 |
|
19 | 28 |
|
20 | 29 | @server.agent( |
21 | | - name="GPT Researcher", |
| 30 | + name="GPT Researcher 2", |
22 | 31 | documentation_url=( |
23 | 32 | f"https://github.com/i-am-bee/beeai-platform/blob/{os.getenv('RELEASE_VERSION', 'main')}" |
24 | 33 | "/agents/community/gpt-researcher" |
25 | 34 | ), |
26 | | - detail=AgentDetail( |
27 | | - interaction_mode="single-turn", |
28 | | - user_greeting="What topic do you want to research?", |
29 | | - use_cases=[ |
30 | | - "**Comprehensive Research** – Generates detailed reports using information from multiple sources.", |
31 | | - "**Bias Reduction** – Cross-references data from various platforms to minimize misinformation and bias.", |
32 | | - "**High Performance** – Utilizes parallelized processes for efficient and swift report generation.", |
33 | | - "**Customizable** – Offers customization options to tailor research for specific domains or tasks.", |
34 | | - ], |
35 | | - ), |
| 35 | + detail=AgentDetail(interaction_mode="single-turn", user_greeting="What topic do you want to research?"), |
36 | 36 | skills=[ |
37 | 37 | AgentSkill( |
38 | 38 | id="deep_research", |
|
59 | 59 | ], |
60 | 60 | ) |
61 | 61 | async def gpt_researcher( |
62 | | - message: Message, context: RunContext, trajectory: Annotated[TrajectoryExtensionServer, TrajectoryExtensionSpec()] |
| 62 | + message: Message, |
| 63 | + context: RunContext, |
| 64 | + trajectory: Annotated[TrajectoryExtensionServer, TrajectoryExtensionSpec()], |
| 65 | + llm_ext: Annotated[LLMServiceExtensionServer, LLMServiceExtensionSpec.single_demand()], |
| 66 | + embedding_ext: Annotated[EmbeddingServiceExtensionServer, EmbeddingServiceExtensionSpec.single_demand()], |
63 | 67 | ) -> AsyncGenerator[RunYield, None]: |
64 | 68 | """ |
65 | 69 | The agent conducts in-depth local and web research using a language model to generate comprehensive reports with |
66 | 70 | citations, aimed at delivering factual, unbiased information. |
67 | 71 | """ |
68 | | - os.environ["RETRIEVER"] = "duckduckgo" |
69 | | - os.environ["OPENAI_BASE_URL"] = os.getenv("LLM_API_BASE", "http://localhost:11434/v1") |
70 | | - os.environ["OPENAI_API_KEY"] = os.getenv("LLM_API_KEY", "dummy") |
71 | | - model = os.getenv("LLM_MODEL", "llama3.1") |
72 | | - os.environ["LLM_MODEL"] = model |
73 | | - |
74 | | - class CustomLogsHandler: |
75 | | - async def send_json(self, data: dict[str, Any]) -> None: |
76 | | - if "output" not in data: |
77 | | - return |
78 | | - match data.get("type"): |
79 | | - case "logs": |
80 | | - await context.yield_async( |
81 | | - trajectory.trajectory_metadata(title="log", content=f"{data['output']}\n") |
82 | | - ) |
83 | | - case "report": |
84 | | - await context.yield_async(data["output"]) |
85 | | - |
86 | | - if not message.parts or not (query := message.parts[-1].root.text): |
87 | | - yield "Please enter a topic or query." |
88 | | - return |
89 | | - |
90 | | - researcher = GPTResearcher(query=query, report_type="research_report", websocket=CustomLogsHandler()) |
91 | | - await researcher.conduct_research() |
92 | | - await researcher.write_report() |
| 72 | + # Set up local environment for this request |
| 73 | + |
| 74 | + llm_conf, embedding_conf = None, None |
| 75 | + if llm_ext and llm_ext.data: |
| 76 | + [llm_conf] = llm_ext.data.llm_fulfillments.values() |
| 77 | + |
| 78 | + if embedding_ext and embedding_ext.data: |
| 79 | + [embedding_conf] = embedding_ext.data.embedding_fulfillments.values() |
| 80 | + |
| 81 | + model = llm_conf.api_model if llm_conf else os.getenv("LLM_MODEL", "dummy") |
| 82 | + embedding_model = embedding_conf.api_model if embedding_conf else os.getenv("EMBEDDING_MODEL", "dummy") |
| 83 | + |
| 84 | + env = { |
| 85 | + "RETRIEVER": "duckduckgo", |
| 86 | + "OPENAI_BASE_URL": llm_conf.api_base if llm_conf else os.getenv("LLM_API_BASE", "http://localhost:11434/v1"), |
| 87 | + "OPENAI_API_KEY": llm_conf.api_key if llm_conf else os.getenv("LLM_API_KEY", "dummy"), |
| 88 | + "LLM_MODEL": model, |
| 89 | + "EMBEDDING": f"openai:{embedding_model}", |
| 90 | + "FAST_LLM": f"openai:{model}", |
| 91 | + "SMART_LLM": f"openai:{model}", |
| 92 | + } |
| 93 | + with with_local_env(env): |
| 94 | + |
| 95 | + class CustomLogsHandler: |
| 96 | + async def send_json(self, data: dict[str, Any]) -> None: |
| 97 | + if "output" not in data: |
| 98 | + return |
| 99 | + match data.get("type"): |
| 100 | + case "logs": |
| 101 | + await context.yield_async( |
| 102 | + trajectory.trajectory_metadata(title="log", content=f"{data['output']}\n") |
| 103 | + ) |
| 104 | + case "report": |
| 105 | + await context.yield_async(data["output"]) |
| 106 | + |
| 107 | + if not message.parts or not (query := message.parts[-1].root.text): |
| 108 | + yield "Please enter a topic or query." |
| 109 | + return |
| 110 | + |
| 111 | + researcher = GPTResearcher(query=query, report_type="research_report", websocket=CustomLogsHandler()) |
| 112 | + await researcher.conduct_research() |
| 113 | + await researcher.write_report() |
93 | 114 |
|
94 | 115 |
|
95 | 116 | def run(): |
|
0 commit comments