diff --git a/docs/api_reference/agents/index.md b/docs/api_reference/agents/index.md index 2608a8eb4..ec2a0d655 100644 --- a/docs/api_reference/agents/index.md +++ b/docs/api_reference/agents/index.md @@ -9,3 +9,5 @@ ::: ragbits.agents.AgentResultStreaming ::: ragbits.agents.a2a.server.create_agent_server + +::: ragbits.agents.post_processors.base diff --git a/docs/how-to/agents/use_post_processors.md b/docs/how-to/agents/use_post_processors.md new file mode 100644 index 000000000..aef223be7 --- /dev/null +++ b/docs/how-to/agents/use_post_processors.md @@ -0,0 +1,64 @@ +# How-To: Use Post-Processors with Ragbits Agents + +Ragbits Agents can be enhanced with post-processors to intercept, log, filter, and modify their outputs. This guide explains how to implement and use post-processors to customize agent responses. + +## Post-Processors Overview + +Ragbits provides two types of post-processors: + +- **PostProcessor**: Processes the final output after generation, ideal for batch processing. +- **StreamingPostProcessor**: Processes outputs as they are generated, suitable for real-time applications. + +### Implementing a custom Post-Processor + +To create a custom post-processor, inherit from the appropriate base class ([`PostProcessor`][ragbits.agents.post_processors.base.PostProcessor] or [`StreamingPostProcessor`][ragbits.agents.post_processors.base.StreamingPostProcessor]) and implement the required method. + +#### Post-Processor Example + +A non-streaming post-processor applies transformations after the entire content is generated. + +```python +class TruncateProcessor(PostProcessor): + def __init__(self, max_length: int = 50) -> None: + self.max_length = max_length + + async def process(self, result, agent): + content = result.content + if len(content) > self.max_length: + content = content[:self.max_length] + "... [TRUNCATED]" + result.content = content + return result +``` + +#### Streaming Post-Processor Example + +A streaming post-processor can manipulate all information returned during generation, including text, tool calls, etc. + +```python +class UpperCaseStreamingProcessor(StreamingPostProcessor): + async def process_streaming(self, chunk, agent): + if isinstance(chunk, str): + return chunk.upper() + return chunk +``` + +## Using Post-Processors + +To use post-processors, pass them to the `run` or `run_streaming` methods of the `Agent` class. If you pass a non-streaming processor to `run_streaming`, set `allow_non_streaming=True`. This allows streaming processors to handle content piece by piece during generation, while non-streaming processors apply transformations after the entire output is generated. + +```python +async def main() -> None: + llm = LiteLLM("gpt-4.1-mini") + agent = Agent(llm=llm, prompt="You are a helpful assistant.") + post_processors = [ + UpperCaseStreamingProcessor(), + TruncateProcessor(max_length=50), + ] + stream_result = agent.run_streaming("Tell me about the history of AI.", post_processors=post_processors, allow_non_streaming=True) + async for chunk in stream_result: + if isinstance(chunk, str): + print(chunk, end="") + print(f"\nFinal answer:\n{stream_result.content}") +``` + +Post-processors offer a flexible way to tailor agent outputs, whether filtering content in real-time or transforming final outputs. diff --git a/examples/README.md b/examples/README.md index 981a2dc7a..a5291f3a0 100644 --- a/examples/README.md +++ b/examples/README.md @@ -20,6 +20,7 @@ All necessary details are provided in the comments at the top of each script. | [Multimodal Prompt with PDF Input](/examples/core/prompt/multimodal_with_pdf.py) | [ragbits-core](/packages/ragbits-core) | Example of how to use the `Prompt` class to answer the question using an LLM with both text and PDF inputs. | | [Multimodal Prompt with Few Shots](/examples/core/prompt/multimodal_with_few_shots.py) | [ragbits-core](/packages/ragbits-core) | Example of how to use the `Prompt` class to generate themed text using an LLM with multimodal inputs and few-shot examples. | | [Tool Use with LLM](/examples/core/llms/tool_use.py) | [ragbits-core](/packages/ragbits-core) | Example of how to provide tools and return tool calls from LLM. | +| [Reasoning with LLM](/examples/core/llms/reasoning.py) | [ragbits-core](/packages/ragbits-core) | Example of how to use reasoning with LLM. | | [OpenTelemetry Audit](/examples/core/audit/otel.py) | [ragbits-core](/packages/ragbits-core) | Example of how to collect traces and metrics using Ragbits audit module with OpenTelemetry. | | [Logfire Audit](/examples/core/audit/logfire_.py) | [ragbits-core](/packages/ragbits-core) | Example of how to collect traces and metrics using Ragbits audit module with Logfire. | | [Basic Document Search](/examples/document-search/basic.py) | [ragbits-document-search](/packages/ragbits-document-search) | Example of how to use the `DocumentSearch` class to search for documents with the `InMemoryVectorStore` class to store the embeddings. | @@ -38,6 +39,9 @@ All necessary details are provided in the comments at the top of each script. | [Recontextualize Last Message](/examples/chat/recontextualize_message.py) | [ragbits-chat](/packages/ragbits-chat) | Example of how to use the `StandaloneMessageCompressor` compressor to recontextualize the last message in a conversation history. | | [Agents Tool Use](/examples/agents/tool_use.py) | [ragbits-agents](/packages/ragbits-agents) | Example of how to use agent with tools. | | [Agents OpenAI Native Tool Use](/examples/agents/openai_native_tool_use.py) | [ragbits-agents](/packages/ragbits-agents) | Example of how to use agent with OpenAI native tools. | +| [Agents Post Processors](/examples/agents/post_processors.py) | [ragbits-agents](/packages/ragbits-agents) | Example of how to use post-processors with agent. | +| [Agents CLI](/examples/agents/cli_agent.py) | [ragbits-agents](/packages/ragbits-agents) | Example of how to use agent in CLI. | | [MCP Local](/examples/agents/mcp/local.py) | [ragbits-agents](/packages/ragbits-agents) | Example of how to use the `Agent` class to connect with a local MCP server. | | [MCP SSE](/examples/agents/mcp/sse.py) | [ragbits-agents](/packages/ragbits-agents) | Example of how to use the `Agent` class to connect with a remote MCP server via SSE. | | [MCP Streamable HTTP](/examples/agents/mcp/streamable_http.py) | [ragbits-agents](/packages/ragbits-agents) | Example of how to use the `Agent` class to connect with a remote MCP server via HTTP. | +| [A2A Orchestration](/examples/agents/a2a/run_orchestrator.py) | [ragbits-agents](/packages/ragbits-agents) | Example of how to setup A2A orchestration. | diff --git a/examples/agents/post_processors.py b/examples/agents/post_processors.py new file mode 100644 index 000000000..5e2e00695 --- /dev/null +++ b/examples/agents/post_processors.py @@ -0,0 +1,97 @@ +""" +Ragbits Agents Example: Post-Processors + +This example demonstrates how to use post-processors with Agent.run() and Agent.run_streaming() methods. + +To run the script, execute the following command: + + ```bash + uv run examples/agents/post_processors.py + ``` +""" + +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "ragbits-core", +# "ragbits-agents", +# ] +# /// + +import asyncio +from types import SimpleNamespace + +from ragbits.agents import Agent, AgentResult, PostProcessor, StreamingPostProcessor, ToolCallResult +from ragbits.core.llms.base import BasePrompt, ToolCall, Usage +from ragbits.core.llms.litellm import LiteLLM + + +class CustomStreamingPostProcessor(StreamingPostProcessor): + """ + Streaming post-processor that checks for forbidden words. + """ + + def __init__(self, forbidden_words: list[str]) -> None: + self.forbidden_words = forbidden_words + + async def process_streaming( + self, chunk: str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage, agent: Agent + ) -> str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage: + """ + Process chunks during streaming. + """ + if isinstance(chunk, str) and chunk.lower().strip() in self.forbidden_words: + return "[FORBIDDEN_WORD]" + return chunk + + +class CustomPostProcessor(PostProcessor): + """ + Non-streaming post-processor that truncates the content. + """ + + def __init__(self, max_length: int = 200) -> None: + self.max_length = max_length + + async def process(self, result: AgentResult, agent: Agent) -> AgentResult: + """ + Process the agent result. + """ + content = result.content + content_length = len(content) + + if content_length > self.max_length: + content = content[: self.max_length] + content += f"... [TRUNCATED] ({content_length} > {self.max_length} chars)" + + return AgentResult( + content=content, + metadata=result.metadata, + tool_calls=result.tool_calls, + history=result.history, + usage=result.usage, + ) + + +async def main() -> None: + """ + Run the example. + """ + llm = LiteLLM("gpt-4.1-mini") + agent: Agent = Agent(llm=llm, prompt="You are a helpful assistant.") + stream_result = agent.run_streaming( + "What is Python?", + post_processors=[ + CustomStreamingPostProcessor(forbidden_words=["python"]), + CustomPostProcessor(max_length=200), + ], + allow_non_streaming=True, + ) + async for chunk in stream_result: + if isinstance(chunk, str): + print(chunk, end="") + print(f"\n\nFinal answer:\n{stream_result.content}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/mkdocs.yml b/mkdocs.yml index e90e422ad..8c534094d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -40,6 +40,7 @@ nav: - "Define and use agents": how-to/agents/define_and_use_agents.md - "Provide tools with MCP": how-to/agents/provide_mcp_tools.md - "Serve agents with A2A": how-to/agents/serve_ragbits_agents.md + - "Use post-processors": how-to/agents/use_post_processors.md - Guardrails: - "Setup guardrails": how-to/guardrails/use_guardrails.md - Chatbots: diff --git a/packages/ragbits-agents/CHANGELOG.md b/packages/ragbits-agents/CHANGELOG.md index 2a722f747..c0eef472b 100644 --- a/packages/ragbits-agents/CHANGELOG.md +++ b/packages/ragbits-agents/CHANGELOG.md @@ -4,6 +4,7 @@ - Support wrapping downstream agents as tools (#818) - Add syntax sugar allowing easier Agents definition (#820) +- Add post-processors (#821) ## 1.3.0 (2025-09-11) ### Changed diff --git a/packages/ragbits-agents/src/ragbits/agents/__init__.py b/packages/ragbits-agents/src/ragbits/agents/__init__.py index 7a7036895..bd1809337 100644 --- a/packages/ragbits-agents/src/ragbits/agents/__init__.py +++ b/packages/ragbits-agents/src/ragbits/agents/__init__.py @@ -7,6 +7,7 @@ AgentRunContext, ToolCallResult, ) +from ragbits.agents.post_processors.base import PostProcessor, StreamingPostProcessor from ragbits.agents.types import QuestionAnswerAgent, QuestionAnswerPromptInput, QuestionAnswerPromptOutput __all__ = [ @@ -16,8 +17,10 @@ "AgentResult", "AgentResultStreaming", "AgentRunContext", + "PostProcessor", "QuestionAnswerAgent", "QuestionAnswerPromptInput", "QuestionAnswerPromptOutput", + "StreamingPostProcessor", "ToolCallResult", ] diff --git a/packages/ragbits-agents/src/ragbits/agents/_main.py b/packages/ragbits-agents/src/ragbits/agents/_main.py index 6b15a00e6..b98b06245 100644 --- a/packages/ragbits-agents/src/ragbits/agents/_main.py +++ b/packages/ragbits-agents/src/ragbits/agents/_main.py @@ -8,7 +8,7 @@ from datetime import timedelta from inspect import iscoroutinefunction from types import ModuleType, SimpleNamespace -from typing import Any, ClassVar, Generic, TypeVar, cast, overload +from typing import Any, ClassVar, Generic, Literal, TypeVar, cast, overload from pydantic import ( BaseModel, @@ -18,6 +18,7 @@ from ragbits import agents from ragbits.agents.exceptions import ( + AgentInvalidPostProcessorError, AgentInvalidPromptInputError, AgentMaxTokensExceededError, AgentMaxTurnsExceededError, @@ -29,6 +30,12 @@ ) from ragbits.agents.mcp.server import MCPServer, MCPServerStdio, MCPServerStreamableHttp from ragbits.agents.mcp.utils import get_tools +from ragbits.agents.post_processors.base import ( + BasePostProcessor, + PostProcessor, + StreamingPostProcessor, + stream_with_post_processing, +) from ragbits.agents.tool import Tool, ToolCallResult, ToolChoice from ragbits.core.audit.traces import trace from ragbits.core.llms.base import LLM, LLMClientOptionsT, LLMOptions, LLMResponseWithMetadata, ToolCall, Usage @@ -157,7 +164,7 @@ class AgentRunContext(BaseModel, Generic[DepsT]): """The usage of the agent.""" -class AgentResultStreaming(AsyncIterator[str | ToolCall | ToolCallResult]): +class AgentResultStreaming(AsyncIterator[str | ToolCall | ToolCallResult | BasePrompt | Usage | SimpleNamespace]): """ An async iterator that will collect all yielded items by LLM.generate_streaming(). This object is returned by `run_streaming`. It can be used in an `async for` loop to process items as they arrive. After the loop completes, @@ -174,38 +181,40 @@ def __init__( self.history: ChatFormat self.usage: Usage = Usage() - def __aiter__(self) -> AsyncIterator[str | ToolCall | ToolCallResult]: + def __aiter__(self) -> AsyncIterator[str | ToolCall | ToolCallResult | BasePrompt | Usage | SimpleNamespace]: return self - async def __anext__(self) -> str | ToolCall | ToolCallResult: + async def __anext__(self) -> str | ToolCall | ToolCallResult | BasePrompt | Usage | SimpleNamespace: try: item = await self._generator.__anext__() + match item: case str(): self.content += item + return item case ToolCall(): - pass + return item case ToolCallResult(): if self.tool_calls is None: self.tool_calls = [] self.tool_calls.append(item) + return item case BasePrompt(): item.add_assistant_message(self.content) self.history = item.chat - item = await self._generator.__anext__() - item = cast(SimpleNamespace, item) - item.result = { - "content": self.content, - "metadata": self.metadata, - "tool_calls": self.tool_calls, - } - raise StopAsyncIteration + return item case Usage(): self.usage = item return await self.__anext__() + case SimpleNamespace(): + result_dict = getattr(item, "result", {}) + self.content = result_dict.get("content", self.content) + self.metadata = result_dict.get("metadata", self.metadata) + self.tool_calls = result_dict.get("tool_calls", self.tool_calls) + return item case _: raise ValueError(f"Unexpected item: {item}") - return item + except StopAsyncIteration: raise @@ -292,6 +301,7 @@ async def run( options: AgentOptions[LLMClientOptionsT] | None = None, context: AgentRunContext | None = None, tool_choice: ToolChoice | None = None, + post_processors: list[PostProcessor[LLMClientOptionsT, PromptInputT, PromptOutputT]] | None = None, ) -> AgentResult[PromptOutputT]: ... @overload @@ -301,6 +311,7 @@ async def run( options: AgentOptions[LLMClientOptionsT] | None = None, context: AgentRunContext | None = None, tool_choice: ToolChoice | None = None, + post_processors: list[PostProcessor[LLMClientOptionsT, PromptInputT, PromptOutputT]] | None = None, ) -> AgentResult[PromptOutputT]: ... async def run( @@ -309,6 +320,7 @@ async def run( options: AgentOptions[LLMClientOptionsT] | None = None, context: AgentRunContext | None = None, tool_choice: ToolChoice | None = None, + post_processors: list[PostProcessor[LLMClientOptionsT, PromptInputT, PromptOutputT]] | None = None, ) -> AgentResult[PromptOutputT]: """ Run the agent. The method is experimental, inputs and outputs may change in the future. @@ -325,6 +337,7 @@ async def run( - "none": do not call tool - "required: enforce tool usage (model decides which one) - Callable: one of provided tools + post_processors: List of post-processors to apply to the response in order. Returns: The result of the agent run. @@ -336,6 +349,24 @@ async def run( AgentInvalidPromptInputError: If the prompt/input combination is invalid. AgentMaxTurnsExceededError: If the maximum number of turns is exceeded. """ + result = await self._run_without_post_processing(input, options, context, tool_choice) + + if post_processors: + for processor in post_processors: + result = await processor.process(result, self) + + return result + + async def _run_without_post_processing( + self, + input: str | PromptInputT | None, + options: AgentOptions[LLMClientOptionsT] | None = None, + context: AgentRunContext | None = None, + tool_choice: ToolChoice | None = None, + ) -> AgentResult[PromptOutputT]: + """ + Run the agent without applying post-processors. + """ if context is None: context = AgentRunContext() @@ -403,6 +434,33 @@ def run_streaming( options: AgentOptions[LLMClientOptionsT] | None = None, context: AgentRunContext | None = None, tool_choice: ToolChoice | None = None, + post_processors: list[StreamingPostProcessor[LLMClientOptionsT, PromptInputT, PromptOutputT]] | None = None, + *, + allow_non_streaming: bool = False, + ) -> AgentResultStreaming: ... + + @overload + def run_streaming( + self: "Agent[LLMClientOptionsT, None, PromptOutputT]", + input: str | None = None, + options: AgentOptions[LLMClientOptionsT] | None = None, + context: AgentRunContext | None = None, + tool_choice: ToolChoice | None = None, + post_processors: list[BasePostProcessor[LLMClientOptionsT, PromptInputT, PromptOutputT]] | None = None, + *, + allow_non_streaming: Literal[True], + ) -> AgentResultStreaming: ... + + @overload + def run_streaming( + self: "Agent[LLMClientOptionsT, PromptInputT, PromptOutputT]", + input: PromptInputT, + options: AgentOptions[LLMClientOptionsT] | None = None, + context: AgentRunContext | None = None, + tool_choice: ToolChoice | None = None, + post_processors: list[StreamingPostProcessor[LLMClientOptionsT, PromptInputT, PromptOutputT]] | None = None, + *, + allow_non_streaming: bool = False, ) -> AgentResultStreaming: ... @overload @@ -412,6 +470,9 @@ def run_streaming( options: AgentOptions[LLMClientOptionsT] | None = None, context: AgentRunContext | None = None, tool_choice: ToolChoice | None = None, + post_processors: list[BasePostProcessor[LLMClientOptionsT, PromptInputT, PromptOutputT]] | None = None, + *, + allow_non_streaming: Literal[True], ) -> AgentResultStreaming: ... def run_streaming( @@ -420,6 +481,13 @@ def run_streaming( options: AgentOptions[LLMClientOptionsT] | None = None, context: AgentRunContext | None = None, tool_choice: ToolChoice | None = None, + post_processors: ( + list[StreamingPostProcessor[LLMClientOptionsT, PromptInputT, PromptOutputT]] + | list[BasePostProcessor[LLMClientOptionsT, PromptInputT, PromptOutputT]] + | None + ) = None, + *, + allow_non_streaming: bool = False, ) -> AgentResultStreaming: """ This method returns an `AgentResultStreaming` object that can be asynchronously @@ -434,6 +502,8 @@ def run_streaming( - "none": do not call tool - "required: enforce tool usage (model decides which one) - Callable: one of provided tools + post_processors: List of post-processors to apply to the response in order. + allow_non_streaming: Whether to allow non-streaming post-processors. Returns: A `StreamingResult` object for iteration and collection. @@ -444,8 +514,17 @@ def run_streaming( AgentToolNotAvailableError: If the selected tool is not available. AgentInvalidPromptInputError: If the prompt/input combination is invalid. AgentMaxTurnsExceededError: If the maximum number of turns is exceeded. + AgentInvalidPostProcessorError: If the post-processor is invalid. """ generator = self._stream_internal(input, options, context, tool_choice) + + if post_processors: + if not allow_non_streaming and any(not p.supports_streaming for p in post_processors): + raise AgentInvalidPostProcessorError( + reason="Non-streaming post-processors are not allowed when allow_non_streaming is False" + ) + generator = stream_with_post_processing(generator, post_processors, self) + return AgentResultStreaming(generator) async def _stream_internal( diff --git a/packages/ragbits-agents/src/ragbits/agents/exceptions.py b/packages/ragbits-agents/src/ragbits/agents/exceptions.py index c474e6ac7..a80e76314 100644 --- a/packages/ragbits-agents/src/ragbits/agents/exceptions.py +++ b/packages/ragbits-agents/src/ragbits/agents/exceptions.py @@ -106,3 +106,13 @@ def __init__( self.limit = limit self.actual = actual self.next_prompt_tokens = next_prompt_tokens + + +class AgentInvalidPostProcessorError(AgentError): + """ + Raised when the post-processor is invalid. + """ + + def __init__(self, reason: str) -> None: + super().__init__(f"Invalid post-processor: {reason}") + self.reason = reason diff --git a/packages/ragbits-agents/src/ragbits/agents/post_processors/__init__.py b/packages/ragbits-agents/src/ragbits/agents/post_processors/__init__.py new file mode 100644 index 000000000..a00a4d209 --- /dev/null +++ b/packages/ragbits-agents/src/ragbits/agents/post_processors/__init__.py @@ -0,0 +1,7 @@ +""" +Post-processors for agent responses. +""" + +from .base import BasePostProcessor, PostProcessor, StreamingPostProcessor + +__all__ = ["BasePostProcessor", "PostProcessor", "StreamingPostProcessor"] diff --git a/packages/ragbits-agents/src/ragbits/agents/post_processors/base.py b/packages/ragbits-agents/src/ragbits/agents/post_processors/base.py new file mode 100644 index 000000000..4f38e13ca --- /dev/null +++ b/packages/ragbits-agents/src/ragbits/agents/post_processors/base.py @@ -0,0 +1,153 @@ +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator +from types import SimpleNamespace +from typing import TYPE_CHECKING, Generic, TypeVar, cast + +from ragbits.agents.tool import ToolCallResult +from ragbits.core.llms.base import LLMOptions, ToolCall, Usage +from ragbits.core.prompt.base import BasePrompt +from ragbits.core.prompt.prompt import PromptInputT, PromptOutputT + +if TYPE_CHECKING: + from ragbits.agents._main import Agent, AgentResult + + +LLMOptionsT = TypeVar("LLMOptionsT", bound=LLMOptions) + + +class BasePostProcessor(Generic[LLMOptionsT, PromptInputT, PromptOutputT]): + """Base class for post-processors.""" + + @property + @abstractmethod + def supports_streaming(self) -> bool: + """ + Whether this post-processor supports streaming mode. + + If True, the processor can work with content during streaming + via process_streaming() method. + + If False, the processor will only be called after streaming is complete + with the full result via process() method. + """ + + +class PostProcessor(ABC, BasePostProcessor[LLMOptionsT, PromptInputT, PromptOutputT]): + """Base class for non-streaming post-processors.""" + + @property + def supports_streaming(self) -> bool: + """Whether this post-processor supports streaming mode.""" + return False + + @abstractmethod + async def process( + self, + result: "AgentResult[PromptOutputT]", + agent: "Agent[LLMOptionsT, PromptInputT, PromptOutputT]", + ) -> "AgentResult[PromptOutputT]": + """ + Process the complete agent result. + + Args: + result: The complete AgentResult from the agent or previous post-processor. + agent: The Agent instance that generated the result. Can be used to re-run + the agent with modified input if needed. + + Returns: + Modified AgentResult to pass to the next processor or return as final result. + """ + + +class StreamingPostProcessor(ABC, BasePostProcessor[LLMOptionsT, PromptInputT, PromptOutputT]): + """Base class for streaming post-processors.""" + + @property + def supports_streaming(self) -> bool: + """Whether this post-processor supports streaming mode.""" + return True + + @abstractmethod + async def process_streaming( + self, + chunk: str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage, + agent: "Agent[LLMOptionsT, PromptInputT, PromptOutputT]", + ) -> str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage: + """ + Process chunks during streaming. + + Args: + chunk: The current chunk being streamed. + agent: The Agent instance generating the content. + + Returns: + Modified chunk to yield, or None to suppress this chunk. + Return the same chunk if no modification needed. + """ + + +async def stream_with_post_processing( + generator: AsyncGenerator[str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage], + post_processors: ( + list[StreamingPostProcessor[LLMOptionsT, PromptInputT, PromptOutputT]] + | list[BasePostProcessor[LLMOptionsT, PromptInputT, PromptOutputT]] + ), + agent: "Agent[LLMOptionsT, PromptInputT, PromptOutputT]", +) -> AsyncGenerator[str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage]: + """ + Stream with support for both streaming and non-streaming post-processors. + + Streaming processors get chunks in real-time via process_streaming(). + Non-streaming processors get the complete result via process(). + """ + from ragbits.agents import AgentResult + + streaming_processors = [p for p in post_processors or [] if isinstance(p, StreamingPostProcessor)] + non_streaming_processors = [p for p in post_processors or [] if isinstance(p, PostProcessor)] + + accumulated_content = "" + tool_call_results: list[ToolCallResult] = [] + usage: Usage = Usage() + prompt_with_history: BasePrompt | None = None + + async for chunk in generator: + processed_chunk = chunk + for streaming_processor in streaming_processors: + processed_chunk = await streaming_processor.process_streaming(chunk=processed_chunk, agent=agent) + if processed_chunk is None: + break + + if isinstance(processed_chunk, str): + accumulated_content += processed_chunk + elif isinstance(processed_chunk, ToolCallResult): + tool_call_results.append(processed_chunk) + elif isinstance(processed_chunk, Usage): + usage = processed_chunk + elif isinstance(processed_chunk, BasePrompt): + prompt_with_history = processed_chunk + + if processed_chunk is not None: + yield processed_chunk + + if non_streaming_processors and prompt_with_history: + agent_result = AgentResult( + content=cast(PromptOutputT, accumulated_content), + metadata={}, + tool_calls=tool_call_results or None, + history=prompt_with_history.chat, + usage=usage, + ) + + current_result = agent_result + for non_streaming_processor in non_streaming_processors: + current_result = await non_streaming_processor.process(current_result, agent) + + yield current_result.usage + yield prompt_with_history + yield SimpleNamespace( + result={ + "content": current_result.content, + "metadata": current_result.metadata, + "tool_calls": current_result.tool_calls, + } + ) diff --git a/packages/ragbits-agents/tests/unit/test_post_processors.py b/packages/ragbits-agents/tests/unit/test_post_processors.py new file mode 100644 index 000000000..b71216267 --- /dev/null +++ b/packages/ragbits-agents/tests/unit/test_post_processors.py @@ -0,0 +1,167 @@ +from types import SimpleNamespace + +import pytest + +from ragbits.agents import Agent, AgentResult, ToolCallResult +from ragbits.agents.exceptions import AgentInvalidPostProcessorError +from ragbits.agents.post_processors.base import PostProcessor, StreamingPostProcessor +from ragbits.core.llms.base import BasePrompt, ToolCall, Usage +from ragbits.core.llms.mock import MockLLM, MockLLMOptions + + +class MockPostProcessor(PostProcessor): + def __init__(self, append_content: str = " - processed"): + self.append_content = append_content + + async def process(self, result: AgentResult, agent: Agent) -> AgentResult: + result.content += self.append_content + return result + + +class MockStreamingPostProcessor(StreamingPostProcessor): + def __init__(self, append_content: str = " - streamed"): + self.append_content = append_content + + async def process_streaming( + self, chunk: str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage, agent: Agent + ): + if isinstance(chunk, str): + return chunk + self.append_content + return chunk + + +@pytest.fixture +def mock_llm() -> MockLLM: + options = MockLLMOptions(response="Initial response") + return MockLLM(default_options=options) + + +@pytest.mark.asyncio +async def test_non_streaming_post_processor(mock_llm: MockLLM): + agent: Agent = Agent(llm=mock_llm, prompt="Test prompt") + post_processor = MockPostProcessor() + + result = await agent.run(post_processors=[post_processor]) + + assert result.content == "Initial response - processed" + + +@pytest.mark.asyncio +async def test_streaming_post_processor(mock_llm: MockLLM): + agent: Agent = Agent(llm=mock_llm, prompt="Test prompt") + post_processor = MockStreamingPostProcessor() + + result = agent.run_streaming(post_processors=[post_processor]) + async for chunk in result: + if isinstance(chunk, str): + assert chunk.endswith(" - streamed") + + +@pytest.mark.asyncio +async def test_non_streaming_processor_in_streaming_mode_raises_error(mock_llm: MockLLM): + agent: Agent = Agent(llm=mock_llm, prompt="Test prompt") + post_processor = MockPostProcessor() + + with pytest.raises(AgentInvalidPostProcessorError): + await anext(agent.run_streaming(post_processors=[post_processor])) # type: ignore # ignore type-checking to test raising the error + + +@pytest.mark.asyncio +async def test_non_streaming_processor_in_streaming_mode_with_allow_non_streaming(mock_llm: MockLLM): + agent: Agent = Agent(llm=mock_llm, prompt="Test prompt") + post_processor = MockPostProcessor() + + result = agent.run_streaming(post_processors=[post_processor], allow_non_streaming=True) + + async for _ in result: + pass + + assert result.content == "Initial response - processed" + + +@pytest.mark.asyncio +async def test_streaming_and_non_streaming_processors(mock_llm: MockLLM): + agent: Agent = Agent(llm=mock_llm, prompt="Test prompt") + non_streaming_processor = MockPostProcessor() + streaming_processor = MockStreamingPostProcessor() + + result = agent.run_streaming( + post_processors=[streaming_processor, non_streaming_processor], allow_non_streaming=True + ) + + async for _ in result: + pass + + assert result.content == "Initial response - streamed - processed" + + +@pytest.mark.asyncio +async def test_streaming_processor_always_runs_before_non_streaming_processor(mock_llm: MockLLM): + agent: Agent = Agent(llm=mock_llm, prompt="Test prompt") + non_streaming_processor = MockPostProcessor() + streaming_processor = MockStreamingPostProcessor() + + result = agent.run_streaming( + post_processors=[streaming_processor, non_streaming_processor], allow_non_streaming=True + ) + async for _ in result: + pass + + assert result.content == "Initial response - streamed - processed" + + result = agent.run_streaming( + post_processors=[non_streaming_processor, streaming_processor], allow_non_streaming=True + ) + async for _ in result: + pass + + assert result.content == "Initial response - streamed - processed" + + +@pytest.mark.asyncio +async def test_multiple_non_streaming_processors_order(mock_llm: MockLLM): + agent: Agent = Agent(llm=mock_llm, prompt="Test prompt") + non_streaming_processor_1 = MockPostProcessor(append_content=" - processed 1") + non_streaming_processor_2 = MockPostProcessor(append_content=" - processed 2") + + result = await agent.run(post_processors=[non_streaming_processor_2, non_streaming_processor_1]) + + assert result.content == "Initial response - processed 2 - processed 1" + + +@pytest.mark.asyncio +async def test_multiple_streaming_processors_order(mock_llm: MockLLM): + agent: Agent = Agent(llm=mock_llm, prompt="Test prompt") + streaming_processor_1 = MockStreamingPostProcessor(append_content=" - streamed 1") + streaming_processor_2 = MockStreamingPostProcessor(append_content=" - streamed 2") + + result = agent.run_streaming( + post_processors=[streaming_processor_2, streaming_processor_1], + ) + async for _ in result: + pass + + assert result.content == "Initial response - streamed 2 - streamed 1" + + +@pytest.mark.asyncio +async def test_multiple_streaming_and_non_streaming_processors_order(mock_llm: MockLLM): + agent: Agent = Agent(llm=mock_llm, prompt="Test prompt") + streaming_processor_1 = MockStreamingPostProcessor(append_content=" - streamed 1") + streaming_processor_2 = MockStreamingPostProcessor(append_content=" - streamed 2") + non_streaming_processor_1 = MockPostProcessor(append_content=" - processed 1") + non_streaming_processor_2 = MockPostProcessor(append_content=" - processed 2") + + result = agent.run_streaming( + post_processors=[ + non_streaming_processor_2, + streaming_processor_1, + non_streaming_processor_1, + streaming_processor_2, + ], + allow_non_streaming=True, + ) + async for _ in result: + pass + + assert result.content == "Initial response - streamed 1 - streamed 2 - processed 2 - processed 1"