Skip to content

Commit 269f4f6

Browse files
authored
feat: megamind context providers (#687)
1 parent 7d4d136 commit 269f4f6

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

src/rai_core/rai/agents/langchain/core/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414

1515
from .conversational_agent import State as ConversationalAgentState
1616
from .conversational_agent import create_conversational_agent
17-
from .megamind import Executor, create_megamind, get_initial_megamind_state
17+
from .megamind import (
18+
ContextProvider,
19+
Executor,
20+
create_megamind,
21+
get_initial_megamind_state,
22+
)
1823
from .react_agent import (
1924
ReActAgentState,
2025
create_react_runnable,
@@ -23,6 +28,7 @@
2328
from .tool_runner import SubAgentToolRunner, ToolRunner
2429

2530
__all__ = [
31+
"ContextProvider",
2632
"ConversationalAgentState",
2733
"Executor",
2834
"ReActAgentState",

src/rai_core/rai/agents/langchain/core/megamind.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
### NOTE (jmatejcz) this agent is still in process of testing and refining
16+
from abc import ABC, abstractmethod
1617
from dataclasses import dataclass
1718
from functools import partial
1819
from typing import (
@@ -185,6 +186,14 @@ class Executor:
185186
system_prompt: str
186187

187188

189+
class ContextProvider(ABC):
190+
"""Context provider are meant to inject exteral info to megamind prompt"""
191+
192+
@abstractmethod
193+
def get_context(self) -> str:
194+
pass
195+
196+
188197
def get_initial_megamind_state(task: str):
189198
return MegamindState(
190199
{
@@ -198,7 +207,11 @@ def get_initial_megamind_state(task: str):
198207
)
199208

200209

201-
def plan_step(megamind_agent: BaseChatModel, state: MegamindState) -> MegamindState:
210+
def plan_step(
211+
megamind_agent: BaseChatModel,
212+
state: MegamindState,
213+
context_providers: Optional[List[ContextProvider]] = None,
214+
) -> MegamindState:
202215
"""Initial planning step."""
203216
if "original_task" not in state:
204217
state["original_task"] = state["messages"][0].content[0]["text"]
@@ -208,6 +221,9 @@ def plan_step(megamind_agent: BaseChatModel, state: MegamindState) -> MegamindSt
208221
state["step"] = None
209222

210223
megamind_prompt = f"You are given objective to complete: {state['original_task']}"
224+
for provider in context_providers:
225+
megamind_prompt += provider.get_context()
226+
megamind_prompt += "\n"
211227
if state["steps_done"]:
212228
megamind_prompt += "\n\n"
213229
megamind_prompt += "Steps that were already done successfully:\n"
@@ -244,17 +260,27 @@ def plan_step(megamind_agent: BaseChatModel, state: MegamindState) -> MegamindSt
244260

245261
def create_megamind(
246262
megamind_llm: BaseChatModel,
247-
megamind_system_prompt: str,
248263
executors: List[Executor],
264+
megamind_system_prompt: Optional[str] = None,
249265
task_planning_prompt: Optional[str] = None,
266+
context_providers: List[ContextProvider] = [],
250267
) -> CompiledStateGraph:
251268
"""Create a megamind langchain agent
252269
253270
Args:
254271
executors (List[Executor]): Subagents for megamind, each can be a specialist with
255272
its own tools llm and system prompt
273+
274+
megamind_system_prompt (Optional[str]): Prompt for megamind node. If not provided
275+
it will default to informing agent of the avaialble executors and listing their tools.
276+
256277
task_planning_prompt (Optional[str]): Prompt that helps summarize the step in a way
257278
that helps planning task
279+
280+
context_providers (List[ContextProvider]): Each ContextProvider can inject external info
281+
to prompt during planning phase
282+
283+
258284
"""
259285
executor_agents = {}
260286
handoff_tools = []
@@ -295,7 +321,8 @@ def create_megamind(
295321
)
296322

297323
graph = StateGraph(MegamindState).add_node(
298-
"megamind", partial(plan_step, megamind_agent)
324+
"megamind",
325+
partial(plan_step, megamind_agent, context_providers=context_providers),
299326
)
300327
for agent_name, agent in executor_agents.items():
301328
graph.add_node(agent_name, agent)

0 commit comments

Comments
 (0)