-
Notifications
You must be signed in to change notification settings - Fork 354
feat[agent]: add BaseAgent and standardize planner, memory, tool mana… #385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,23 @@ | ||
| from .react import DEFAULT_REACT_AGENT_SYSTEM_PROMPT, ReActAgent | ||
| from adalflow.utils.registry import EntityMapping | ||
| """Agent components for AdalFlow.""" | ||
|
|
||
| from .react import ReActAgent as LegacyReActAgent | ||
| from .react_agent import ReActAgent as NewReActAgent | ||
| from .base_agent import ( | ||
| BaseAgent, | ||
| BasePlanner, | ||
| BaseToolManager, | ||
| BaseMemory, | ||
| Step, | ||
| AgentOutput, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "ReActAgent", | ||
| "DEFAULT_REACT_AGENT_SYSTEM_PROMPT", | ||
| "LegacyReActAgent", # Old implementation for backward compatibility | ||
| "NewReActAgent", # New implementation using base agent | ||
| "BaseAgent", # Base agent class | ||
| "BasePlanner", # Base planner interface | ||
| "BaseToolManager", # Base tool manager interface | ||
| "BaseMemory", # Base memory interface | ||
| "Step", # Step data class | ||
| "AgentOutput", # Output data class | ||
| ] | ||
|
|
||
| for name in __all__: | ||
| EntityMapping.register(name, globals()[name]) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,231 @@ | ||
| """Base agent implementation with standardized interfaces.""" | ||
|
|
||
| from typing import List, Union, Callable, Optional, Any, Dict | ||
| from dataclasses import dataclass, field | ||
| from adalflow.core.base_data_class import DataClass | ||
| import logging | ||
|
|
||
| from adalflow.core.func_tool import FunctionTool, AsyncCallable | ||
| from adalflow.core.component import Component | ||
| from adalflow.core.types import ( | ||
| Function, | ||
| ) | ||
| from adalflow.core.model_client import ModelClient | ||
| from adalflow.utils.logger import printc | ||
|
|
||
| log = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @dataclass | ||
| class Step(DataClass): | ||
| """Standardized step structure for all agents.""" | ||
|
|
||
| step_number: int = field(metadata={"desc": "The step number"}) | ||
| action: Optional[Function] = field( | ||
| default=None, metadata={"desc": "The action taken in this step"} | ||
| ) | ||
| observation: Any = field( | ||
| default=None, metadata={"desc": "The observation from this step"} | ||
| ) | ||
| metadata: Dict = field( | ||
| default_factory=dict, metadata={"desc": "Additional metadata for this step"} | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class AgentOutput(DataClass): | ||
| """Standardized output structure for all agents.""" | ||
|
|
||
| id: Optional[str] = field( | ||
| default=None, metadata={"desc": "The unique id of the output"} | ||
| ) | ||
| step_history: List[Step] = field( | ||
| metadata={"desc": "The history of steps."}, default_factory=list | ||
| ) | ||
| answer: Any = field(metadata={"desc": "The final answer."}, default=None) | ||
| metadata: Dict = field( | ||
| default_factory=dict, metadata={"desc": "Additional metadata"} | ||
| ) | ||
|
|
||
| def validate(self) -> bool: | ||
| """Validate the output structure.""" | ||
| if not isinstance(self.step_history, list): | ||
| return False | ||
| if not all(isinstance(step, Step) for step in self.step_history): | ||
| return False | ||
| return True | ||
|
|
||
|
|
||
| class BasePlanner(Component): | ||
| """Base interface for planning strategies.""" | ||
|
|
||
| def __init__(self, model_client: ModelClient, model_kwargs: Dict = {}): | ||
| super().__init__() | ||
| self.model_client = model_client | ||
| self.model_kwargs = model_kwargs | ||
|
|
||
| def plan(self, input: str, context: Dict) -> Function: | ||
| """Plan the next action based on input and context.""" | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| class BaseToolManager(Component): | ||
| """Base interface for tool management.""" | ||
|
|
||
| def __init__(self, tools: List[Union[Callable, AsyncCallable, FunctionTool]]): | ||
| super().__init__() | ||
| self.tools = tools | ||
|
|
||
| def execute(self, action: Function) -> Any: | ||
| """Execute an action using the appropriate tool.""" | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| class BaseMemory(Component): | ||
| """Base interface for memory management.""" | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| self.steps: List[Step] = [] | ||
|
|
||
| def store(self, step: Step) -> None: | ||
| """Store a step in memory.""" | ||
| self.steps.append(step) | ||
|
|
||
| def retrieve(self, query: str) -> List[Step]: | ||
| """Retrieve relevant steps from memory.""" | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| class BaseAgent(Component): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. another complexity of the agent is at the auto-optimization, we can make this new agent experimental and it might need multiple iterations to get more mature |
||
| """Base agent class with standardized interfaces.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| planner: BasePlanner, | ||
| tool_manager: BaseToolManager, | ||
| memory: Optional[BaseMemory] = None, | ||
| max_steps: int = 10, | ||
| context_variables: Optional[Dict] = None, | ||
| use_cache: bool = True, | ||
| debug: bool = False, | ||
| ): | ||
| super().__init__() | ||
| self.planner = planner | ||
| self.tool_manager = tool_manager | ||
| self.memory = memory | ||
| self.max_steps = max_steps | ||
| self.context_variables = context_variables | ||
| self.use_cache = use_cache | ||
| self.debug = debug | ||
|
|
||
| def _handle_training(self, step: Step) -> Step: | ||
| """Handle training mode specific logic.""" | ||
| if not self.training: | ||
| return step | ||
| # Add training specific logic here | ||
| return step | ||
|
|
||
| def _handle_evaluation(self, step: Step) -> Step: | ||
| """Handle evaluation mode specific logic.""" | ||
| return step | ||
|
|
||
| def _format_output(self, step_history: List[Step], answer: Any) -> AgentOutput: | ||
| """Format the final output.""" | ||
| return AgentOutput( | ||
| step_history=step_history, answer=answer, metadata=self._get_metadata() | ||
| ) | ||
|
|
||
| def _get_metadata(self) -> Dict: | ||
| """Get metadata for the output.""" | ||
| return { | ||
| "max_steps": self.max_steps, | ||
| "use_cache": self.use_cache, | ||
| "context_variables": self.context_variables, | ||
| } | ||
|
|
||
| def _run_one_step( | ||
| self, | ||
| step_number: int, | ||
| input: str, | ||
| context: Dict, | ||
| step_history: List[Step] = [], | ||
| ) -> Step: | ||
| """Run one step of the agent.""" | ||
| if self.debug: | ||
| printc(f"Running step {step_number}", color="yellow") | ||
|
|
||
| # Plan the next action | ||
| action = self.planner.plan(input, context) | ||
|
|
||
| # Execute the action | ||
| observation = self.tool_manager.execute(action) | ||
|
|
||
| # Create step | ||
| step = Step( | ||
| step_number=step_number, | ||
| action=action, | ||
| observation=observation, | ||
| metadata={"context": context}, | ||
| ) | ||
|
|
||
| # Handle training/evaluation mode | ||
| if self.training: | ||
| step = self._handle_training(step) | ||
| else: | ||
| step = self._handle_evaluation(step) | ||
|
|
||
| # Store in memory if available | ||
| if self.memory: | ||
| self.memory.store(step) | ||
|
|
||
| return step | ||
|
|
||
| def call(self, input: str, **kwargs) -> AgentOutput: | ||
| """Main entry point for the agent.""" | ||
| step_history: List[Step] = [] | ||
| context = { | ||
| "input": input, | ||
| "step_history": step_history, | ||
| **(self.context_variables or {}), | ||
| **kwargs, | ||
| } | ||
|
|
||
| for step_number in range(1, self.max_steps + 1): | ||
| step = self._run_one_step( | ||
| step_number=step_number, | ||
| input=input, | ||
| context=context, | ||
| step_history=step_history, | ||
| ) | ||
| step_history.append(step) | ||
|
|
||
| # Check if we should stop | ||
| if self._should_stop(step): | ||
| break | ||
|
|
||
| # Get final answer | ||
| answer = self._get_answer(step_history) | ||
|
|
||
| # Format and return output | ||
| output = self._format_output(step_history, answer) | ||
| if not output.validate(): | ||
| raise ValueError("Invalid output format") | ||
|
|
||
| return output | ||
|
|
||
| def _should_stop(self, step: Step) -> bool: | ||
| """Determine if the agent should stop.""" | ||
| raise NotImplementedError | ||
|
|
||
| def _get_answer(self, step_history: List[Step]) -> Any: | ||
| """Get the final answer from step history.""" | ||
| raise NotImplementedError | ||
|
|
||
| def train_step(self, input: str, target: Any) -> Dict: | ||
| """Standard training step interface.""" | ||
| raise NotImplementedError | ||
|
|
||
| def eval_step(self, input: str) -> AgentOutput: | ||
| """Standard evaluation step interface.""" | ||
| raise NotImplementedError | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| """Example demonstrating ReAct agent with vector memory support.""" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. examples should not be in the package, please move it to tutorials |
||
|
|
||
| from adalflow.components.agent.react_agent import ReActAgent | ||
| from adalflow.core.func_tool import FunctionTool | ||
| from adalflow.components.model_client import OpenAIClient | ||
| from adalflow.core.types import Function | ||
| import logging | ||
|
|
||
| # from adalflow.components.memory import Memory | ||
| from adalflow.components.memory.memory import Memory | ||
|
|
||
| from dotenv import load_dotenv | ||
|
|
||
| load_dotenv() | ||
| # Set up logging | ||
| logging.basicConfig(level=logging.INFO) | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def main(): | ||
| # Initialize model client | ||
| model_client = OpenAIClient() | ||
| model_kwargs = { | ||
| "model": "gpt-3.5-turbo", | ||
| "temperature": 0.7, | ||
| } | ||
| memory = Memory() | ||
|
|
||
| # Define some example tools | ||
| def calculate(expression: str, **kwargs) -> str: | ||
| """Calculate the result of a mathematical expression.""" | ||
| try: | ||
| return str(eval(expression)) | ||
| except Exception as e: | ||
| return f"Error calculating: {str(e)}" | ||
|
|
||
| def get_factorial(n: int, **kwargs) -> str: | ||
| """Calculate factorial of a number.""" | ||
| try: | ||
| result = 1 | ||
| for i in range(1, n + 1): | ||
| result *= i | ||
| return str(result) | ||
| except Exception as e: | ||
| return f"Error calculating factorial: {str(e)}" | ||
|
|
||
| def finish(answer: str, **kwargs) -> str: | ||
| """Finish the conversation with a final answer.""" | ||
| return answer | ||
|
|
||
| def extract_result(history: str, **kwargs) -> str: | ||
| """Get the previous context.""" | ||
| return memory.call() | ||
|
|
||
| def square(n: int, **kwargs) -> str: | ||
| """Square a number.""" | ||
| return str(n * n) | ||
|
|
||
| # Create example functions for the agent | ||
| examples = [ | ||
| Function( | ||
| thought="I need to calculate a simple arithmetic expression.", | ||
| name="calculate", | ||
| kwargs={"expression": "2 + 2"}, | ||
| ), | ||
| Function( | ||
| thought="I need to calculate the factorial of a number 5.", | ||
| name="get_factorial", | ||
| kwargs={"n": 5}, | ||
| ), | ||
| Function( | ||
| thought="I need to context data of previous conversation.", | ||
| name="extract_result", | ||
| kwargs={"history": "history"}, | ||
| ), | ||
| Function( | ||
| thought="I need to square a number 3.", | ||
| name="square", | ||
| kwargs={"n": 3}, | ||
| ), | ||
| ] | ||
| # Create function tools | ||
| calc_tool = FunctionTool(calculate) | ||
| factorial_tool = FunctionTool(get_factorial) | ||
| extract_result_tool = FunctionTool(extract_result) | ||
| square_tool = FunctionTool(square) | ||
|
|
||
| # Create ReAct agent with vector memory | ||
| agent = ReActAgent( | ||
| tools=[ | ||
| calc_tool, | ||
| factorial_tool, | ||
| extract_result_tool, | ||
| square_tool, | ||
| ], | ||
| model_client=model_client, | ||
| model_kwargs=model_kwargs, | ||
| add_llm_as_fallback=True, | ||
| max_steps=5, | ||
| examples=examples, | ||
| debug=True, # Enable debug output | ||
| ) | ||
|
|
||
| # Example 1: Simple calculation | ||
| logger.info("Example 1: Simple calculation") | ||
| print("MEMORY_CALL", memory.call()) | ||
| agent.context_variables = {"context_variables": {"history": memory.call()}} | ||
| result1 = agent("What is 2 + 1?") | ||
| logger.info(f"Result 1: {result1.answer}") | ||
| memory.add_dialog_turn("What is 2 + 1?", result1.step_history) | ||
|
|
||
| # Example 2: Using previous context | ||
| print("MEMORY_CALL", memory.call()) | ||
| logger.info("\nExample 2: Using previous context") | ||
| agent.context_variables = {"context_variables": {"history": memory.call()}} | ||
| result2 = agent("what is the result of the previous question?") | ||
| logger.info(f"Result 2: {result2.answer}") | ||
| memory.add_dialog_turn( | ||
| "what is the result of the previous question?", result2.step_history | ||
| ) | ||
|
|
||
| # Example 2: Using previous context | ||
| print("MEMORY_CALL", memory.call()) | ||
| logger.info("\nExample 3: Factorial of previous final result") | ||
| agent.context_variables = {"context_variables": {"history": memory.call()}} | ||
| result3 = agent("what is the square of the previous final result?") | ||
| logger.info(f"Result 3: {result3.answer}") | ||
| memory.add_dialog_turn( | ||
| "what is the square of the previous final result?", result3.step_history | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we dont need another base tool manager, the toolmanager is already abstracted enough.
The goal is not to create more than 3 levels of abstraction especially those not doing much work.
The abstraction of the agent should be more focused on the design pattern of agents, such as openai's agent sdk, where it makes it easier to create an agent from user's experience and still remain the control over the prompts if the user wants to modify it.
also we are trying to make it easy to hand over the context (memory) and the step history to all other agents (as a tool).
i have some code and i can share later.
we can do another quick sync call next week.
you can check out openai agent sdk before that.