diff --git a/python/packages/ag-ui/tests/__init__.py b/python/packages/ag-ui/tests/__init__.py deleted file mode 100644 index 2a50eae894..0000000000 --- a/python/packages/ag-ui/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. diff --git a/python/packages/ag-ui/tests/test_ag_ui_client.py b/python/packages/ag-ui/tests/test_ag_ui_client.py index 09570c1be4..d46561c2c1 100644 --- a/python/packages/ag-ui/tests/test_ag_ui_client.py +++ b/python/packages/ag-ui/tests/test_ag_ui_client.py @@ -9,13 +9,13 @@ from agent_framework import ( ChatMessage, ChatOptions, + ChatResponse, ChatResponseUpdate, FunctionCallContent, Role, TextContent, ai_function, ) -from agent_framework._types import ChatResponse from pytest import MonkeyPatch from agent_framework_ag_ui._client import AGUIChatClient, ServerFunctionCallContent diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index 281b81c968..a553cc838e 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -9,12 +9,11 @@ from typing import Any import pytest -from agent_framework import ChatAgent, ChatMessage, ChatOptions, TextContent -from agent_framework._types import ChatResponseUpdate +from agent_framework import ChatAgent, ChatMessage, ChatOptions, ChatResponseUpdate, TextContent from pydantic import BaseModel sys.path.insert(0, str(Path(__file__).parent)) -from test_helpers_ag_ui import StreamingChatClientStub +from utils_test_ag_ui import StreamingChatClientStub async def test_agent_initialization_basic(): diff --git a/python/packages/ag-ui/tests/test_document_writer_flow.py b/python/packages/ag-ui/tests/test_document_writer_flow.py index 1ea164beef..01ab789c20 100644 --- a/python/packages/ag-ui/tests/test_document_writer_flow.py +++ b/python/packages/ag-ui/tests/test_document_writer_flow.py @@ -3,8 +3,7 @@ """Tests for document writer predictive state flow with confirm_changes.""" from ag_ui.core import EventType, StateDeltaEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallStartEvent -from agent_framework import FunctionCallContent, FunctionResultContent, TextContent -from agent_framework._types import AgentRunResponseUpdate +from agent_framework import AgentRunResponseUpdate, FunctionCallContent, FunctionResultContent, TextContent from agent_framework_ag_ui._events import AgentFrameworkEventBridge diff --git a/python/packages/ag-ui/tests/test_endpoint.py b/python/packages/ag-ui/tests/test_endpoint.py index e8b4c86a49..02b23a544b 100644 --- a/python/packages/ag-ui/tests/test_endpoint.py +++ b/python/packages/ag-ui/tests/test_endpoint.py @@ -6,8 +6,7 @@ import sys from pathlib import Path -from agent_framework import ChatAgent, TextContent -from agent_framework._types import ChatResponseUpdate +from agent_framework import ChatAgent, ChatResponseUpdate, TextContent from fastapi import FastAPI from fastapi.testclient import TestClient @@ -15,7 +14,7 @@ from agent_framework_ag_ui._endpoint import add_agent_framework_fastapi_endpoint sys.path.insert(0, str(Path(__file__).parent)) -from test_helpers_ag_ui import StreamingChatClientStub, stream_from_updates +from utils_test_ag_ui import StreamingChatClientStub, stream_from_updates def build_chat_client(response_text: str = "Test response") -> StreamingChatClientStub: diff --git a/python/packages/ag-ui/tests/test_human_in_the_loop.py b/python/packages/ag-ui/tests/test_human_in_the_loop.py index 55a2869c91..8e932d4aea 100644 --- a/python/packages/ag-ui/tests/test_human_in_the_loop.py +++ b/python/packages/ag-ui/tests/test_human_in_the_loop.py @@ -2,8 +2,7 @@ """Tests for human in the loop (function approval requests).""" -from agent_framework import FunctionApprovalRequestContent, FunctionCallContent -from agent_framework._types import AgentRunResponseUpdate +from agent_framework import AgentRunResponseUpdate, FunctionApprovalRequestContent, FunctionCallContent from agent_framework_ag_ui._events import AgentFrameworkEventBridge diff --git a/python/packages/ag-ui/tests/test_orchestrators.py b/python/packages/ag-ui/tests/test_orchestrators.py index 8c00602538..a3803f04d0 100644 --- a/python/packages/ag-ui/tests/test_orchestrators.py +++ b/python/packages/ag-ui/tests/test_orchestrators.py @@ -6,8 +6,7 @@ from types import SimpleNamespace from typing import Any -from agent_framework import AgentRunResponseUpdate, TextContent, ai_function -from agent_framework._tools import FunctionInvocationConfiguration +from agent_framework import AgentRunResponseUpdate, FunctionInvocationConfiguration, TextContent, ai_function from agent_framework_ag_ui._agent import AgentConfig from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, ExecutionContext diff --git a/python/packages/ag-ui/tests/test_orchestrators_coverage.py b/python/packages/ag-ui/tests/test_orchestrators_coverage.py index 041e25c3d2..e10e84a754 100644 --- a/python/packages/ag-ui/tests/test_orchestrators_coverage.py +++ b/python/packages/ag-ui/tests/test_orchestrators_coverage.py @@ -20,7 +20,7 @@ from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, HumanInTheLoopOrchestrator sys.path.insert(0, str(Path(__file__).parent)) -from test_helpers_ag_ui import StubAgent, TestExecutionContext +from utils_test_ag_ui import StubAgent, TestExecutionContext @ai_function(approval_mode="always_require") diff --git a/python/packages/ag-ui/tests/test_shared_state.py b/python/packages/ag-ui/tests/test_shared_state.py index 36f80b9d47..469f5f5ad8 100644 --- a/python/packages/ag-ui/tests/test_shared_state.py +++ b/python/packages/ag-ui/tests/test_shared_state.py @@ -8,14 +8,13 @@ import pytest from ag_ui.core import StateSnapshotEvent -from agent_framework import ChatAgent, TextContent -from agent_framework._types import ChatResponseUpdate +from agent_framework import ChatAgent, ChatResponseUpdate, TextContent from agent_framework_ag_ui._agent import AgentFrameworkAgent from agent_framework_ag_ui._events import AgentFrameworkEventBridge sys.path.insert(0, str(Path(__file__).parent)) -from test_helpers_ag_ui import StreamingChatClientStub, stream_from_updates +from utils_test_ag_ui import StreamingChatClientStub, stream_from_updates @pytest.fixture diff --git a/python/packages/ag-ui/tests/test_structured_output.py b/python/packages/ag-ui/tests/test_structured_output.py index c5f9719938..aee8fe6dc8 100644 --- a/python/packages/ag-ui/tests/test_structured_output.py +++ b/python/packages/ag-ui/tests/test_structured_output.py @@ -8,12 +8,11 @@ from pathlib import Path from typing import Any -from agent_framework import ChatAgent, ChatMessage, ChatOptions, TextContent -from agent_framework._types import ChatResponseUpdate +from agent_framework import ChatAgent, ChatMessage, ChatOptions, ChatResponseUpdate, TextContent from pydantic import BaseModel sys.path.insert(0, str(Path(__file__).parent)) -from test_helpers_ag_ui import StreamingChatClientStub, stream_from_updates +from utils_test_ag_ui import StreamingChatClientStub, stream_from_updates class RecipeOutput(BaseModel): diff --git a/python/packages/ag-ui/tests/test_helpers_ag_ui.py b/python/packages/ag-ui/tests/utils_test_ag_ui.py similarity index 89% rename from python/packages/ag-ui/tests/test_helpers_ag_ui.py rename to python/packages/ag-ui/tests/utils_test_ag_ui.py index fc82b11510..3c86c284da 100644 --- a/python/packages/ag-ui/tests/test_helpers_ag_ui.py +++ b/python/packages/ag-ui/tests/utils_test_ag_ui.py @@ -11,12 +11,13 @@ AgentRunResponse, AgentRunResponseUpdate, AgentThread, + BaseChatClient, ChatMessage, ChatOptions, + ChatResponse, + ChatResponseUpdate, TextContent, ) -from agent_framework._clients import BaseChatClient -from agent_framework._types import ChatResponse, ChatResponseUpdate from agent_framework_ag_ui._message_adapters import _deduplicate_messages, _sanitize_tool_history from agent_framework_ag_ui._orchestrators import ExecutionContext @@ -79,31 +80,15 @@ def __init__( chat_options: Any | None = None, chat_client: Any | None = None, ) -> None: - self._id = agent_id - self._name = agent_name - self._description = "stub agent" + self.id = agent_id + self.name = agent_name + self.description = "stub agent" self.updates = updates or [AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant")] self.chat_options = chat_options or SimpleNamespace(tools=None, response_format=None) self.chat_client = chat_client or SimpleNamespace(function_invocation_configuration=None) self.messages_received: list[Any] = [] self.tools_received: list[Any] | None = None - @property - def id(self) -> str: - return self._id - - @property - def name(self) -> str | None: - return self._name - - @property - def display_name(self) -> str: - return self._name or self._id - - @property - def description(self) -> str | None: - return self._description - async def run( self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py index 4cef22e023..d3b98b05ff 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py @@ -182,32 +182,11 @@ def __init__(self, context: AgentOrchestrationContextType, agent_name: str): """ self.context = context self.agent_name = agent_name - self._id = str(uuid.uuid4()) - self._name = agent_name - self._display_name = agent_name - self._description = f"Durable agent proxy for {agent_name}" + self.id = str(uuid.uuid4()) + self.name = agent_name + self.description = f"Durable agent proxy for {agent_name}" logger.debug("[DurableAIAgent] Initialized for agent: %s", agent_name) - @property - def id(self) -> str: - """Get the unique identifier for this agent.""" - return self._id - - @property - def name(self) -> str | None: - """Get the name of the agent.""" - return self._name - - @property - def display_name(self) -> str: - """Get the display name of the agent.""" - return self._display_name - - @property - def description(self) -> str | None: - """Get the description of the agent.""" - return self._description - # We return an AgentTask here which is a TaskBase subclass. # This is an intentional deviation from AgentProtocol which defines run() as async. # The AgentTask can be yielded in Durable Functions orchestrations and will provide diff --git a/python/packages/azurefunctions/tests/integration_tests/__init__.py b/python/packages/azurefunctions/tests/integration_tests/__init__.py deleted file mode 100644 index 2a50eae894..0000000000 --- a/python/packages/azurefunctions/tests/integration_tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. diff --git a/python/packages/azurefunctions/tests/integration_tests/conftest.py b/python/packages/azurefunctions/tests/integration_tests/conftest.py index e2f19d6037..ee81028b80 100644 --- a/python/packages/azurefunctions/tests/integration_tests/conftest.py +++ b/python/packages/azurefunctions/tests/integration_tests/conftest.py @@ -6,13 +6,18 @@ """ import subprocess +import sys from collections.abc import Iterator, Mapping +from pathlib import Path from typing import Any import pytest import requests -from .testutils import ( +# Add the integration_tests directory to the path so testutils can be imported +sys.path.insert(0, str(Path(__file__).parent)) + +from testutils import ( FunctionAppStartupError, build_base_url, cleanup_function_app, diff --git a/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py b/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py index cd93e6a352..2f152552e2 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_01_single_agent.py @@ -15,11 +15,10 @@ """ import pytest +from testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled from agent_framework_azurefunctions._constants import THREAD_ID_HEADER -from .testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled - # Module-level markers - applied to all tests in this file pytestmark = [ pytest.mark.sample("01_single_agent"), diff --git a/python/packages/azurefunctions/tests/integration_tests/test_02_multi_agent.py b/python/packages/azurefunctions/tests/integration_tests/test_02_multi_agent.py index f473a2be11..7a4adfd8dd 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_02_multi_agent.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_02_multi_agent.py @@ -15,8 +15,7 @@ """ import pytest - -from .testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled +from testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file pytestmark = [ diff --git a/python/packages/azurefunctions/tests/integration_tests/test_03_reliable_streaming.py b/python/packages/azurefunctions/tests/integration_tests/test_03_reliable_streaming.py index 44fb8efb2f..032935ee29 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_03_reliable_streaming.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_03_reliable_streaming.py @@ -19,8 +19,7 @@ import pytest import requests - -from .testutils import ( +from testutils import ( SampleTestHelper, skip_if_azure_functions_integration_tests_disabled, ) diff --git a/python/packages/azurefunctions/tests/integration_tests/test_04_single_agent_orchestration_chaining.py b/python/packages/azurefunctions/tests/integration_tests/test_04_single_agent_orchestration_chaining.py index e4bb1cd930..fff06c9d8d 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_04_single_agent_orchestration_chaining.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_04_single_agent_orchestration_chaining.py @@ -19,8 +19,7 @@ """ import pytest - -from .testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled +from testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file pytestmark = [ diff --git a/python/packages/azurefunctions/tests/integration_tests/test_05_multi_agent_orchestration_concurrency.py b/python/packages/azurefunctions/tests/integration_tests/test_05_multi_agent_orchestration_concurrency.py index aac8f361c6..d2d9cbbed8 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_05_multi_agent_orchestration_concurrency.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_05_multi_agent_orchestration_concurrency.py @@ -19,8 +19,7 @@ """ import pytest - -from .testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled +from testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file pytestmark = [ diff --git a/python/packages/azurefunctions/tests/integration_tests/test_06_multi_agent_orchestration_conditionals.py b/python/packages/azurefunctions/tests/integration_tests/test_06_multi_agent_orchestration_conditionals.py index d7f13777bb..0b2a9f7073 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_06_multi_agent_orchestration_conditionals.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_06_multi_agent_orchestration_conditionals.py @@ -19,8 +19,7 @@ """ import pytest - -from .testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled +from testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file pytestmark = [ diff --git a/python/packages/azurefunctions/tests/integration_tests/test_07_single_agent_orchestration_hitl.py b/python/packages/azurefunctions/tests/integration_tests/test_07_single_agent_orchestration_hitl.py index ade46033bc..f21410ebf5 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_07_single_agent_orchestration_hitl.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_07_single_agent_orchestration_hitl.py @@ -21,8 +21,7 @@ import time import pytest - -from .testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled +from testutils import SampleTestHelper, skip_if_azure_functions_integration_tests_disabled # Module-level markers - applied to all tests in this file pytestmark = [ diff --git a/python/packages/azurefunctions/tests/test_orchestration.py b/python/packages/azurefunctions/tests/test_orchestration.py index b0dd313b0b..0fd774d1cd 100644 --- a/python/packages/azurefunctions/tests/test_orchestration.py +++ b/python/packages/azurefunctions/tests/test_orchestration.py @@ -218,12 +218,10 @@ def test_has_agent_protocol_properties(self) -> None: assert hasattr(agent, "id") assert hasattr(agent, "name") assert hasattr(agent, "description") - assert hasattr(agent, "display_name") # Verify values assert agent.name == "TestAgent" assert agent.description == "Durable agent proxy for TestAgent" - assert agent.display_name == "TestAgent" assert agent.id is not None # Auto-generated UUID def test_get_new_thread(self) -> None: diff --git a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py index 35abd6dec7..a80fb7221a 100644 --- a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py +++ b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py @@ -4,11 +4,10 @@ from typing import Any, ClassVar from agent_framework import ( - AgentMiddlewares, + AgentMiddlewareTypes, AgentRunResponse, AgentRunResponseUpdate, AgentThread, - AggregateContextProvider, BaseAgent, ChatMessage, ContextProvider, @@ -79,8 +78,8 @@ def __init__( id: str | None = None, name: str | None = None, description: str | None = None, - context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None, - middleware: AgentMiddlewares | list[AgentMiddlewares] | None = None, + context_provider: ContextProvider | None = None, + middleware: list[AgentMiddlewareTypes] | None = None, environment_id: str | None = None, agent_identifier: str | None = None, client_id: str | None = None, @@ -107,8 +106,8 @@ def __init__( id: id of the CopilotAgent name: Name of the CopilotAgent description: Description of the CopilotAgent - context_providers: Context Providers, to be used by the copilot agent. - middleware: Agent middlewares used by the agent. + context_provider: Context Provider, to be used by the copilot agent. + middleware: Agent middleware used by the agent, should be a list of AgentMiddlewareTypes. environment_id: Environment ID of the Power Platform environment containing the Copilot Studio app. Can also be set via COPILOTSTUDIOAGENT__ENVIRONMENTID environment variable. @@ -138,7 +137,7 @@ def __init__( id=id, name=name, description=description, - context_providers=context_providers, + context_provider=context_provider, middleware=middleware, ) if not client: diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index f260b1a6d3..78164ad2a0 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -18,7 +18,7 @@ from ._clients import BaseChatClient, ChatClientProtocol from ._logging import get_logger from ._mcp import LOG_LEVEL_MAPPING, MCPTool -from ._memory import AggregateContextProvider, Context, ContextProvider +from ._memory import Context, ContextProvider from ._middleware import Middleware, use_agent_middleware from ._serialization import SerializationMixin from ._threads import AgentThread, ChatMessageStoreProtocol @@ -116,24 +116,9 @@ class AgentProtocol(Protocol): # No need to inherit from AgentProtocol or use any framework classes class CustomAgent: def __init__(self): - self._id = "custom-agent-001" - self._name = "Custom Agent" - - @property - def id(self) -> str: - return self._id - - @property - def name(self) -> str | None: - return self._name - - @property - def display_name(self) -> str: - return self.name or self.id - - @property - def description(self) -> str | None: - return "A fully custom agent implementation" + self.id = "custom-agent-001" + self.name = "Custom Agent" + self.description = "A fully custom agent implementation" async def run(self, messages=None, *, thread=None, **kwargs): # Your custom implementation @@ -160,25 +145,9 @@ def get_new_thread(self, **kwargs): assert isinstance(instance, AgentProtocol) """ - @property - def id(self) -> str: - """Returns the ID of the agent.""" - ... - - @property - def name(self) -> str | None: - """Returns the name of the agent.""" - ... - - @property - def display_name(self) -> str: - """Returns the display name of the agent.""" - ... - - @property - def description(self) -> str | None: - """Returns the description of the agent.""" - ... + id: str + name: str | None + description: str | None async def run( self, @@ -289,7 +258,6 @@ async def _stream(): # Access agent properties print(agent.id) # Custom or auto-generated UUID - print(agent.display_name) # Returns name or id """ DEFAULT_EXCLUDE: ClassVar[set[str]] = {"additional_properties"} @@ -300,8 +268,8 @@ def __init__( id: str | None = None, name: str | None = None, description: str | None = None, - context_providers: ContextProvider | Sequence[ContextProvider] | None = None, - middleware: Middleware | Sequence[Middleware] | None = None, + context_provider: ContextProvider | None = None, + middleware: Sequence[Middleware] | None = None, additional_properties: MutableMapping[str, Any] | None = None, **kwargs: Any, ) -> None: @@ -312,8 +280,8 @@ def __init__( a new UUID will be generated. name: The name of the agent, can be None. description: The description of the agent. - context_providers: The collection of multiple context providers to include during agent invocation. - middleware: List of middleware to intercept agent and function invocations. + context_provider: The context provider to include during agent invocation. + middleware: List of middleware. additional_properties: Additional properties set on the agent. kwargs: Additional keyword arguments (merged into additional_properties). """ @@ -322,11 +290,10 @@ def __init__( self.id = id self.name = name self.description = description - self.context_provider = self._prepare_context_providers(context_providers) - if middleware is None or isinstance(middleware, Sequence): - self.middleware: list[Middleware] | None = cast(list[Middleware], middleware) if middleware else None - else: - self.middleware = [middleware] + self.context_provider = context_provider + self.middleware: list[Middleware] | None = ( + cast(list[Middleware], middleware) if middleware is not None else None + ) # Merge kwargs into additional_properties self.additional_properties: dict[str, Any] = cast(dict[str, Any], additional_properties or {}) @@ -356,14 +323,6 @@ async def _notify_thread_of_new_messages( if thread.context_provider: await thread.context_provider.invoked(input_messages, response_messages, **kwargs) - @property - def display_name(self) -> str: - """Returns the display name of the agent. - - This is the name if present, otherwise the id. - """ - return self.name or self.id - def get_new_thread(self, **kwargs: Any) -> AgentThread: """Return a new AgentThread instance that is compatible with the agent. @@ -499,18 +458,6 @@ def _normalize_messages( return [ChatMessage(role=Role.USER, text=msg) if isinstance(msg, str) else msg for msg in messages] - def _prepare_context_providers( - self, - context_providers: ContextProvider | Sequence[ContextProvider] | None = None, - ) -> AggregateContextProvider | None: - if not context_providers: - return None - - if isinstance(context_providers, AggregateContextProvider): - return context_providers - - return AggregateContextProvider(context_providers) - # region ChatAgent @@ -594,8 +541,8 @@ def __init__( name: str | None = None, description: str | None = None, chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None, - context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None, - middleware: Middleware | list[Middleware] | None = None, + context_provider: ContextProvider | None = None, + middleware: Sequence[Middleware] | None = None, # chat options allow_multiple_tool_calls: bool | None = None, conversation_id: str | None = None, @@ -639,8 +586,8 @@ def __init__( description: A brief description of the agent's purpose. chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol. If not provided, the default in-memory store will be used. - context_providers: The collection of multiple context providers to include during agent invocation. - middleware: List of middleware to intercept agent and function invocations. + context_provider: The context provider to include during agent invocation. + middleware: List of middleware to intercept agent, chat and function invocations. allow_multiple_tool_calls: Whether to allow multiple tool calls in a single response. conversation_id: The conversation ID for service-managed threads. Cannot be used together with chat_message_store_factory. @@ -683,7 +630,7 @@ def __init__( id=id, name=name, description=description, - context_providers=context_providers, + context_provider=context_provider, middleware=middleware, **kwargs, ) diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 6743902475..ca273ea536 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from ._logging import get_logger -from ._memory import AggregateContextProvider, ContextProvider +from ._memory import ContextProvider from ._middleware import ( ChatMiddleware, ChatMiddlewareCallable, @@ -336,12 +336,7 @@ def __init__( self, *, middleware: ( - ChatMiddleware - | ChatMiddlewareCallable - | FunctionMiddleware - | FunctionMiddlewareCallable - | list[ChatMiddleware | ChatMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable] - | None + Sequence[ChatMiddleware | ChatMiddlewareCallable | FunctionMiddleware | FunctionMiddlewareCallable] | None ) = None, additional_properties: dict[str, Any] | None = None, **kwargs: Any, @@ -722,8 +717,8 @@ def create_agent( description: str | None = None, instructions: str | None = None, chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None, - context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None, - middleware: Middleware | list[Middleware] | None = None, + context_provider: ContextProvider | None = None, + middleware: Sequence[Middleware] | None = None, allow_multiple_tool_calls: bool | None = None, conversation_id: str | None = None, frequency_penalty: float | None = None, @@ -761,8 +756,8 @@ def create_agent( These will be put into the messages sent to the chat client service as a system message. chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol. If not provided, the default in-memory store will be used. - context_providers: Context providers to include during agent invocation. - middleware: List of middleware to intercept agent and function invocations. + context_provider: Context provider to include during agent invocation. + middleware: List of middleware to intercept chat and function invocations. allow_multiple_tool_calls: Whether to allow multiple tool calls per agent turn. conversation_id: The conversation ID to associate with the agent's messages. frequency_penalty: The frequency penalty to use. @@ -813,7 +808,7 @@ def create_agent( description=description, instructions=instructions, chat_message_store_factory=chat_message_store_factory, - context_providers=context_providers, + context_provider=context_provider, middleware=middleware, allow_multiple_tool_calls=allow_multiple_tool_calls, conversation_id=conversation_id, diff --git a/python/packages/core/agent_framework/_memory.py b/python/packages/core/agent_framework/_memory.py index a5b53fc39f..5e46b1749d 100644 --- a/python/packages/core/agent_framework/_memory.py +++ b/python/packages/core/agent_framework/_memory.py @@ -1,22 +1,16 @@ # Copyright (c) Microsoft. All rights reserved. -import asyncio import sys from abc import ABC, abstractmethod from collections.abc import MutableSequence, Sequence -from contextlib import AsyncExitStack from types import TracebackType -from typing import TYPE_CHECKING, Any, Final, cast +from typing import TYPE_CHECKING, Any, Final from ._types import ChatMessage if TYPE_CHECKING: from ._tools import ToolProtocol -if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover -else: - from typing_extensions import override # type: ignore[import] # pragma: no cover if sys.version_info >= (3, 11): from typing import Self # pragma: no cover else: @@ -24,7 +18,7 @@ # region Context -__all__ = ["AggregateContextProvider", "Context", "ContextProvider"] +__all__ = ["Context", "ContextProvider"] class Context: @@ -100,7 +94,7 @@ async def invoking(self, messages, **kwargs): # Use with a chat agent async with CustomContextProvider() as provider: - agent = ChatAgent(chat_client=client, name="assistant", context_providers=provider) + agent = ChatAgent(chat_client=client, name="assistant", context_provider=provider) """ # Default prompt to be used by all context providers when assembling memories/instructions @@ -183,130 +177,3 @@ async def __aexit__( exc_tb: The exception traceback if an exception occurred, None otherwise. """ pass - - -# region AggregateContextProvider - - -class AggregateContextProvider(ContextProvider): - """A ContextProvider that contains multiple context providers. - - It delegates events to multiple context providers and aggregates responses from those - events before returning. This allows you to combine multiple context providers into a - single provider. - - Note: - An AggregateContextProvider is created automatically when you pass a single context - provider or a sequence of context providers to the agent constructor. - - Examples: - .. code-block:: python - - from agent_framework import AggregateContextProvider, ChatAgent - - # Create multiple context providers - provider1 = CustomContextProvider1() - provider2 = CustomContextProvider2() - provider3 = CustomContextProvider3() - - # Pass them to the agent - AggregateContextProvider is created automatically - agent = ChatAgent(chat_client=client, name="assistant", context_providers=[provider1, provider2, provider3]) - - # Verify that an AggregateContextProvider was created - assert isinstance(agent.context_providers, AggregateContextProvider) - - # Add additional providers to the agent - provider4 = CustomContextProvider4() - agent.context_providers.add(provider4) - """ - - def __init__(self, context_providers: ContextProvider | Sequence[ContextProvider] | None = None) -> None: - """Initialize the AggregateContextProvider with context providers. - - Args: - context_providers: The context provider(s) to add. - """ - if isinstance(context_providers, ContextProvider): - self.providers = [context_providers] - else: - self.providers = cast(list[ContextProvider], context_providers) or [] - self._exit_stack: AsyncExitStack | None = None - - def add(self, context_provider: ContextProvider) -> None: - """Add a new context provider. - - Args: - context_provider: The context provider to add. - """ - self.providers.append(context_provider) - - @override - async def thread_created(self, thread_id: str | None = None) -> None: - await asyncio.gather(*[x.thread_created(thread_id) for x in self.providers]) - - @override - async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context: - contexts = await asyncio.gather(*[provider.invoking(messages, **kwargs) for provider in self.providers]) - instructions: str = "" - return_messages: list[ChatMessage] = [] - tools: list["ToolProtocol"] = [] - for ctx in contexts: - if ctx.instructions: - instructions += ctx.instructions - if ctx.messages: - return_messages.extend(ctx.messages) - if ctx.tools: - tools.extend(ctx.tools) - return Context(instructions=instructions, messages=return_messages, tools=tools) - - @override - async def invoked( - self, - request_messages: ChatMessage | Sequence[ChatMessage], - response_messages: ChatMessage | Sequence[ChatMessage] | None = None, - invoke_exception: Exception | None = None, - **kwargs: Any, - ) -> None: - await asyncio.gather(*[ - x.invoked( - request_messages=request_messages, - response_messages=response_messages, - invoke_exception=invoke_exception, - **kwargs, - ) - for x in self.providers - ]) - - @override - async def __aenter__(self) -> "Self": - """Enter the async context manager and set up all providers. - - Returns: - The AggregateContextProvider instance for chaining. - """ - self._exit_stack = AsyncExitStack() - await self._exit_stack.__aenter__() - - # Enter all context providers - for provider in self.providers: - await self._exit_stack.enter_async_context(provider) - - return self - - @override - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - """Exit the async context manager and clean up all providers. - - Args: - exc_type: The exception type if an exception occurred, None otherwise. - exc_val: The exception value if an exception occurred, None otherwise. - exc_tb: The exception traceback if an exception occurred, None otherwise. - """ - if self._exit_stack is not None: - await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) - self._exit_stack = None diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 4e36cb764a..00a465f72b 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -2,10 +2,10 @@ import inspect from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence +from collections.abc import AsyncIterable, Awaitable, Callable, MutableSequence, Sequence from enum import Enum from functools import update_wrapper -from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypedDict, TypeVar from ._serialization import SerializationMixin from ._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, prepare_messages @@ -23,7 +23,7 @@ __all__ = [ "AgentMiddleware", - "AgentMiddlewares", + "AgentMiddlewareTypes", "AgentRunContext", "ChatContext", "ChatMiddleware", @@ -305,7 +305,7 @@ async def process(self, context: AgentRunContext, next): # Use with an agent - agent = ChatAgent(chat_client=client, name="assistant", middleware=RetryMiddleware()) + agent = ChatAgent(chat_client=client, name="assistant", middleware=[RetryMiddleware()]) """ @abstractmethod @@ -373,7 +373,7 @@ async def process(self, context: FunctionInvocationContext, next): # Use with an agent - agent = ChatAgent(chat_client=client, name="assistant", middleware=CachingMiddleware()) + agent = ChatAgent(chat_client=client, name="assistant", middleware=[CachingMiddleware()]) """ @abstractmethod @@ -432,7 +432,9 @@ async def process(self, context: ChatContext, next): # Use with an agent agent = ChatAgent( - chat_client=client, name="assistant", middleware=SystemPromptMiddleware("You are a helpful assistant.") + chat_client=client, + name="assistant", + middleware=[SystemPromptMiddleware("You are a helpful assistant.")], ) """ @@ -480,7 +482,7 @@ async def process( | ChatMiddleware | ChatMiddlewareCallable ) -AgentMiddlewares: TypeAlias = AgentMiddleware | AgentMiddlewareCallable +AgentMiddlewareTypes: TypeAlias = AgentMiddleware | AgentMiddlewareCallable # region Middleware type markers for decorators @@ -511,7 +513,7 @@ async def logging_middleware(context: AgentRunContext, next): # Use with an agent - agent = ChatAgent(chat_client=client, name="assistant", middleware=logging_middleware) + agent = ChatAgent(chat_client=client, name="assistant", middleware=[logging_middleware]) """ # Add marker attribute to identify this as agent middleware func._middleware_type: MiddlewareType = MiddlewareType.AGENT # type: ignore @@ -544,7 +546,7 @@ async def logging_middleware(context: FunctionInvocationContext, next): # Use with an agent - agent = ChatAgent(chat_client=client, name="assistant", middleware=logging_middleware) + agent = ChatAgent(chat_client=client, name="assistant", middleware=[logging_middleware]) """ # Add marker attribute to identify this as function middleware func._middleware_type: MiddlewareType = MiddlewareType.FUNCTION # type: ignore @@ -577,7 +579,7 @@ async def logging_middleware(context: ChatContext, next): # Use with an agent - agent = ChatAgent(chat_client=client, name="assistant", middleware=logging_middleware) + agent = ChatAgent(chat_client=client, name="assistant", middleware=[logging_middleware]) """ # Add marker attribute to identify this as chat middleware func._middleware_type: MiddlewareType = MiddlewareType.CHAT # type: ignore @@ -609,7 +611,7 @@ class BaseMiddlewarePipeline(ABC): def __init__(self) -> None: """Initialize the base middleware pipeline.""" - self._middlewares: list[Any] = [] + self._middleware: list[Any] = [] @abstractmethod def _register_middleware(self, middleware: Any) -> None: @@ -624,12 +626,12 @@ def _register_middleware(self, middleware: Any) -> None: @property def has_middlewares(self) -> bool: - """Check if there are any middlewares registered. + """Check if there are any middleware registered. Returns: - True if middlewares are registered, False otherwise. + True if middleware are registered, False otherwise. """ - return bool(self._middlewares) + return bool(self._middleware) def _register_middleware_with_wrapper( self, @@ -645,9 +647,9 @@ def _register_middleware_with_wrapper( expected_type: The expected middleware base class type. """ if isinstance(middleware, expected_type): - self._middlewares.append(middleware) + self._middleware.append(middleware) elif callable(middleware): - self._middlewares.append(MiddlewareWrapper(middleware)) # type: ignore[arg-type] + self._middleware.append(MiddlewareWrapper(middleware)) # type: ignore[arg-type] def _create_handler_chain( self, @@ -667,7 +669,7 @@ def _create_handler_chain( """ def create_next_handler(index: int) -> Callable[[Any], Awaitable[None]]: - if index >= len(self._middlewares): + if index >= len(self._middleware): async def final_wrapper(c: Any) -> None: # Execute actual handler and populate context for observability @@ -677,7 +679,7 @@ async def final_wrapper(c: Any) -> None: return final_wrapper - middleware = self._middlewares[index] + middleware = self._middleware[index] next_handler = create_next_handler(index + 1) async def current_handler(c: Any) -> None: @@ -705,7 +707,7 @@ def _create_streaming_handler_chain( """ def create_next_handler(index: int) -> Callable[[Any], Awaitable[None]]: - if index >= len(self._middlewares): + if index >= len(self._middleware): async def final_wrapper(c: Any) -> None: # If terminate was set, skip execution @@ -724,7 +726,7 @@ async def final_wrapper(c: Any) -> None: return final_wrapper - middleware = self._middlewares[index] + middleware = self._middleware[index] next_handler = create_next_handler(index + 1) async def current_handler(c: Any) -> None: @@ -745,20 +747,20 @@ class AgentMiddlewarePipeline(BaseMiddlewarePipeline): to process the agent invocation and pass control to the next middleware in the chain. """ - def __init__(self, middlewares: list[AgentMiddleware | AgentMiddlewareCallable] | None = None): + def __init__(self, middleware: Sequence[AgentMiddlewareTypes] | None = None): """Initialize the agent middleware pipeline. Args: - middlewares: The list of agent middleware to include in the pipeline. + middleware: The list of agent middleware to include in the pipeline. """ super().__init__() - self._middlewares: list[AgentMiddleware] = [] + self._middleware: list[AgentMiddleware] = [] - if middlewares: - for middleware in middlewares: - self._register_middleware(middleware) + if middleware: + for mdlware in middleware: + self._register_middleware(mdlware) - def _register_middleware(self, middleware: AgentMiddleware | AgentMiddlewareCallable) -> None: + def _register_middleware(self, middleware: AgentMiddlewareTypes) -> None: """Register an agent middleware item. Args: @@ -789,7 +791,7 @@ async def execute( context.messages = messages context.is_streaming = False - if not self._middlewares: + if not self._middleware: return await final_handler(context) # Store the final result @@ -841,7 +843,7 @@ async def execute_stream( context.messages = messages context.is_streaming = True - if not self._middlewares: + if not self._middleware: async for update in final_handler(context): yield update return @@ -874,18 +876,18 @@ class FunctionMiddlewarePipeline(BaseMiddlewarePipeline): to process the function invocation and pass control to the next middleware in the chain. """ - def __init__(self, middlewares: list[FunctionMiddleware | FunctionMiddlewareCallable] | None = None): + def __init__(self, middleware: Sequence[FunctionMiddleware | FunctionMiddlewareCallable] | None = None): """Initialize the function middleware pipeline. Args: - middlewares: The list of function middleware to include in the pipeline. + middleware: The list of function middleware to include in the pipeline. """ super().__init__() - self._middlewares: list[FunctionMiddleware] = [] + self._middleware: list[FunctionMiddleware] = [] - if middlewares: - for middleware in middlewares: - self._register_middleware(middleware) + if middleware: + for mdlware in middleware: + self._register_middleware(mdlware) def _register_middleware(self, middleware: FunctionMiddleware | FunctionMiddlewareCallable) -> None: """Register a function middleware item. @@ -917,7 +919,7 @@ async def execute( context.function = function context.arguments = arguments - if not self._middlewares: + if not self._middleware: return await final_handler(context) # Store the final result @@ -947,18 +949,18 @@ class ChatMiddlewarePipeline(BaseMiddlewarePipeline): to process the chat request and pass control to the next middleware in the chain. """ - def __init__(self, middlewares: list[ChatMiddleware | ChatMiddlewareCallable] | None = None): + def __init__(self, middleware: Sequence[ChatMiddleware | ChatMiddlewareCallable] | None = None): """Initialize the chat middleware pipeline. Args: - middlewares: The list of chat middleware to include in the pipeline. + middleware: The list of chat middleware to include in the pipeline. """ super().__init__() - self._middlewares: list[ChatMiddleware] = [] + self._middleware: list[ChatMiddleware] = [] - if middlewares: - for middleware in middlewares: - self._register_middleware(middleware) + if middleware: + for mdlware in middleware: + self._register_middleware(mdlware) def _register_middleware(self, middleware: ChatMiddleware | ChatMiddlewareCallable) -> None: """Register a chat middleware item. @@ -995,7 +997,7 @@ async def execute( context.messages = messages context.chat_options = chat_options - if not self._middlewares: + if not self._middleware: return await final_handler(context) # Store the final result @@ -1045,7 +1047,7 @@ async def execute_stream( context.chat_options = chat_options context.is_streaming = True - if not self._middlewares: + if not self._middleware: async for update in final_handler(context): yield update return @@ -1182,8 +1184,8 @@ async def run_stream(self, messages, **kwargs): original_run_stream = agent_class.run_stream # type: ignore[attr-defined] def _build_middleware_pipelines( - agent_level_middlewares: Middleware | list[Middleware] | None, - run_level_middlewares: Middleware | list[Middleware] | None = None, + agent_level_middlewares: Sequence[Middleware] | None, + run_level_middlewares: Sequence[Middleware] | None = None, ) -> tuple[AgentMiddlewarePipeline, FunctionMiddlewarePipeline, list[ChatMiddleware | ChatMiddlewareCallable]]: """Build fresh agent and function middleware pipelines from the provided middleware lists. @@ -1191,7 +1193,7 @@ def _build_middleware_pipelines( agent_level_middlewares: Agent-level middleware (executed first) run_level_middlewares: Run-level middleware (executed after agent middleware) """ - middleware = categorize_middleware(agent_level_middlewares, run_level_middlewares) + middleware = categorize_middleware(*(agent_level_middlewares or ()), *(run_level_middlewares or ())) return ( AgentMiddlewarePipeline(middleware["agent"]), # type: ignore[arg-type] @@ -1204,7 +1206,7 @@ async def middleware_enabled_run( messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, thread: Any = None, - middleware: Middleware | list[Middleware] | None = None, + middleware: Sequence[Middleware] | None = None, **kwargs: Any, ) -> AgentRunResponse: """Middleware-enabled run method.""" @@ -1253,7 +1255,7 @@ def middleware_enabled_run_stream( messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, *, thread: Any = None, - middleware: Middleware | list[Middleware] | None = None, + middleware: Sequence[Middleware] | None = None, **kwargs: Any, ) -> AsyncIterable[AgentRunResponseUpdate]: """Middleware-enabled run_stream method.""" @@ -1461,9 +1463,15 @@ def final_handler(ctx: ChatContext) -> Any: return chat_client_class +class MiddlewareDict(TypedDict): + agent: list[AgentMiddleware | AgentMiddlewareCallable] + function: list[FunctionMiddleware | FunctionMiddlewareCallable] + chat: list[ChatMiddleware | ChatMiddlewareCallable] + + def categorize_middleware( - *middleware_sources: Any | list[Any] | None, -) -> dict[str, list[Any]]: + *middleware_sources: Middleware | None, +) -> MiddlewareDict: """Categorize middleware from multiple sources into agent, function, and chat types. Args: @@ -1472,7 +1480,7 @@ def categorize_middleware( Returns: Dict with keys "agent", "function", "chat" containing lists of categorized middleware. """ - result: dict[str, list[Any]] = {"agent": [], "function": [], "chat": []} + result: MiddlewareDict = {"agent": [], "function": [], "chat": []} # Merge all middleware sources into a single list all_middleware: list[Any] = [] @@ -1495,11 +1503,11 @@ def categorize_middleware( # Always call _determine_middleware_type to ensure proper validation middleware_type = _determine_middleware_type(middleware) if middleware_type == MiddlewareType.AGENT: - result["agent"].append(middleware) + result["agent"].append(middleware) # type: ignore elif middleware_type == MiddlewareType.FUNCTION: - result["function"].append(middleware) + result["function"].append(middleware) # type: ignore elif middleware_type == MiddlewareType.CHAT: - result["chat"].append(middleware) + result["chat"].append(middleware) # type: ignore else: # Fallback to agent middleware for unknown types result["agent"].append(middleware) @@ -1508,7 +1516,7 @@ def categorize_middleware( def create_function_middleware_pipeline( - *middleware_sources: list[Middleware] | None, + *middleware_sources: Middleware, ) -> FunctionMiddlewarePipeline | None: """Create a function middleware pipeline from multiple middleware sources. @@ -1518,28 +1526,10 @@ def create_function_middleware_pipeline( Returns: A FunctionMiddlewarePipeline if function middleware is found, None otherwise. """ - middleware = categorize_middleware(*middleware_sources) - function_middlewares = middleware["function"] + function_middlewares = categorize_middleware(*middleware_sources)["function"] return FunctionMiddlewarePipeline(function_middlewares) if function_middlewares else None # type: ignore[arg-type] -def _merge_and_filter_chat_middleware( - instance_middleware: Any | list[Any] | None, - call_middleware: Any | list[Any] | None, -) -> list[ChatMiddleware | ChatMiddlewareCallable]: - """Merge instance-level and call-level middleware, filtering for chat middleware only. - - Args: - instance_middleware: Middleware defined at the instance level. - call_middleware: Middleware provided at the call level. - - Returns: - A merged list of chat middleware only. - """ - middleware = categorize_middleware(instance_middleware, call_middleware) - return middleware["chat"] # type: ignore[return-value] - - def extract_and_merge_function_middleware( chat_client: Any, kwargs: dict[str, Any] ) -> "FunctionMiddlewarePipeline | None": @@ -1556,7 +1546,7 @@ def extract_and_merge_function_middleware( existing_pipeline: FunctionMiddlewarePipeline | None = kwargs.get("_function_middleware_pipeline") # Get middleware sources - client_middleware = getattr(chat_client, "middleware", None) if hasattr(chat_client, "middleware") else None + client_middleware = getattr(chat_client, "middleware", None) run_level_middleware = kwargs.get("middleware") # If we have an existing pipeline but no additional middleware sources, return it directly @@ -1564,15 +1554,15 @@ def extract_and_merge_function_middleware( return existing_pipeline # If we have an existing pipeline with additional middleware, we need to merge - # Extract existing pipeline middlewares if present - cast to list[Middleware] for type compatibility - existing_middlewares: list[Middleware] | None = list(existing_pipeline._middlewares) if existing_pipeline else None + # Extract existing pipeline middleware if present - cast to list[Middleware] for type compatibility + existing_middleware: list[Middleware] | None = list(existing_pipeline._middleware) if existing_pipeline else None # Create combined pipeline from all sources using existing helper combined_pipeline = create_function_middleware_pipeline( - client_middleware, run_level_middleware, existing_middlewares + *(client_middleware or ()), *(run_level_middleware or ()), *(existing_middleware or ()) ) - # If we have an existing pipeline but combined is None (no new middlewares), return existing + # If we have an existing pipeline but combined is None (no new middleware), return existing if existing_pipeline and combined_pipeline is None: return existing_pipeline diff --git a/python/packages/core/agent_framework/_serialization.py b/python/packages/core/agent_framework/_serialization.py index cf28df2f4f..8aa9b6adcf 100644 --- a/python/packages/core/agent_framework/_serialization.py +++ b/python/packages/core/agent_framework/_serialization.py @@ -259,7 +259,7 @@ def __init__(self, **kwargs): agent = CustomAgent( - context_providers=[...], + context_provider=[...], middleware=[...] ) diff --git a/python/packages/core/agent_framework/_threads.py b/python/packages/core/agent_framework/_threads.py index 92469a78d5..e44c362324 100644 --- a/python/packages/core/agent_framework/_threads.py +++ b/python/packages/core/agent_framework/_threads.py @@ -3,7 +3,7 @@ from collections.abc import MutableMapping, Sequence from typing import Any, Protocol, TypeVar -from ._memory import AggregateContextProvider +from ._memory import ContextProvider from ._serialization import SerializationMixin from ._types import ChatMessage from .exceptions import AgentThreadException @@ -327,7 +327,7 @@ def __init__( *, service_thread_id: str | None = None, message_store: ChatMessageStoreProtocol | None = None, - context_provider: AggregateContextProvider | None = None, + context_provider: ContextProvider | None = None, ) -> None: """Initialize an AgentThread, do not use this method manually, always use: ``agent.get_new_thread()``. diff --git a/python/packages/core/agent_framework/_workflows/_handoff.py b/python/packages/core/agent_framework/_workflows/_handoff.py index 33c533c5e5..2dbf3bbb6e 100644 --- a/python/packages/core/agent_framework/_workflows/_handoff.py +++ b/python/packages/core/agent_framework/_workflows/_handoff.py @@ -110,7 +110,7 @@ def _clone_chat_agent(agent: ChatAgent) -> ChatAgent: name=agent.name, description=agent.description, chat_message_store_factory=agent.chat_message_store_factory, - context_providers=agent.context_provider, + context_provider=agent.context_provider, middleware=middleware, # Disable parallel tool calls to prevent the agent from invoking multiple handoff tools at once. allow_multiple_tool_calls=False, @@ -1041,7 +1041,7 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Han Each participant must have a unique identifier (name for agents, id for executors). The workflow will automatically create an alias map so agents can be referenced by - their name, display_name, or executor id when routing. + their name, id, or executor id when routing. Args: participants: Sequence of AgentProtocol or Executor instances. Each must have @@ -1091,7 +1091,7 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Han if isinstance(participant, Executor): identifier = participant.id elif isinstance(participant, AgentProtocol): - identifier = participant.display_name + identifier = participant.name or participant.id else: raise TypeError( f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}." @@ -1838,7 +1838,7 @@ def _resolve_executors(self) -> tuple[dict[str, Executor], dict[str, str]]: if isinstance(instance, Executor): identifier = instance.id elif isinstance(instance, AgentProtocol): - identifier = instance.display_name + identifier = instance.name or instance.id else: raise TypeError( f"Participants must be AgentProtocol or Executor instances. Got {type(instance).__name__}." @@ -2026,9 +2026,9 @@ def _prepare_agent_with_handoffs( tool_targets = self._apply_auto_tools(cloned_agent, target_agents) if tool_targets: middleware = _AutoHandoffMiddleware(tool_targets) - existing_middleware = list(cloned_agent.middleware or []) - existing_middleware.append(middleware) - cloned_agent.middleware = existing_middleware + existing_middlewares = list(cloned_agent.middleware or []) + existing_middlewares.append(middleware) + cloned_agent.middleware = existing_middlewares new_executor = AgentExecutor( cloned_agent, diff --git a/python/packages/core/agent_framework/_workflows/_participant_utils.py b/python/packages/core/agent_framework/_workflows/_participant_utils.py index a6f1cf2a84..2d0a259ac8 100644 --- a/python/packages/core/agent_framework/_workflows/_participant_utils.py +++ b/python/packages/core/agent_framework/_workflows/_participant_utils.py @@ -53,7 +53,7 @@ def wrap_participant(participant: AgentProtocol | Executor, *, executor_id: str f"Participants must implement AgentProtocol or be Executor instances. Got {type(participant).__name__}." ) - executor_id = executor_id or participant.display_name + executor_id = executor_id or participant.name or participant.id return AgentExecutor(participant, id=executor_id) @@ -89,11 +89,11 @@ def _register(values: Iterable[str | None]) -> None: if isinstance(participant, AgentProtocol): name = getattr(participant, "name", None) - display = getattr(participant, "display_name", None) - _register([name, display]) + agent_id = getattr(participant, "id", None) + _register([name, agent_id]) else: - display = getattr(participant, "display_name", None) - _register([display]) + participant_id = getattr(participant, "id", None) + _register([participant_id]) return aliases diff --git a/python/packages/core/agent_framework/_workflows/_sequential.py b/python/packages/core/agent_framework/_workflows/_sequential.py index 24ae4cda29..0c394574fa 100644 --- a/python/packages/core/agent_framework/_workflows/_sequential.py +++ b/python/packages/core/agent_framework/_workflows/_sequential.py @@ -290,7 +290,7 @@ def build(self) -> Workflow: for p in participants: if isinstance(p, (AgentProtocol, AgentExecutor)): - label = p.id if isinstance(p, AgentExecutor) else p.display_name + label = p.id if isinstance(p, AgentExecutor) else p.name if self._request_info_enabled: # Insert request info interceptor BEFORE the agent diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 26c261038b..2a6a20db14 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1346,7 +1346,7 @@ async def trace_run( operation_name=OtelAttr.AGENT_INVOKE_OPERATION, provider_name=provider_name, agent_id=self.id, - agent_name=self.display_name, + agent_name=self.name or self.id, agent_description=self.description, thread_id=thread.service_thread_id if thread else None, chat_options=getattr(self, "chat_options", None), @@ -1418,7 +1418,7 @@ async def trace_run_streaming( operation_name=OtelAttr.AGENT_INVOKE_OPERATION, provider_name=provider_name, agent_id=self.id, - agent_name=self.display_name, + agent_name=self.name or self.id, agent_description=self.description, thread_id=thread.service_thread_id if thread else None, chat_options=getattr(self, "chat_options", None), diff --git a/python/packages/core/tests/core/conftest.py b/python/packages/core/tests/core/conftest.py index ca524a4144..ab3aedb7d3 100644 --- a/python/packages/core/tests/core/conftest.py +++ b/python/packages/core/tests/core/conftest.py @@ -221,11 +221,6 @@ def name(self) -> str | None: """Returns the name of the agent.""" return "Name" - @property - def display_name(self) -> str: - """Returns the name of the agent.""" - return "Display Name" - @property def description(self) -> str | None: return "Description" diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 7611df0cb0..7d0c08e494 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -13,7 +13,6 @@ AgentRunResponse, AgentRunResponseUpdate, AgentThread, - AggregateContextProvider, ChatAgent, ChatClientProtocol, ChatMessage, @@ -66,7 +65,6 @@ async def test_chat_client_agent_init(chat_client: ChatClientProtocol) -> None: assert agent.id == agent_id assert agent.name is None assert agent.description == "Test" - assert agent.display_name == agent_id # Display name defaults to id if name is None async def test_chat_client_agent_init_with_name(chat_client: ChatClientProtocol) -> None: @@ -76,7 +74,6 @@ async def test_chat_client_agent_init_with_name(chat_client: ChatClientProtocol) assert agent.id == agent_id assert agent.name == "Test Agent" assert agent.description == "Test" - assert agent.display_name == "Test Agent" # Display name is the name if present async def test_chat_client_agent_run(chat_client: ChatClientProtocol) -> None: @@ -255,7 +252,7 @@ async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], * async def test_chat_agent_context_providers_model_invoking(chat_client: ChatClientProtocol) -> None: """Test that context providers' invoking is called during agent run.""" mock_provider = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="Test context instructions")]) - agent = ChatAgent(chat_client=chat_client, context_providers=mock_provider) + agent = ChatAgent(chat_client=chat_client, context_provider=mock_provider) await agent.run("Hello") @@ -272,7 +269,7 @@ async def test_chat_agent_context_providers_thread_created(chat_client_base: Cha ) ] - agent = ChatAgent(chat_client=chat_client_base, context_providers=mock_provider) + agent = ChatAgent(chat_client=chat_client_base, context_provider=mock_provider) await agent.run("Hello") @@ -283,7 +280,7 @@ async def test_chat_agent_context_providers_thread_created(chat_client_base: Cha async def test_chat_agent_context_providers_messages_adding(chat_client: ChatClientProtocol) -> None: """Test that context providers' invoked is called during agent run.""" mock_provider = MockContextProvider() - agent = ChatAgent(chat_client=chat_client, context_providers=mock_provider) + agent = ChatAgent(chat_client=chat_client, context_provider=mock_provider) await agent.run("Hello") @@ -295,7 +292,7 @@ async def test_chat_agent_context_providers_messages_adding(chat_client: ChatCli async def test_chat_agent_context_instructions_in_messages(chat_client: ChatClientProtocol) -> None: """Test that AI context instructions are included in messages.""" mock_provider = MockContextProvider(messages=[ChatMessage(role="system", text="Context-specific instructions")]) - agent = ChatAgent(chat_client=chat_client, instructions="Agent instructions", context_providers=mock_provider) + agent = ChatAgent(chat_client=chat_client, instructions="Agent instructions", context_provider=mock_provider) # We need to test the _prepare_thread_and_messages method directly _, _, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] @@ -314,7 +311,7 @@ async def test_chat_agent_context_instructions_in_messages(chat_client: ChatClie async def test_chat_agent_no_context_instructions(chat_client: ChatClientProtocol) -> None: """Test behavior when AI context has no instructions.""" mock_provider = MockContextProvider() - agent = ChatAgent(chat_client=chat_client, instructions="Agent instructions", context_providers=mock_provider) + agent = ChatAgent(chat_client=chat_client, instructions="Agent instructions", context_provider=mock_provider) _, _, messages = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] thread=None, input_messages=[ChatMessage(role=Role.USER, text="Hello")] @@ -329,7 +326,7 @@ async def test_chat_agent_no_context_instructions(chat_client: ChatClientProtoco async def test_chat_agent_run_stream_context_providers(chat_client: ChatClientProtocol) -> None: """Test that context providers work with run_stream method.""" mock_provider = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="Stream context instructions")]) - agent = ChatAgent(chat_client=chat_client, context_providers=mock_provider) + agent = ChatAgent(chat_client=chat_client, context_provider=mock_provider) # Collect all stream updates updates: list[AgentRunResponseUpdate] = [] @@ -343,44 +340,6 @@ async def test_chat_agent_run_stream_context_providers(chat_client: ChatClientPr assert mock_provider.invoked_called -async def test_chat_agent_multiple_context_providers(chat_client: ChatClientProtocol) -> None: - """Test that multiple context providers work together.""" - provider1 = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="First provider instructions")]) - provider2 = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="Second provider instructions")]) - - agent = ChatAgent(chat_client=chat_client, context_providers=[provider1, provider2]) - - await agent.run("Hello") - - # Both providers should be called - assert provider1.invoking_called - assert not provider1.thread_created_called - assert provider1.invoked_called - - assert provider2.invoking_called - assert not provider2.thread_created_called - assert provider2.invoked_called - - -async def test_chat_agent_aggregate_context_provider_combines_instructions() -> None: - """Test that AggregateContextProvider combines instructions from multiple providers.""" - provider1 = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="First instruction")]) - provider2 = MockContextProvider(messages=[ChatMessage(role=Role.SYSTEM, text="Second instruction")]) - - aggregate = AggregateContextProvider() - aggregate.providers.append(provider1) - aggregate.providers.append(provider2) - - # Test invoking combines instructions - result = await aggregate.invoking([ChatMessage(role=Role.USER, text="Test")]) - - assert result.messages - assert isinstance(result.messages[0], ChatMessage) - assert isinstance(result.messages[1], ChatMessage) - assert result.messages[0].text == "First instruction" - assert result.messages[1].text == "Second instruction" - - async def test_chat_agent_context_providers_with_thread_service_id(chat_client_base: ChatClientProtocol) -> None: """Test context providers with service-managed thread.""" mock_provider = MockContextProvider() @@ -391,7 +350,7 @@ async def test_chat_agent_context_providers_with_thread_service_id(chat_client_b ) ] - agent = ChatAgent(chat_client=chat_client_base, context_providers=mock_provider) + agent = ChatAgent(chat_client=chat_client_base, context_provider=mock_provider) # Use existing service-managed thread thread = agent.get_new_thread(service_thread_id="existing-thread-id") diff --git a/python/packages/core/tests/core/test_memory.py b/python/packages/core/tests/core/test_memory.py index f3750f20e2..6cc7ba436e 100644 --- a/python/packages/core/tests/core/test_memory.py +++ b/python/packages/core/tests/core/test_memory.py @@ -2,10 +2,9 @@ from collections.abc import MutableSequence from typing import Any -from unittest.mock import AsyncMock, Mock -from agent_framework import ChatMessage, Role, TextContent -from agent_framework._memory import AggregateContextProvider, Context, ContextProvider +from agent_framework import ChatMessage, Role +from agent_framework._memory import Context, ContextProvider class MockContextProvider(ContextProvider): @@ -45,252 +44,50 @@ async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], * return context -class TestAggregateContextProvider: - """Tests for AggregateContextProvider class.""" - - def test_init_with_no_providers(self) -> None: - """Test initialization with no providers.""" - aggregate = AggregateContextProvider() - assert aggregate.providers == [] - - def test_init_with_none_providers(self) -> None: - """Test initialization with None providers.""" - aggregate = AggregateContextProvider(None) - assert aggregate.providers == [] - - def test_init_with_providers(self) -> None: - """Test initialization with providers.""" - provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")]) - provider2 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 2")]) - provider3 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 3")]) - providers = [provider1, provider2, provider3] - - aggregate = AggregateContextProvider(providers) - assert len(aggregate.providers) == 3 - assert aggregate.providers[0] is provider1 - assert aggregate.providers[1] is provider2 - assert aggregate.providers[2] is provider3 - - def test_add_provider(self) -> None: - """Test adding a provider.""" - aggregate = AggregateContextProvider() - provider = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions")]) - - aggregate.add(provider) - assert len(aggregate.providers) == 1 - assert aggregate.providers[0] is provider - - def test_add_multiple_providers(self) -> None: - """Test adding multiple providers.""" - aggregate = AggregateContextProvider() - provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")]) - provider2 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 2")]) - - aggregate.add(provider1) - aggregate.add(provider2) - - assert len(aggregate.providers) == 2 - assert aggregate.providers[0] is provider1 - assert aggregate.providers[1] is provider2 - - async def test_thread_created_with_no_providers(self) -> None: - """Test thread_created with no providers.""" - aggregate = AggregateContextProvider() - - # Should not raise an exception - await aggregate.thread_created("thread-123") - - async def test_thread_created_with_providers(self) -> None: - """Test thread_created calls all providers.""" - provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")]) - provider2 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 2")]) - aggregate = AggregateContextProvider([provider1, provider2]) - - thread_id = "thread-123" - await aggregate.thread_created(thread_id) - - assert provider1.thread_created_called - assert provider1.thread_created_thread_id == thread_id - assert provider2.thread_created_called - assert provider2.thread_created_thread_id == thread_id - - async def test_thread_created_with_none_thread_id(self) -> None: - """Test thread_created with None thread_id.""" - provider = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions")]) - aggregate = AggregateContextProvider([provider]) - - await aggregate.thread_created(None) +class TestContext: + """Tests for Context class.""" + def test_context_default_values(self) -> None: + """Test Context has correct default values.""" + context = Context() + assert context.instructions is None + assert context.messages == [] + assert context.tools == [] + + def test_context_with_values(self) -> None: + """Test Context can be initialized with values.""" + messages = [ChatMessage(role=Role.USER, text="Test message")] + context = Context(instructions="Test instructions", messages=messages) + assert context.instructions == "Test instructions" + assert len(context.messages) == 1 + assert context.messages[0].text == "Test message" + + +class TestContextProvider: + """Tests for ContextProvider class.""" + + async def test_thread_created(self) -> None: + """Test thread_created is called.""" + provider = MockContextProvider() + await provider.thread_created("test-thread-id") assert provider.thread_created_called - assert provider.thread_created_thread_id is None - - async def test_messages_adding_with_no_providers(self) -> None: - """Test invoked with no providers.""" - aggregate = AggregateContextProvider() - message = ChatMessage(text="Hello", role=Role.USER) - - # Should not raise an exception - await aggregate.invoked(message) - - async def test_messages_adding_with_single_message(self) -> None: - """Test invoked with a single message.""" - provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")]) - provider2 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 2")]) - aggregate = AggregateContextProvider([provider1, provider2]) - - message = ChatMessage(text="Hello", role=Role.USER) - await aggregate.invoked(message) - - assert provider1.invoked_called - assert provider1.new_messages == message - assert provider2.invoked_called - assert provider2.new_messages == message - - async def test_messages_adding_with_message_sequence(self) -> None: - """Test invoked with a sequence of messages.""" - provider = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions")]) - aggregate = AggregateContextProvider([provider]) - - messages = [ - ChatMessage(text="Hello", role=Role.USER), - ChatMessage(text="Hi there", role=Role.ASSISTANT), - ] - await aggregate.invoked(messages) + assert provider.thread_created_thread_id == "test-thread-id" + async def test_invoked(self) -> None: + """Test invoked is called.""" + provider = MockContextProvider() + message = ChatMessage(role=Role.USER, text="Test message") + await provider.invoked(message) assert provider.invoked_called - assert provider.new_messages == messages - - async def test_model_invoking_with_no_providers(self) -> None: - """Test invoking with no providers.""" - aggregate = AggregateContextProvider() - message = ChatMessage(text="Hello", role=Role.USER) - - context = await aggregate.invoking(message) - - assert isinstance(context, Context) - assert not context.messages - - async def test_model_invoking_with_single_provider(self) -> None: - """Test invoking with a single provider.""" - provider = MockContextProvider(messages=[ChatMessage(role="user", text="Test instructions")]) - aggregate = AggregateContextProvider([provider]) - - message = [ChatMessage(text="Hello", role=Role.USER)] - context = await aggregate.invoking(message) + assert provider.new_messages == message + async def test_invoking(self) -> None: + """Test invoking is called and returns context.""" + provider = MockContextProvider(messages=[ChatMessage(role=Role.USER, text="Context message")]) + message = ChatMessage(role=Role.USER, text="Test message") + context = await provider.invoking(message) assert provider.invoking_called assert provider.model_invoking_messages == message - assert isinstance(context, Context) - - assert context.messages - assert isinstance(context.messages[0].contents[0], TextContent) - assert context.messages[0].text == "Test instructions" - - async def test_model_invoking_with_multiple_providers(self) -> None: - """Test invoking combines contexts from multiple providers.""" - provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")]) - provider2 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 2")]) - provider3 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 3")]) - aggregate = AggregateContextProvider([provider1, provider2, provider3]) - - messages = [ChatMessage(text="Hello", role=Role.USER)] - context = await aggregate.invoking(messages) - - assert provider1.invoking_called - assert provider1.model_invoking_messages == messages - assert provider2.invoking_called - assert provider2.model_invoking_messages == messages - assert provider3.invoking_called - assert provider3.model_invoking_messages == messages - - assert isinstance(context, Context) - - assert context.messages - assert isinstance(context.messages[0].contents[0], TextContent) - assert isinstance(context.messages[1].contents[0], TextContent) - assert isinstance(context.messages[2].contents[0], TextContent) - assert context.messages[0].text == "Instructions 1" - assert context.messages[1].text == "Instructions 2" - assert context.messages[2].text == "Instructions 3" - - async def test_model_invoking_with_none_instructions(self) -> None: - """Test invoking filters out None instructions.""" - provider1 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 1")]) - provider2 = MockContextProvider(messages=None) # None instructions - provider3 = MockContextProvider(messages=[ChatMessage(role="user", text="Instructions 3")]) - aggregate = AggregateContextProvider([provider1, provider2, provider3]) - - message = ChatMessage(text="Hello", role=Role.USER) - context = await aggregate.invoking(message) - - assert isinstance(context, Context) - assert context.messages - assert isinstance(context.messages[0].contents[0], TextContent) - assert isinstance(context.messages[1].contents[0], TextContent) - assert context.messages[0].text == "Instructions 1" - assert context.messages[1].text == "Instructions 3" - - async def test_model_invoking_with_all_none_instructions(self) -> None: - """Test invoking when all providers return None instructions.""" - provider1 = MockContextProvider(None) - provider2 = MockContextProvider(None) - aggregate = AggregateContextProvider([provider1, provider2]) - - message = ChatMessage(text="Hello", role=Role.USER) - context = await aggregate.invoking(message) - - assert isinstance(context, Context) - assert not context.messages - - async def test_model_invoking_with_mutable_sequence(self) -> None: - """Test invoking with MutableSequence of messages.""" - provider = MockContextProvider(messages=[ChatMessage(role="user", text="Test instructions")]) - aggregate = AggregateContextProvider([provider]) - - messages = [ChatMessage(text="Hello", role=Role.USER)] - context = await aggregate.invoking(messages) - - assert provider.invoking_called - assert provider.model_invoking_messages == messages - assert isinstance(context, Context) - assert context.messages - assert isinstance(context.messages[0].contents[0], TextContent) - assert context.messages[0].text == "Test instructions" - - async def test_async_methods_concurrent_execution(self) -> None: - """Test that async methods execute providers concurrently.""" - # Use AsyncMock to verify concurrent execution - provider1 = Mock(spec=ContextProvider) - provider1.thread_created = AsyncMock() - provider1.invoked = AsyncMock() - provider1.invoking = AsyncMock(return_value=Context(messages=[ChatMessage(role="user", text="Test 1")])) - - provider2 = Mock(spec=ContextProvider) - provider2.thread_created = AsyncMock() - provider2.invoked = AsyncMock() - provider2.invoking = AsyncMock(return_value=Context(messages=[ChatMessage(role="user", text="Test 2")])) - - aggregate = AggregateContextProvider([provider1, provider2]) - - # Test thread_created - await aggregate.thread_created("thread-123") - provider1.thread_created.assert_called_once_with("thread-123") - provider2.thread_created.assert_called_once_with("thread-123") - - # Test invoked - message = ChatMessage(text="Hello", role=Role.USER) - await aggregate.invoked(message) - provider1.invoked.assert_called_once_with( - request_messages=message, response_messages=None, invoke_exception=None - ) - provider2.invoked.assert_called_once_with( - request_messages=message, response_messages=None, invoke_exception=None - ) - - # Test invoking - context = await aggregate.invoking(message) - provider1.invoking.assert_called_once_with(message) - provider2.invoking.assert_called_once_with(message) - assert context.messages - assert context.messages[0].text == "Test 1" - assert context.messages[1].text == "Test 2" + assert context.messages is not None + assert len(context.messages) == 1 + assert context.messages[0].text == "Context message" diff --git a/python/packages/core/tests/core/test_middleware.py b/python/packages/core/tests/core/test_middleware.py index a84c8927d0..552d3a3c48 100644 --- a/python/packages/core/tests/core/test_middleware.py +++ b/python/packages/core/tests/core/test_middleware.py @@ -148,7 +148,7 @@ async def process(self, context: AgentRunContext, next: Any) -> None: context.terminate = True def test_init_empty(self) -> None: - """Test AgentMiddlewarePipeline initialization with no middlewares.""" + """Test AgentMiddlewarePipeline initialization with no middleware.""" pipeline = AgentMiddlewarePipeline() assert not pipeline.has_middlewares @@ -457,7 +457,7 @@ async def final_handler(ctx: FunctionInvocationContext) -> str: assert execution_order == ["handler"] def test_init_empty(self) -> None: - """Test FunctionMiddlewarePipeline initialization with no middlewares.""" + """Test FunctionMiddlewarePipeline initialization with no middleware.""" pipeline = FunctionMiddlewarePipeline() assert not pipeline.has_middlewares @@ -539,7 +539,7 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai context.terminate = True def test_init_empty(self) -> None: - """Test ChatMiddlewarePipeline initialization with no middlewares.""" + """Test ChatMiddlewarePipeline initialization with no middleware.""" pipeline = ChatMiddlewarePipeline() assert not pipeline.has_middlewares @@ -979,7 +979,7 @@ class TestMultipleMiddlewareOrdering: """Test cases for multiple middleware execution order.""" async def test_agent_middleware_execution_order(self, mock_agent: AgentProtocol) -> None: - """Test that multiple agent middlewares execute in registration order.""" + """Test that multiple agent middleware execute in registration order.""" execution_order: list[str] = [] class FirstMiddleware(AgentMiddleware): @@ -1006,8 +1006,8 @@ async def process( await next(context) execution_order.append("third_after") - middlewares = [FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()] - pipeline = AgentMiddlewarePipeline(middlewares) # type: ignore + middleware = [FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()] + pipeline = AgentMiddlewarePipeline(middleware) # type: ignore messages = [ChatMessage(role=Role.USER, text="test")] context = AgentRunContext(agent=mock_agent, messages=messages) @@ -1030,7 +1030,7 @@ async def final_handler(ctx: AgentRunContext) -> AgentRunResponse: assert execution_order == expected_order async def test_function_middleware_execution_order(self, mock_function: AIFunction[Any, Any]) -> None: - """Test that multiple function middlewares execute in registration order.""" + """Test that multiple function middleware execute in registration order.""" execution_order: list[str] = [] class FirstMiddleware(FunctionMiddleware): @@ -1053,8 +1053,8 @@ async def process( await next(context) execution_order.append("second_after") - middlewares = [FirstMiddleware(), SecondMiddleware()] - pipeline = FunctionMiddlewarePipeline(middlewares) # type: ignore + middleware = [FirstMiddleware(), SecondMiddleware()] + pipeline = FunctionMiddlewarePipeline(middleware) # type: ignore arguments = FunctionTestArgs(name="test") context = FunctionInvocationContext(function=mock_function, arguments=arguments) @@ -1069,7 +1069,7 @@ async def final_handler(ctx: FunctionInvocationContext) -> str: assert execution_order == expected_order async def test_chat_middleware_execution_order(self, mock_chat_client: Any) -> None: - """Test that multiple chat middlewares execute in registration order.""" + """Test that multiple chat middleware execute in registration order.""" execution_order: list[str] = [] class FirstChatMiddleware(ChatMiddleware): @@ -1090,8 +1090,8 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai await next(context) execution_order.append("third_after") - middlewares = [FirstChatMiddleware(), SecondChatMiddleware(), ThirdChatMiddleware()] - pipeline = ChatMiddlewarePipeline(middlewares) # type: ignore + middleware = [FirstChatMiddleware(), SecondChatMiddleware(), ThirdChatMiddleware()] + pipeline = ChatMiddlewarePipeline(middleware) # type: ignore messages = [ChatMessage(role=Role.USER, text="test")] chat_options = ChatOptions() context = ChatContext(chat_client=mock_chat_client, messages=messages, chat_options=chat_options) @@ -1542,7 +1542,7 @@ async def final_handler(ctx: FunctionInvocationContext) -> str: assert context.result is None async def test_multiple_middlewares_early_stop(self, mock_agent: AgentProtocol) -> None: - """Test that when first middleware doesn't call next(), subsequent middlewares are not called.""" + """Test that when first middleware doesn't call next(), subsequent middleware are not called.""" execution_order: list[str] = [] class FirstMiddleware(AgentMiddleware): @@ -1641,7 +1641,7 @@ async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: assert context.result is None async def test_multiple_chat_middlewares_early_stop(self, mock_chat_client: Any) -> None: - """Test that when first chat middleware doesn't call next(), subsequent middlewares are not called.""" + """Test that when first chat middleware doesn't call next(), subsequent middleware are not called.""" execution_order: list[str] = [] class FirstChatMiddleware(ChatMiddleware): diff --git a/python/packages/core/tests/core/test_middleware_with_agent.py b/python/packages/core/tests/core/test_middleware_with_agent.py index 6cb41f674b..147c474800 100644 --- a/python/packages/core/tests/core/test_middleware_with_agent.py +++ b/python/packages/core/tests/core/test_middleware_with_agent.py @@ -418,7 +418,7 @@ class TestChatAgentMultipleMiddlewareOrdering: """Test cases for multiple middleware execution order with ChatAgent.""" async def test_multiple_agent_middleware_execution_order(self, chat_client: "MockChatClient") -> None: - """Test that multiple agent middlewares execute in correct order with ChatAgent.""" + """Test that multiple agent middleware execute in correct order with ChatAgent.""" execution_order: list[str] = [] class OrderedMiddleware(AgentMiddleware): @@ -432,12 +432,12 @@ async def process( await next(context) execution_order.append(f"{self.name}_after") - # Create multiple middlewares + # Create multiple middleware middleware1 = OrderedMiddleware("first") middleware2 = OrderedMiddleware("second") middleware3 = OrderedMiddleware("third") - # Create ChatAgent with multiple middlewares + # Create ChatAgent with multiple middleware agent = ChatAgent(chat_client=chat_client, middleware=[middleware1, middleware2, middleware3]) # Execute the agent @@ -453,7 +453,7 @@ async def process( assert execution_order == expected_order async def test_mixed_middleware_types_with_chat_agent(self, chat_client: "MockChatClient") -> None: - """Test mixed class and function-based middlewares with ChatAgent.""" + """Test mixed class and function-based middleware with ChatAgent.""" execution_order: list[str] = [] class ClassAgentMiddleware(AgentMiddleware): @@ -507,8 +507,8 @@ async def function_function_middleware( assert response is not None assert chat_client.call_count == 1 - # Verify that agent middlewares were executed in correct order - # (Function middlewares won't execute since no functions are called) + # Verify that agent middleware were executed in correct order + # (Function middleware won't execute since no functions are called) expected_order = ["class_agent_before", "function_agent_before", "function_agent_after", "class_agent_after"] assert execution_order == expected_order @@ -999,7 +999,7 @@ async def test_run_level_middleware_isolation(self, chat_client: "MockChatClient # Clear execution log execution_log.clear() - # Fourth run with both run middlewares - should see both + # Fourth run with both run middleware - should see both await agent.run("Test message 4", middleware=[run_middleware1, run_middleware2]) assert execution_log == ["run1_start", "run2_start", "run2_end", "run1_end"] diff --git a/python/packages/core/tests/core/test_observability.py b/python/packages/core/tests/core/test_observability.py index 8528295406..38835fbc91 100644 --- a/python/packages/core/tests/core/test_observability.py +++ b/python/packages/core/tests/core/test_observability.py @@ -342,7 +342,6 @@ class MockChatClientAgent: def __init__(self): self.id = "test_agent_id" self.name = "test_agent" - self.display_name = "Test Agent" self.description = "Test agent description" async def run(self, messages=None, *, thread=None, **kwargs): @@ -384,7 +383,6 @@ class MockAgent: def __init__(self): self.id = "test_agent_id" self.name = "test_agent" - self.display_name = "Test Agent" async def run(self, messages=None, *, thread=None, **kwargs): return Mock() @@ -406,7 +404,6 @@ class MockChatClientAgent: def __init__(self): self.id = "test_agent_id" self.name = "test_agent" - self.display_name = "Test Agent" self.description = "Test agent description" self.chat_options = ChatOptions(model_id="TestModel") @@ -441,10 +438,10 @@ async def test_agent_instrumentation_enabled( spans = span_exporter.get_finished_spans() assert len(spans) == 1 span = spans[0] - assert span.name == "invoke_agent Test Agent" + assert span.name == "invoke_agent test_agent" assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.AGENT_INVOKE_OPERATION assert span.attributes[OtelAttr.AGENT_ID] == "test_agent_id" - assert span.attributes[OtelAttr.AGENT_NAME] == "Test Agent" + assert span.attributes[OtelAttr.AGENT_NAME] == "test_agent" assert span.attributes[OtelAttr.AGENT_DESCRIPTION] == "Test agent description" assert span.attributes[SpanAttributes.LLM_REQUEST_MODEL] == "TestModel" assert span.attributes[OtelAttr.INPUT_TOKENS] == 15 @@ -469,10 +466,10 @@ async def test_agent_streaming_response_with_diagnostics_enabled_via_decorator( spans = span_exporter.get_finished_spans() assert len(spans) == 1 span = spans[0] - assert span.name == "invoke_agent Test Agent" + assert span.name == "invoke_agent test_agent" assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.AGENT_INVOKE_OPERATION assert span.attributes[OtelAttr.AGENT_ID] == "test_agent_id" - assert span.attributes[OtelAttr.AGENT_NAME] == "Test Agent" + assert span.attributes[OtelAttr.AGENT_NAME] == "test_agent" assert span.attributes[OtelAttr.AGENT_DESCRIPTION] == "Test agent description" assert span.attributes[SpanAttributes.LLM_REQUEST_MODEL] == "TestModel" if enable_sensitive_data: diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 2815c3152c..1485c9f523 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -38,7 +38,7 @@ async def run( # type: ignore[override] ) -> AgentRunResponse: self.call_count += 1 return AgentRunResponse( - messages=[ChatMessage(role=Role.ASSISTANT, text=f"Response #{self.call_count}: {self.display_name}")] + messages=[ChatMessage(role=Role.ASSISTANT, text=f"Response #{self.call_count}: {self.name}")] ) async def run_stream( # type: ignore[override] @@ -49,7 +49,7 @@ async def run_stream( # type: ignore[override] **kwargs: Any, ) -> AsyncIterable[AgentRunResponseUpdate]: self.call_count += 1 - yield AgentRunResponseUpdate(contents=[TextContent(text=f"Response #{self.call_count}: {self.display_name}")]) + yield AgentRunResponseUpdate(contents=[TextContent(text=f"Response #{self.call_count}: {self.name}")]) async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: diff --git a/python/packages/core/tests/workflow/test_handoff.py b/python/packages/core/tests/workflow/test_handoff.py index d0d5092323..1daeef3943 100644 --- a/python/packages/core/tests/workflow/test_handoff.py +++ b/python/packages/core/tests/workflow/test_handoff.py @@ -78,7 +78,7 @@ def __init__( text_handoff: bool = False, extra_properties: dict[str, object] | None = None, ) -> None: - super().__init__(id=name, name=name, display_name=name) + super().__init__(id=name, name=name) self._agent_name = name self.handoff_to = handoff_to self.calls: list[list[ChatMessage]] = [] @@ -102,7 +102,7 @@ async def run( # type: ignore[override] reply = ChatMessage( role=Role.ASSISTANT, contents=contents, - author_name=self.display_name, + author_name=self.name, additional_properties=additional_properties, ) return AgentRunResponse(messages=[reply]) diff --git a/python/packages/core/tests/workflow/test_sequential.py b/python/packages/core/tests/workflow/test_sequential.py index 8ff0098c38..15d6e0c822 100644 --- a/python/packages/core/tests/workflow/test_sequential.py +++ b/python/packages/core/tests/workflow/test_sequential.py @@ -35,7 +35,7 @@ async def run( # type: ignore[override] thread: AgentThread | None = None, **kwargs: Any, ) -> AgentRunResponse: - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=f"{self.display_name} reply")]) + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=f"{self.name} reply")]) async def run_stream( # type: ignore[override] self, @@ -45,7 +45,7 @@ async def run_stream( # type: ignore[override] **kwargs: Any, ) -> AsyncIterable[AgentRunResponseUpdate]: # Minimal async generator with one assistant update - yield AgentRunResponseUpdate(contents=[TextContent(text=f"{self.display_name} reply")]) + yield AgentRunResponseUpdate(contents=[TextContent(text=f"{self.name} reply")]) class _SummarizerExec(Executor): diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index 5b7637057b..d0140dc893 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -60,7 +60,7 @@ async def run( **kwargs: Any, ) -> AgentRunResponse: self.captured_kwargs.append(dict(kwargs)) - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=f"{self.display_name} response")]) + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=f"{self.name} response")]) async def run_stream( self, @@ -70,7 +70,7 @@ async def run_stream( **kwargs: Any, ) -> AsyncIterable[AgentRunResponseUpdate]: self.captured_kwargs.append(dict(kwargs)) - yield AgentRunResponseUpdate(contents=[TextContent(text=f"{self.display_name} response")]) + yield AgentRunResponseUpdate(contents=[TextContent(text=f"{self.name} response")]) class _EchoAgent(BaseAgent): @@ -83,7 +83,7 @@ async def run( thread: AgentThread | None = None, **kwargs: Any, ) -> AgentRunResponse: - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=f"{self.display_name} reply")]) + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=f"{self.name} reply")]) async def run_stream( self, @@ -92,7 +92,7 @@ async def run_stream( thread: AgentThread | None = None, **kwargs: Any, ) -> AsyncIterable[AgentRunResponseUpdate]: - yield AgentRunResponseUpdate(contents=[TextContent(text=f"{self.display_name} reply")]) + yield AgentRunResponseUpdate(contents=[TextContent(text=f"{self.name} reply")]) # region Sequential Builder Tests diff --git a/python/packages/devui/agent_framework_devui/_discovery.py b/python/packages/devui/agent_framework_devui/_discovery.py index 9f8fcf0542..46169481f2 100644 --- a/python/packages/devui/agent_framework_devui/_discovery.py +++ b/python/packages/devui/agent_framework_devui/_discovery.py @@ -346,8 +346,8 @@ async def create_entity_info_from_object( instructions = None model = None chat_client_type = None - context_providers_list = None - middleware_list = None + context_provider_list = None + middlewares_list = None if entity_type == "agent": from ._utils import extract_agent_metadata @@ -356,8 +356,8 @@ async def create_entity_info_from_object( instructions = agent_meta["instructions"] model = agent_meta["model"] chat_client_type = agent_meta["chat_client_type"] - context_providers_list = agent_meta["context_providers"] - middleware_list = agent_meta["middleware"] + context_provider_list = agent_meta["context_provider"] + middlewares_list = agent_meta["middleware"] # Log helpful info about agent capabilities (before creating EntityInfo) if entity_type == "agent": @@ -395,8 +395,8 @@ async def create_entity_info_from_object( instructions=instructions, model_id=model, chat_client_type=chat_client_type, - context_providers=context_providers_list, - middleware=middleware_list, + context_provider=context_provider_list, + middleware=middlewares_list, executors=tools_list if entity_type == "workflow" else [], input_schema={"type": "string"}, # Default schema start_executor_id=tools_list[0] if tools_list and entity_type == "workflow" else None, @@ -829,8 +829,8 @@ async def _register_entity_from_object( instructions = None model = None chat_client_type = None - context_providers_list = None - middleware_list = None + context_provider_list = None + middlewares_list = None if obj_type == "agent": from ._utils import extract_agent_metadata @@ -839,8 +839,8 @@ async def _register_entity_from_object( instructions = agent_meta["instructions"] model = agent_meta["model"] chat_client_type = agent_meta["chat_client_type"] - context_providers_list = agent_meta["context_providers"] - middleware_list = agent_meta["middleware"] + context_provider_list = agent_meta["context_provider"] + middlewares_list = agent_meta["middleware"] entity_info = EntityInfo( id=entity_id, @@ -852,8 +852,8 @@ async def _register_entity_from_object( instructions=instructions, model_id=model, chat_client_type=chat_client_type, - context_providers=context_providers_list, - middleware=middleware_list, + context_provider=context_provider_list, + middleware=middlewares_list, metadata={ "module_path": module_path, "entity_type": obj_type, diff --git a/python/packages/devui/agent_framework_devui/_utils.py b/python/packages/devui/agent_framework_devui/_utils.py index 3c17c072f7..5b0e218f7a 100644 --- a/python/packages/devui/agent_framework_devui/_utils.py +++ b/python/packages/devui/agent_framework_devui/_utils.py @@ -32,7 +32,7 @@ def extract_agent_metadata(entity_object: Any) -> dict[str, Any]: "instructions": None, "model": None, "chat_client_type": None, - "context_providers": None, + "context_provider": None, "middleware": None, } @@ -60,20 +60,20 @@ def extract_agent_metadata(entity_object: Any) -> dict[str, Any]: and entity_object.context_provider and hasattr(entity_object.context_provider, "__class__") ): - metadata["context_providers"] = [entity_object.context_provider.__class__.__name__] # type: ignore + metadata["context_provider"] = [entity_object.context_provider.__class__.__name__] # type: ignore # Try to get middleware if hasattr(entity_object, "middleware") and entity_object.middleware: - middleware_list: list[str] = [] + middlewares_list: list[str] = [] for m in entity_object.middleware: # Try multiple ways to get a good name for middleware if hasattr(m, "__name__"): # Function or callable - middleware_list.append(m.__name__) + middlewares_list.append(m.__name__) elif hasattr(m, "__class__"): # Class instance - middleware_list.append(m.__class__.__name__) + middlewares_list.append(m.__class__.__name__) else: - middleware_list.append(str(m)) - metadata["middleware"] = middleware_list # type: ignore + middlewares_list.append(str(m)) + metadata["middleware"] = middlewares_list # type: ignore return metadata diff --git a/python/packages/devui/agent_framework_devui/models/_discovery_models.py b/python/packages/devui/agent_framework_devui/models/_discovery_models.py index 382639b277..ff217a48d2 100644 --- a/python/packages/devui/agent_framework_devui/models/_discovery_models.py +++ b/python/packages/devui/agent_framework_devui/models/_discovery_models.py @@ -43,7 +43,7 @@ class EntityInfo(BaseModel): instructions: str | None = None model_id: str | None = None chat_client_type: str | None = None - context_providers: list[str] | None = None + context_provider: list[str] | None = None middleware: list[str] | None = None # Workflow-specific fields (populated only for detailed info requests) diff --git a/python/packages/devui/frontend/src/components/features/agent/agent-details-modal.tsx b/python/packages/devui/frontend/src/components/features/agent/agent-details-modal.tsx index 990231552b..f9fa4480a0 100644 --- a/python/packages/devui/frontend/src/components/features/agent/agent-details-modal.tsx +++ b/python/packages/devui/frontend/src/components/features/agent/agent-details-modal.tsx @@ -179,10 +179,10 @@ export function AgentDetailsModal({ )} - {/* Middleware */} + {/* Middlewares */} {agent.middleware && agent.middleware.length > 0 && ( } >