diff --git a/docs/agents/callbacks.md b/docs/agents/callbacks.md index cf39beaf..c1eb63b9 100644 --- a/docs/agents/callbacks.md +++ b/docs/agents/callbacks.md @@ -54,6 +54,26 @@ property so that callbacks can access information in a framework-agnostic way. You can see what attributes are available for LLM Calls and Tool Executions by examining the [`GenAI`][any_agent.tracing.attributes.GenAI] class. +### Framework State + +In addition to the span attributes, callbacks can access and modify framework-specific objects through [`Context.framework_state`][any_agent.callbacks.context.Context.framework_state]. + +This allows callbacks to directly manipulate the agent's execution, such as: + +- Modifying messages before they're sent to the LLM +- Injecting prompts mid-execution +- Changing user queries dynamically + +#### Helper Methods + +The `framework_state` provides helper methods to work with messages in a normalized format: + +**`get_messages()`**: Get messages as a list of dicts with `role` and `content` keys + +**`set_messages()`**: Set messages from a list of dicts with `role` and `content` keys + +These methods handle framework-specific message formats internally, providing a consistent API across frameworks. + ## Implementing Callbacks All callbacks must inherit from the base [`Callback`][any_agent.callbacks.base.Callback] class and can choose to implement any subset of the available callback methods. These methods include: @@ -136,7 +156,7 @@ Callbacks are provided to the agent using the [`AgentConfig.callbacks`][any_agen agent = AnyAgent.create( "tinyagent", AgentConfig( - model_id="gpt-4.1-nano", + model_id="openai:gpt-4.1-nano", instructions="Use the tools to find an answer", tools=[search_web, visit_webpage], callbacks=[ @@ -157,7 +177,7 @@ Callbacks are provided to the agent using the [`AgentConfig.callbacks`][any_agen agent = AnyAgent.create( "tinyagent", AgentConfig( - model_id="gpt-4.1-nano", + model_id="openai:gpt-4.1-nano", instructions="Use the tools to find an answer", tools=[search_web, visit_webpage], callbacks=[ @@ -272,3 +292,55 @@ class LimitToolExecutions(Callback): return context ``` + +## Example: Modifying prompts dynamically + +You can use callbacks to modify the prompt being sent to the LLM. This is useful for injecting instructions or reminders mid-execution: + +```python +from any_agent.callbacks.base import Callback +from any_agent.callbacks.context import Context + +class InjectReminderCallback(Callback): + def __init__(self, reminder: str, every_n_calls: int = 5): + self.reminder = reminder + self.every_n_calls = every_n_calls + self.call_count = 0 + + def before_llm_call(self, context: Context, *args, **kwargs) -> Context: + self.call_count += 1 + + if self.call_count % self.every_n_calls == 0: + try: + messages = context.framework_state.get_messages() + if messages: + messages[-1]["content"] += f"\n\n{self.reminder}" + context.framework_state.set_messages(messages) + except NotImplementedError: + pass + + return context +``` + +Example usage: + +```python +from any_agent import AgentConfig, AnyAgent + +callback = InjectReminderCallback( + reminder="Remember to use the Todo tool to track your tasks!", + every_n_calls=5 +) + +config = AgentConfig( + model_id="openai:gpt-4o-mini", + instructions="You are a helpful assistant.", + callbacks=[callback], +) +# ... Continue to create and run agent +``` + +!!! tip + + Use try/except to gracefully handle frameworks that don't support message modification yet. The callback will simply skip modification for unsupported frameworks. +``` diff --git a/docs/api/callbacks.md b/docs/api/callbacks.md index ebca1863..ed827a0e 100644 --- a/docs/api/callbacks.md +++ b/docs/api/callbacks.md @@ -4,6 +4,8 @@ ::: any_agent.callbacks.context.Context +::: any_agent.callbacks.context.FrameworkState + ::: any_agent.callbacks.span_print.ConsolePrintSpan ::: any_agent.callbacks.get_default_callbacks diff --git a/src/any_agent/callbacks/__init__.py b/src/any_agent/callbacks/__init__.py index 7e26da88..836625a4 100644 --- a/src/any_agent/callbacks/__init__.py +++ b/src/any_agent/callbacks/__init__.py @@ -1,8 +1,8 @@ from .base import Callback -from .context import Context +from .context import Context, FrameworkState from .span_print import ConsolePrintSpan -__all__ = ["Callback", "ConsolePrintSpan", "Context"] +__all__ = ["Callback", "ConsolePrintSpan", "Context", "FrameworkState"] def get_default_callbacks() -> list[Callback]: diff --git a/src/any_agent/callbacks/context.py b/src/any_agent/callbacks/context.py index e3479991..7574bdb2 100644 --- a/src/any_agent/callbacks/context.py +++ b/src/any_agent/callbacks/context.py @@ -1,14 +1,87 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from collections.abc import Callable + from opentelemetry.trace import Span, Tracer from any_agent.tracing.agent_trace import AgentTrace +@dataclass +class FrameworkState: + """Framework-specific state that can be accessed and modified by callbacks. + + This object provides a consistent interface for accessing framework state across + different agent frameworks, while the actual content is framework-specific. + """ + + messages: list[dict[str, Any]] = field(default_factory=list) + """Internal storage for messages. Use get_messages() and set_messages() instead.""" + + _message_getter: Callable[[], list[dict[str, Any]]] | None = field( + default=None, repr=False + ) + """Framework-specific message getter function.""" + + _message_setter: Callable[[list[dict[str, Any]]], None] | None = field( + default=None, repr=False + ) + """Framework-specific message setter function.""" + + def get_messages(self) -> list[dict[str, Any]]: + """Get messages in a normalized dict format. + + Returns a list of message dicts with 'role' and 'content' keys. + Works consistently across all frameworks. + + Returns: + List of message dicts with 'role' and 'content' keys. + + Raises: + NotImplementedError: If the framework doesn't support message access yet. + + Example: + ```python + messages = context.framework_state.get_messages() + # [{"role": "user", "content": "Hello"}] + ``` + + """ + if self._message_getter is None: + msg = "get_messages() is not implemented for this framework yet" + raise NotImplementedError(msg) + return self._message_getter() + + def set_messages(self, messages: list[dict[str, Any]]) -> None: + """Set messages from a normalized dict format. + + Accepts a list of message dicts with 'role' and 'content' keys and + converts them to the framework-specific format. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + + Raises: + NotImplementedError: If the framework doesn't support message modification yet. + + Example: + ```python + messages = context.framework_state.get_messages() + messages[-1]["content"] = "Say hello" + context.framework_state.set_messages(messages) + ``` + + """ + if self._message_setter is None: + msg = "set_messages() is not implemented for this framework yet" + raise NotImplementedError(msg) + self._message_setter(messages) + + @dataclass class Context: """Object that will be shared across callbacks. @@ -31,3 +104,20 @@ class Context: shared: dict[str, Any] """Can be used to store arbitrary information for sharing across callbacks.""" + + framework_state: FrameworkState + """Framework-specific state that can be accessed and modified by callbacks. + + Provides consistent access to framework state across different agent frameworks. + See [`FrameworkState`][any_agent.callbacks.context.FrameworkState] for available attributes. + + Example: + ```python + class ModifyPromptCallback(Callback): + def before_llm_call(self, context: Context, *args, **kwargs) -> Context: + # Modify the last message content + if context.framework_state.messages: + context.framework_state.messages[-1]["content"] = "Say hello" + return context + ``` + """ diff --git a/src/any_agent/callbacks/wrappers/tinyagent.py b/src/any_agent/callbacks/wrappers/tinyagent.py index 47183706..a473cf5e 100644 --- a/src/any_agent/callbacks/wrappers/tinyagent.py +++ b/src/any_agent/callbacks/wrappers/tinyagent.py @@ -26,6 +26,20 @@ async def wrap_call_model(**kwargs): context = self.callback_context[ get_current_span().get_span_context().trace_id ] + + if "messages" in kwargs: + context.framework_state.messages = kwargs["messages"] + + def get_messages(): + return context.framework_state.messages + + def set_messages(messages): + context.framework_state.messages = messages + kwargs["messages"] = messages + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages + for callback in agent.config.callbacks: context = callback.before_llm_call(context, **kwargs) @@ -42,6 +56,7 @@ async def wrapped_tool_execution(original_call, request): context = self.callback_context[ get_current_span().get_span_context().trace_id ] + for callback in agent.config.callbacks: context = callback.before_tool_execution(context, request) diff --git a/src/any_agent/frameworks/any_agent.py b/src/any_agent/frameworks/any_agent.py index b1a1da41..34971fe2 100644 --- a/src/any_agent/frameworks/any_agent.py +++ b/src/any_agent/frameworks/any_agent.py @@ -8,7 +8,7 @@ from any_llm.utils.aio import run_async_in_sync from opentelemetry import trace as otel_trace -from any_agent.callbacks.context import Context +from any_agent.callbacks.context import Context, FrameworkState from any_agent.callbacks.wrappers import ( _get_wrapper_by_framework, ) @@ -217,6 +217,7 @@ async def run_async(self, prompt: str, **kwargs: Any) -> AgentTrace: trace=AgentTrace(), tracer=self._tracer, shared={}, + framework_state=FrameworkState(), ) if len(self._wrapper.callback_context) == 1: diff --git a/tests/integration/callbacks/__init__.py b/tests/integration/callbacks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/callbacks/test_framework_state.py b/tests/integration/callbacks/test_framework_state.py new file mode 100644 index 00000000..971c756d --- /dev/null +++ b/tests/integration/callbacks/test_framework_state.py @@ -0,0 +1,39 @@ +from any_agent import AgentConfig, AnyAgent +from any_agent.callbacks import Callback, Context +from any_agent.config import AgentFramework +from any_agent.testing.helpers import DEFAULT_SMALL_MODEL_ID +from typing import Any + + +class LLMInputModifier(Callback): + """Callback that modifies LLM input messages.""" + + def before_llm_call(self, context: Context, *args: Any, **kwargs: Any) -> Context: + messages = context.framework_state.get_messages() + messages[-1]["content"] = "Say hello" + context.framework_state.set_messages(messages) + return context + + +async def test_modify_llm_input(agent_framework: AgentFramework) -> None: + """Test that framework_state message modification works via helper methods.""" + modifier = LLMInputModifier() + config = AgentConfig( + model_id=DEFAULT_SMALL_MODEL_ID, + instructions="You are a helpful assistant.", + callbacks=[modifier], + ) + + agent = await AnyAgent.create_async(agent_framework, config) + + try: + result = await agent.run_async("Say goodbye") + assert result.final_output is not None + assert isinstance(result.final_output, str) + + assert "hello" in result.final_output.lower(), ( + "Expected 'hello' in the final output" + ) + + finally: + await agent.cleanup_async()