13
13
# limitations under the License.
14
14
15
15
### NOTE (jmatejcz) this agent is still in process of testing and refining
16
+ from abc import ABC , abstractmethod
16
17
from dataclasses import dataclass
17
18
from functools import partial
18
19
from typing import (
@@ -185,6 +186,14 @@ class Executor:
185
186
system_prompt : str
186
187
187
188
189
+ class ContextProvider (ABC ):
190
+ """Context provider are meant to inject exteral info to megamind prompt"""
191
+
192
+ @abstractmethod
193
+ def get_context (self ) -> str :
194
+ pass
195
+
196
+
188
197
def get_initial_megamind_state (task : str ):
189
198
return MegamindState (
190
199
{
@@ -198,7 +207,11 @@ def get_initial_megamind_state(task: str):
198
207
)
199
208
200
209
201
- def plan_step (megamind_agent : BaseChatModel , state : MegamindState ) -> MegamindState :
210
+ def plan_step (
211
+ megamind_agent : BaseChatModel ,
212
+ state : MegamindState ,
213
+ context_providers : Optional [List [ContextProvider ]] = None ,
214
+ ) -> MegamindState :
202
215
"""Initial planning step."""
203
216
if "original_task" not in state :
204
217
state ["original_task" ] = state ["messages" ][0 ].content [0 ]["text" ]
@@ -208,6 +221,9 @@ def plan_step(megamind_agent: BaseChatModel, state: MegamindState) -> MegamindSt
208
221
state ["step" ] = None
209
222
210
223
megamind_prompt = f"You are given objective to complete: { state ['original_task' ]} "
224
+ for provider in context_providers :
225
+ megamind_prompt += provider .get_context ()
226
+ megamind_prompt += "\n "
211
227
if state ["steps_done" ]:
212
228
megamind_prompt += "\n \n "
213
229
megamind_prompt += "Steps that were already done successfully:\n "
@@ -244,17 +260,27 @@ def plan_step(megamind_agent: BaseChatModel, state: MegamindState) -> MegamindSt
244
260
245
261
def create_megamind (
246
262
megamind_llm : BaseChatModel ,
247
- megamind_system_prompt : str ,
248
263
executors : List [Executor ],
264
+ megamind_system_prompt : Optional [str ] = None ,
249
265
task_planning_prompt : Optional [str ] = None ,
266
+ context_providers : List [ContextProvider ] = [],
250
267
) -> CompiledStateGraph :
251
268
"""Create a megamind langchain agent
252
269
253
270
Args:
254
271
executors (List[Executor]): Subagents for megamind, each can be a specialist with
255
272
its own tools llm and system prompt
273
+
274
+ megamind_system_prompt (Optional[str]): Prompt for megamind node. If not provided
275
+ it will default to informing agent of the avaialble executors and listing their tools.
276
+
256
277
task_planning_prompt (Optional[str]): Prompt that helps summarize the step in a way
257
278
that helps planning task
279
+
280
+ context_providers (List[ContextProvider]): Each ContextProvider can inject external info
281
+ to prompt during planning phase
282
+
283
+
258
284
"""
259
285
executor_agents = {}
260
286
handoff_tools = []
@@ -295,7 +321,8 @@ def create_megamind(
295
321
)
296
322
297
323
graph = StateGraph (MegamindState ).add_node (
298
- "megamind" , partial (plan_step , megamind_agent )
324
+ "megamind" ,
325
+ partial (plan_step , megamind_agent , context_providers = context_providers ),
299
326
)
300
327
for agent_name , agent in executor_agents .items ():
301
328
graph .add_node (agent_name , agent )
0 commit comments