diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 034ed5e9ff..3a6d5b818c 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -4,13 +4,14 @@ import re import sys from abc import abstractmethod -from collections.abc import Collection, Sequence +from collections.abc import Callable, Collection, Sequence from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore from datetime import timedelta from functools import partial from typing import TYPE_CHECKING, Any, Literal import httpx +from anyio import ClosedResourceError from mcp import types from mcp.client.session import ClientSession from mcp.client.stdio import StdioServerParameters, stdio_client @@ -21,7 +22,11 @@ from mcp.shared.session import RequestResponder from pydantic import BaseModel, create_model -from ._tools import AIFunction, HostedMCPSpecificApproval, _build_pydantic_model_from_json_schema +from ._tools import ( + AIFunction, + HostedMCPSpecificApproval, + _build_pydantic_model_from_json_schema, +) from ._types import ( ChatMessage, Contents, @@ -329,7 +334,9 @@ def __init__( approval_mode: (Literal["always_require", "never_require"] | HostedMCPSpecificApproval | None) = None, allowed_tools: Collection[str] | None = None, load_tools: bool = True, + parse_tool_results: Literal[True] | Callable[[types.CallToolResult], Any] | None = True, load_prompts: bool = True, + parse_prompt_results: Literal[True] | Callable[[types.GetPromptResult], Any] | None = True, session: ClientSession | None = None, request_timeout: int | None = None, chat_client: "ChatClientProtocol | None" = None, @@ -347,7 +354,9 @@ def __init__( self.allowed_tools = allowed_tools self.additional_properties = additional_properties self.load_tools_flag = load_tools + self.parse_tool_results = parse_tool_results self.load_prompts_flag = load_prompts + self.parse_prompt_results = parse_prompt_results self._exit_stack = AsyncExitStack() self.session = session self.request_timeout = request_timeout @@ -367,15 +376,23 @@ def functions(self) -> list[AIFunction[Any, Any]]: return self._functions return [func for func in self._functions if func.name in self.allowed_tools] - async def connect(self) -> None: + async def connect(self, *, reset: bool = False) -> None: """Connect to the MCP server. Establishes a connection to the MCP server, initializes the session, and loads tools and prompts if configured to do so. + Keyword Args: + reset: If True, forces a reconnection even if already connected. + Raises: ToolException: If connection or session initialization fails. """ + if reset: + await self._exit_stack.aclose() + self.session = None + self.is_connected = False + self._exit_stack = AsyncExitStack() if not self.session: try: transport = await self._exit_stack.enter_async_context(self.get_mcp_client()) @@ -565,86 +582,88 @@ async def load_prompts(self) -> None: """Load prompts from the MCP server. Retrieves available prompts from the connected MCP server and converts - them into AIFunction instances. + them into AIFunction instances. Handles pagination automatically. Raises: ToolExecutionException: If the MCP server is not connected. """ - if not self.session: - raise ToolExecutionException("MCP server not connected, please call connect() before using this method.") - try: - prompt_list = await self.session.list_prompts() - except Exception as exc: - logger.info( - "Prompt could not be loaded, you can exclude trying to load, by setting: load_prompts=False", - exc_info=exc, - ) - prompt_list = None - # Track existing function names to prevent duplicates existing_names = {func.name for func in self._functions} - for prompt in prompt_list.prompts if prompt_list else []: - local_name = _normalize_mcp_name(prompt.name) - - # Skip if already loaded - if local_name in existing_names: - continue - - input_model = _get_input_model_from_mcp_prompt(prompt) - approval_mode = self._determine_approval_mode(local_name) - func: AIFunction[BaseModel, list[ChatMessage]] = AIFunction( - func=partial(self.get_prompt, prompt.name), - name=local_name, - description=prompt.description or "", - approval_mode=approval_mode, - input_model=input_model, - ) - self._functions.append(func) - existing_names.add(local_name) + params: types.PaginatedRequestParams | None = None + while True: + # Ensure connection is still valid before each page request + await self._ensure_connected() + + prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr] + + for prompt in prompt_list.prompts: + local_name = _normalize_mcp_name(prompt.name) + + # Skip if already loaded + if local_name in existing_names: + continue + + input_model = _get_input_model_from_mcp_prompt(prompt) + approval_mode = self._determine_approval_mode(local_name) + func: AIFunction[BaseModel, list[ChatMessage] | Any | types.GetPromptResult] = AIFunction( + func=partial(self.get_prompt, prompt.name), + name=local_name, + description=prompt.description or "", + approval_mode=approval_mode, + input_model=input_model, + ) + self._functions.append(func) + existing_names.add(local_name) + + # Check if there are more pages + if not prompt_list or not prompt_list.nextCursor: + break + params = types.PaginatedRequestParams(cursor=prompt_list.nextCursor) async def load_tools(self) -> None: """Load tools from the MCP server. Retrieves available tools from the connected MCP server and converts - them into AIFunction instances. + them into AIFunction instances. Handles pagination automatically. Raises: ToolExecutionException: If the MCP server is not connected. """ - if not self.session: - raise ToolExecutionException("MCP server not connected, please call connect() before using this method.") - try: - tool_list = await self.session.list_tools() - except Exception as exc: - logger.info( - "Tools could not be loaded, you can exclude trying to load, by setting: load_tools=False", - exc_info=exc, - ) - tool_list = None - # Track existing function names to prevent duplicates existing_names = {func.name for func in self._functions} - for tool in tool_list.tools if tool_list else []: - local_name = _normalize_mcp_name(tool.name) - - # Skip if already loaded - if local_name in existing_names: - continue - - input_model = _get_input_model_from_mcp_tool(tool) - approval_mode = self._determine_approval_mode(local_name) - # Create AIFunctions out of each tool - func: AIFunction[BaseModel, list[Contents]] = AIFunction( - func=partial(self.call_tool, tool.name), - name=local_name, - description=tool.description or "", - approval_mode=approval_mode, - input_model=input_model, - ) - self._functions.append(func) - existing_names.add(local_name) + params: types.PaginatedRequestParams | None = None + while True: + # Ensure connection is still valid before each page request + await self._ensure_connected() + + tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr] + + for tool in tool_list.tools: + local_name = _normalize_mcp_name(tool.name) + + # Skip if already loaded + if local_name in existing_names: + continue + + input_model = _get_input_model_from_mcp_tool(tool) + approval_mode = self._determine_approval_mode(local_name) + # Create AIFunctions out of each tool + func: AIFunction[BaseModel, list[Contents] | Any | types.CallToolResult] = AIFunction( + func=partial(self.call_tool, tool.name), + name=local_name, + description=tool.description or "", + approval_mode=approval_mode, + input_model=input_model, + ) + self._functions.append(func) + existing_names.add(local_name) + + # Check if there are more pages + if not tool_list or not tool_list.nextCursor: + break + params = types.PaginatedRequestParams(cursor=tool_list.nextCursor) async def close(self) -> None: """Disconnect from the MCP server. @@ -664,7 +683,28 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: """ pass - async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents]: + async def _ensure_connected(self) -> None: + """Ensure the connection is valid, reconnecting if necessary. + + This method proactively checks if the connection is valid and + reconnects if it's not, avoiding the need to catch ClosedResourceError. + + Raises: + ToolExecutionException: If reconnection fails. + """ + try: + await self.session.send_ping() # type: ignore[union-attr] + except Exception: + logger.info("MCP connection invalid or closed. Reconnecting...") + try: + await self.connect(reset=True) + except Exception as ex: + raise ToolExecutionException( + "Failed to establish MCP connection.", + inner_exception=ex, + ) from ex + + async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents] | Any | types.CallToolResult: """Call a tool with the given arguments. Args: @@ -680,8 +720,6 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents]: ToolExecutionException: If the MCP server is not connected, tools are not loaded, or the tool call fails. """ - if not self.session: - raise ToolExecutionException("MCP server not connected, please call connect() before using this method.") if not self.load_tools_flag: raise ToolExecutionException( "Tools are not loaded for this server, please set load_tools=True in the constructor." @@ -692,16 +730,44 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents]: filtered_kwargs = { k: v for k, v in kwargs.items() if k not in {"chat_options", "tools", "tool_choice", "thread"} } - try: - return _parse_contents_from_mcp_tool_result( - await self.session.call_tool(tool_name, arguments=filtered_kwargs) - ) - except McpError as mcp_exc: - raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc - except Exception as ex: - raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex - async def get_prompt(self, prompt_name: str, **kwargs: Any) -> list[ChatMessage]: + # Try the operation, reconnecting once if the connection is closed + for attempt in range(2): + try: + result = await self.session.call_tool(tool_name, arguments=filtered_kwargs) # type: ignore + if self.parse_tool_results is None: + return result + if self.parse_tool_results is True: + return _parse_contents_from_mcp_tool_result(result) + if callable(self.parse_tool_results): + return self.parse_tool_results(result) + return result + except ClosedResourceError as cl_ex: + if attempt == 0: + # First attempt failed, try reconnecting + logger.info("MCP connection closed unexpectedly. Reconnecting...") + try: + await self.connect(reset=True) + continue # Retry the operation + except Exception as reconn_ex: + raise ToolExecutionException( + "Failed to reconnect to MCP server.", + inner_exception=reconn_ex, + ) from reconn_ex + else: + # Second attempt also failed, give up + logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}") + raise ToolExecutionException( + f"Failed to call tool '{tool_name}' - connection lost.", + inner_exception=cl_ex, + ) from cl_ex + except McpError as mcp_exc: + raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc + except Exception as ex: + raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex + raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.") + + async def get_prompt(self, prompt_name: str, **kwargs: Any) -> list[ChatMessage] | Any | types.GetPromptResult: """Call a prompt with the given arguments. Args: @@ -717,19 +783,46 @@ async def get_prompt(self, prompt_name: str, **kwargs: Any) -> list[ChatMessage] ToolExecutionException: If the MCP server is not connected, prompts are not loaded, or the prompt call fails. """ - if not self.session: - raise ToolExecutionException("MCP server not connected, please call connect() before using this method.") if not self.load_prompts_flag: raise ToolExecutionException( "Prompts are not loaded for this server, please set load_prompts=True in the constructor." ) - try: - prompt_result = await self.session.get_prompt(prompt_name, arguments=kwargs) - return [_parse_message_from_mcp(message) for message in prompt_result.messages] - except McpError as mcp_exc: - raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc - except Exception as ex: - raise ToolExecutionException(f"Failed to call prompt '{prompt_name}'.", inner_exception=ex) from ex + + # Try the operation, reconnecting once if the connection is closed + for attempt in range(2): + try: + prompt_result = await self.session.get_prompt(prompt_name, arguments=kwargs) # type: ignore + if self.parse_prompt_results is None: + return prompt_result + if self.parse_prompt_results is True: + return [_parse_message_from_mcp(message) for message in prompt_result.messages] + if callable(self.parse_prompt_results): + return self.parse_prompt_results(prompt_result) + return prompt_result + except ClosedResourceError as cl_ex: + if attempt == 0: + # First attempt failed, try reconnecting + logger.info("MCP connection closed unexpectedly. Reconnecting...") + try: + await self.connect(reset=True) + continue # Retry the operation + except Exception as reconn_ex: + raise ToolExecutionException( + "Failed to reconnect to MCP server.", + inner_exception=reconn_ex, + ) from reconn_ex + else: + # Second attempt also failed, give up + logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}") + raise ToolExecutionException( + f"Failed to call prompt '{prompt_name}' - connection lost.", + inner_exception=cl_ex, + ) from cl_ex + except McpError as mcp_exc: + raise ToolExecutionException(mcp_exc.error.message, inner_exception=mcp_exc) from mcp_exc + except Exception as ex: + raise ToolExecutionException(f"Failed to call prompt '{prompt_name}'.", inner_exception=ex) from ex + raise ToolExecutionException(f"Failed to get prompt '{prompt_name}' after retries.") async def __aenter__(self) -> Self: """Enter the async context manager. @@ -804,7 +897,9 @@ def __init__( command: str, *, load_tools: bool = True, + parse_tool_results: Literal[True] | Callable[[types.CallToolResult], Any] | None = True, load_prompts: bool = True, + parse_prompt_results: Literal[True] | Callable[[types.GetPromptResult], Any] | None = True, request_timeout: int | None = None, session: ClientSession | None = None, description: str | None = None, @@ -830,7 +925,15 @@ def __init__( Keyword Args: load_tools: Whether to load tools from the MCP server. + parse_tool_results: How to parse tool results from the MCP server. + Set to True, to use the default parser that converts to Agent Framework types. + Set to a callable to use a custom parser function. + Set to None to return the raw MCP tool result. load_prompts: Whether to load prompts from the MCP server. + parse_prompt_results: How to parse prompt results from the MCP server. + Set to True, to use the default parser that converts to Agent Framework types. + Set to a callable to use a custom parser function. + Set to None to return the raw MCP prompt result. request_timeout: The default timeout in seconds for all requests. session: The session to use for the MCP connection. description: The description of the tool. @@ -857,7 +960,9 @@ def __init__( session=session, chat_client=chat_client, load_tools=load_tools, + parse_tool_results=parse_tool_results, load_prompts=load_prompts, + parse_prompt_results=parse_prompt_results, request_timeout=request_timeout, ) self.command = command @@ -913,7 +1018,9 @@ def __init__( url: str, *, load_tools: bool = True, + parse_tool_results: Literal[True] | Callable[[types.CallToolResult], Any] | None = True, load_prompts: bool = True, + parse_prompt_results: Literal[True] | Callable[[types.GetPromptResult], Any] | None = True, request_timeout: int | None = None, session: ClientSession | None = None, description: str | None = None, @@ -939,7 +1046,15 @@ def __init__( Keyword Args: load_tools: Whether to load tools from the MCP server. + parse_tool_results: How to parse tool results from the MCP server. + Set to True, to use the default parser that converts to Agent Framework types. + Set to a callable to use a custom parser function. + Set to None to return the raw MCP tool result. load_prompts: Whether to load prompts from the MCP server. + parse_prompt_results: How to parse prompt results from the MCP server. + Set to True, to use the default parser that converts to Agent Framework types. + Set to a callable to use a custom parser function. + Set to None to return the raw MCP prompt result. request_timeout: The default timeout in seconds for all requests. session: The session to use for the MCP connection. description: The description of the tool. @@ -968,7 +1083,9 @@ def __init__( session=session, chat_client=chat_client, load_tools=load_tools, + parse_tool_results=parse_tool_results, load_prompts=load_prompts, + parse_prompt_results=parse_prompt_results, request_timeout=request_timeout, ) self.url = url @@ -1016,7 +1133,9 @@ def __init__( url: str, *, load_tools: bool = True, + parse_tool_results: Literal[True] | Callable[[types.CallToolResult], Any] | None = True, load_prompts: bool = True, + parse_prompt_results: Literal[True] | Callable[[types.GetPromptResult], Any] | None = True, request_timeout: int | None = None, session: ClientSession | None = None, description: str | None = None, @@ -1040,7 +1159,15 @@ def __init__( Keyword Args: load_tools: Whether to load tools from the MCP server. + parse_tool_results: How to parse tool results from the MCP server. + Set to True, to use the default parser that converts to Agent Framework types. + Set to a callable to use a custom parser function. + Set to None to return the raw MCP tool result. load_prompts: Whether to load prompts from the MCP server. + parse_prompt_results: How to parse prompt results from the MCP server. + Set to True, to use the default parser that converts to Agent Framework types. + Set to a callable to use a custom parser function. + Set to None to return the raw MCP prompt result. request_timeout: The default timeout in seconds for all requests. session: The session to use for the MCP connection. description: The description of the tool. @@ -1064,7 +1191,9 @@ def __init__( session=session, chat_client=chat_client, load_tools=load_tools, + parse_tool_results=parse_tool_results, load_prompts=load_prompts, + parse_prompt_results=parse_prompt_results, request_timeout=request_timeout, ) self.url = url diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index ba8fa50c0f..7a0b1672a7 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -4,7 +4,15 @@ import inspect import json import sys -from collections.abc import AsyncIterable, Awaitable, Callable, Collection, Mapping, MutableMapping, Sequence +from collections.abc import ( + AsyncIterable, + Awaitable, + Callable, + Collection, + Mapping, + MutableMapping, + Sequence, +) from functools import wraps from time import perf_counter, time_ns from typing import ( @@ -18,6 +26,7 @@ Protocol, TypedDict, TypeVar, + Union, cast, get_args, get_origin, @@ -121,7 +130,13 @@ def _parse_inputs( if inputs is None: return [] - from ._types import BaseContent, DataContent, HostedFileContent, HostedVectorStoreContent, UriContent + from ._types import ( + BaseContent, + DataContent, + HostedFileContent, + HostedVectorStoreContent, + UriContent, + ) parsed_inputs: list["Contents"] = [] if not isinstance(inputs, list): @@ -1010,6 +1025,27 @@ def _build_pydantic_model_from_json_schema( if not properties: return create_model(f"{model_name}_input") + def _resolve_literal_type(prop_details: dict[str, Any]) -> type | None: + """Check if property should be a Literal type (const or enum). + + Args: + prop_details: The JSON Schema property details + + Returns: + Literal type if const or enum is present, None otherwise + """ + # const → Literal["value"] + if "const" in prop_details: + return Literal[prop_details["const"]] # type: ignore + + # enum → Literal["a", "b", ...] + if "enum" in prop_details and isinstance(prop_details["enum"], list): + enum_values = prop_details["enum"] + if enum_values: + return Literal[tuple(enum_values)] # type: ignore + + return None + def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type: """Resolve JSON Schema type to Python type, handling $ref, nested objects, and typed arrays. @@ -1020,6 +1056,31 @@ def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type: Returns: Python type annotation (could be int, str, list[str], or a nested Pydantic model) """ + # Handle oneOf + discriminator (polymorphic objects) + if "oneOf" in prop_details and "discriminator" in prop_details: + discriminator = prop_details["discriminator"] + disc_field = discriminator.get("propertyName") + + variants = [] + for variant in prop_details["oneOf"]: + if "$ref" in variant: + ref = variant["$ref"] + if ref.startswith("#/$defs/"): + def_name = ref.split("/")[-1] + resolved = definitions.get(def_name) + if resolved: + variant_model = _resolve_type( + resolved, + parent_name=f"{parent_name}_{def_name}", + ) + variants.append(variant_model) + + if variants and disc_field: + return Annotated[ + Union[tuple(variants)], # type: ignore + Field(discriminator=disc_field), + ] + # Handle $ref by resolving the reference if "$ref" in prop_details: ref = prop_details["$ref"] @@ -1070,9 +1131,15 @@ def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type: else nested_prop_details ) - nested_python_type = _resolve_type( - nested_prop_details, f"{nested_model_name}_{nested_prop_name}" - ) + # Check for Literal types first (const/enum) + literal_type = _resolve_literal_type(nested_prop_details) + if literal_type is not None: + nested_python_type = literal_type + else: + nested_python_type = _resolve_type( + nested_prop_details, + f"{nested_model_name}_{nested_prop_name}", + ) nested_description = nested_prop_details.get("description", "") # Build field kwargs for nested property @@ -1109,7 +1176,12 @@ def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type: for prop_name, prop_details in properties.items(): prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details - python_type = _resolve_type(prop_details, f"{model_name}_{prop_name}") + # Check for Literal types first (const/enum) + literal_type = _resolve_literal_type(prop_details) + if literal_type is not None: + python_type = literal_type + else: + python_type = _resolve_type(prop_details, f"{model_name}_{prop_name}") description = prop_details.get("description", "") # Build field kwargs (description, etc.) diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 89c3c520fe..c4e8cb09df 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -1248,6 +1248,75 @@ async def test_streamable_http_integration(): assert result[0].text is not None +@pytest.mark.flaky +@skip_if_mcp_integration_tests_disabled +async def test_mcp_connection_reset_integration(): + """Test that connection reset works correctly with a real MCP server. + + This integration test verifies: + 1. Initial connection and tool execution works + 2. Simulating connection failure triggers automatic reconnection + 3. Tool execution works after reconnection + 4. Exit stack cleanup happens properly during reconnection + """ + url = os.environ.get("LOCAL_MCP_URL") + + tool = MCPStreamableHTTPTool(name="integration_test", url=url) + + async with tool: + # Verify initial connection + assert tool.session is not None + assert tool.is_connected is True + assert len(tool.functions) > 0, "The MCP server should have at least one function." + + # Get the first function and invoke it + func = tool.functions[0] + first_result = await func.invoke(query="What is Agent Framework?") + assert first_result is not None + assert len(first_result) > 0 + + # Store the original session and exit stack for comparison + original_session = tool.session + original_exit_stack = tool._exit_stack + original_call_tool = tool.session.call_tool + + # Simulate connection failure by making call_tool raise ClosedResourceError once + call_count = 0 + + async def call_tool_with_error(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call fails with connection error + from anyio.streams.memory import ClosedResourceError + + raise ClosedResourceError + # After reconnection, delegate to the original method + return await original_call_tool(*args, **kwargs) + + tool.session.call_tool = call_tool_with_error + + # Invoke the function again - this should trigger automatic reconnection on ClosedResourceError + second_result = await func.invoke(query="What is Agent Framework?") + assert second_result is not None + assert len(second_result) > 0 + + # Verify we have a new session and exit stack after reconnection + assert tool.session is not None + assert tool.session is not original_session, "Session should be replaced after reconnection" + assert tool._exit_stack is not original_exit_stack, "Exit stack should be replaced after reconnection" + assert tool.is_connected is True + + # Verify tools are still available after reconnection + assert len(tool.functions) > 0 + + # Both results should be valid (we don't compare content as it may vary) + if hasattr(first_result[0], "text"): + assert first_result[0].text is not None + if hasattr(second_result[0], "text"): + assert second_result[0].text is not None + + async def test_mcp_tool_message_handler_notification(): """Test that message_handler correctly processes tools/list_changed and prompts/list_changed notifications.""" @@ -1549,7 +1618,6 @@ def test_mcp_websocket_tool_get_mcp_client_with_kwargs(): ) -@pytest.mark.asyncio async def test_mcp_tool_deduplication(): """Test that MCP tools are not duplicated in MCPTool""" from agent_framework._mcp import MCPTool @@ -1611,7 +1679,6 @@ async def test_mcp_tool_deduplication(): assert added_count == 1 # Only 1 new function added -@pytest.mark.asyncio async def test_load_tools_prevents_multiple_calls(): """Test that connect() prevents calling load_tools() multiple times""" from unittest.mock import AsyncMock, MagicMock @@ -1627,6 +1694,7 @@ async def test_load_tools_prevents_multiple_calls(): mock_session = AsyncMock() mock_tool_list = MagicMock() mock_tool_list.tools = [] + mock_tool_list.nextCursor = None # No pagination mock_session.list_tools = AsyncMock(return_value=mock_tool_list) mock_session.initialize = AsyncMock() @@ -1650,7 +1718,6 @@ async def test_load_tools_prevents_multiple_calls(): assert mock_session.list_tools.call_count == 1 # Still 1, not incremented -@pytest.mark.asyncio async def test_load_prompts_prevents_multiple_calls(): """Test that connect() prevents calling load_prompts() multiple times""" from unittest.mock import AsyncMock, MagicMock @@ -1666,6 +1733,7 @@ async def test_load_prompts_prevents_multiple_calls(): mock_session = AsyncMock() mock_prompt_list = MagicMock() mock_prompt_list.prompts = [] + mock_prompt_list.nextCursor = None # No pagination mock_session.list_prompts = AsyncMock(return_value=mock_prompt_list) tool.session = mock_session @@ -1688,7 +1756,6 @@ async def test_load_prompts_prevents_multiple_calls(): assert mock_session.list_prompts.call_count == 1 # Still 1, not incremented -@pytest.mark.asyncio async def test_mcp_streamable_http_tool_httpx_client_cleanup(): """Test that MCPStreamableHTTPTool properly passes through httpx clients.""" from unittest.mock import AsyncMock, Mock, patch @@ -1744,3 +1811,556 @@ async def test_mcp_streamable_http_tool_httpx_client_cleanup(): # Get the last call (should be from tool2.connect()) call_args = mock_client.call_args assert call_args.kwargs["http_client"] is user_client, "User's client should be passed through" + + +async def test_load_tools_with_pagination(): + """Test that load_tools handles pagination correctly.""" + from unittest.mock import AsyncMock, MagicMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + + # Mock the session + mock_session = AsyncMock() + tool.session = mock_session + tool.load_tools_flag = True + + # Create paginated responses + page1 = MagicMock() + page1.tools = [ + types.Tool( + name="tool_1", + description="First tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + types.Tool( + name="tool_2", + description="Second tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + ] + page1.nextCursor = "cursor_page2" + + page2 = MagicMock() + page2.tools = [ + types.Tool( + name="tool_3", + description="Third tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + ] + page2.nextCursor = "cursor_page3" + + page3 = MagicMock() + page3.tools = [ + types.Tool( + name="tool_4", + description="Fourth tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + ] + page3.nextCursor = None # No more pages + + # Mock list_tools to return different pages based on params + async def mock_list_tools(params=None): + if params is None: + return page1 + if params.cursor == "cursor_page2": + return page2 + if params.cursor == "cursor_page3": + return page3 + raise ValueError("Unexpected cursor value") + + mock_session.list_tools = AsyncMock(side_effect=mock_list_tools) + + # Load tools with pagination + await tool.load_tools() + + # Verify all pages were fetched + assert mock_session.list_tools.call_count == 3 + assert len(tool._functions) == 4 + assert [f.name for f in tool._functions] == ["tool_1", "tool_2", "tool_3", "tool_4"] + + +async def test_load_prompts_with_pagination(): + """Test that load_prompts handles pagination correctly.""" + from unittest.mock import AsyncMock, MagicMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + + # Mock the session + mock_session = AsyncMock() + tool.session = mock_session + tool.load_prompts_flag = True + + # Create paginated responses + page1 = MagicMock() + page1.prompts = [ + types.Prompt( + name="prompt_1", + description="First prompt", + arguments=[types.PromptArgument(name="arg1", description="Arg 1", required=True)], + ), + types.Prompt( + name="prompt_2", + description="Second prompt", + arguments=[types.PromptArgument(name="arg2", description="Arg 2", required=True)], + ), + ] + page1.nextCursor = "cursor_page2" + + page2 = MagicMock() + page2.prompts = [ + types.Prompt( + name="prompt_3", + description="Third prompt", + arguments=[types.PromptArgument(name="arg3", description="Arg 3", required=False)], + ), + ] + page2.nextCursor = None # No more pages + + # Mock list_prompts to return different pages based on params + async def mock_list_prompts(params=None): + if params is None: + return page1 + if params.cursor == "cursor_page2": + return page2 + raise ValueError("Unexpected cursor value") + + mock_session.list_prompts = AsyncMock(side_effect=mock_list_prompts) + + # Load prompts with pagination + await tool.load_prompts() + + # Verify all pages were fetched + assert mock_session.list_prompts.call_count == 2 + assert len(tool._functions) == 3 + assert [f.name for f in tool._functions] == ["prompt_1", "prompt_2", "prompt_3"] + + +async def test_load_tools_pagination_with_duplicates(): + """Test that load_tools prevents duplicates across paginated results.""" + from unittest.mock import AsyncMock, MagicMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + + # Mock the session + mock_session = AsyncMock() + tool.session = mock_session + tool.load_tools_flag = True + + # Create paginated responses with duplicate tool names + page1 = MagicMock() + page1.tools = [ + types.Tool( + name="tool_1", + description="First tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + types.Tool( + name="tool_2", + description="Second tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + ] + page1.nextCursor = "cursor_page2" + + page2 = MagicMock() + page2.tools = [ + types.Tool( + name="tool_1", # Duplicate from page1 + description="Duplicate tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + types.Tool( + name="tool_3", + description="Third tool", + inputSchema={"type": "object", "properties": {"param": {"type": "string"}}}, + ), + ] + page2.nextCursor = None + + # Mock list_tools to return different pages + async def mock_list_tools(params=None): + if params is None: + return page1 + if params.cursor == "cursor_page2": + return page2 + raise ValueError("Unexpected cursor value") + + mock_session.list_tools = AsyncMock(side_effect=mock_list_tools) + + # Load tools with pagination + await tool.load_tools() + + # Verify duplicates were skipped + assert mock_session.list_tools.call_count == 2 + assert len(tool._functions) == 3 + assert [f.name for f in tool._functions] == ["tool_1", "tool_2", "tool_3"] + + +async def test_load_prompts_pagination_with_duplicates(): + """Test that load_prompts prevents duplicates across paginated results.""" + from unittest.mock import AsyncMock, MagicMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + + # Mock the session + mock_session = AsyncMock() + tool.session = mock_session + tool.load_prompts_flag = True + + # Create paginated responses with duplicate prompt names + page1 = MagicMock() + page1.prompts = [ + types.Prompt( + name="prompt_1", + description="First prompt", + arguments=[types.PromptArgument(name="arg1", description="Arg 1", required=True)], + ), + ] + page1.nextCursor = "cursor_page2" + + page2 = MagicMock() + page2.prompts = [ + types.Prompt( + name="prompt_1", # Duplicate from page1 + description="Duplicate prompt", + arguments=[types.PromptArgument(name="arg2", description="Arg 2", required=False)], + ), + types.Prompt( + name="prompt_2", + description="Second prompt", + arguments=[types.PromptArgument(name="arg3", description="Arg 3", required=True)], + ), + ] + page2.nextCursor = None + + # Mock list_prompts to return different pages + async def mock_list_prompts(params=None): + if params is None: + return page1 + if params.cursor == "cursor_page2": + return page2 + raise ValueError("Unexpected cursor value") + + mock_session.list_prompts = AsyncMock(side_effect=mock_list_prompts) + + # Load prompts with pagination + await tool.load_prompts() + + # Verify duplicates were skipped + assert mock_session.list_prompts.call_count == 2 + assert len(tool._functions) == 2 + assert [f.name for f in tool._functions] == ["prompt_1", "prompt_2"] + + +async def test_load_tools_pagination_exception_handling(): + """Test that load_tools handles exceptions during pagination gracefully.""" + from unittest.mock import AsyncMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + + # Mock the session + mock_session = AsyncMock() + tool.session = mock_session + tool.load_tools_flag = True + + # Mock list_tools to raise an exception on first call + mock_session.list_tools = AsyncMock(side_effect=RuntimeError("Connection error")) + + # Load tools should raise the exception (not handled gracefully) + with pytest.raises(RuntimeError, match="Connection error"): + await tool.load_tools() + + # Verify exception was raised on first call + assert mock_session.list_tools.call_count == 1 + assert len(tool._functions) == 0 + + +async def test_load_prompts_pagination_exception_handling(): + """Test that load_prompts handles exceptions during pagination gracefully.""" + from unittest.mock import AsyncMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + + # Mock the session + mock_session = AsyncMock() + tool.session = mock_session + tool.load_prompts_flag = True + + # Mock list_prompts to raise an exception on first call + mock_session.list_prompts = AsyncMock(side_effect=RuntimeError("Connection error")) + + # Load prompts should raise the exception (not handled gracefully) + with pytest.raises(RuntimeError, match="Connection error"): + await tool.load_prompts() + + # Verify exception was raised on first call + assert mock_session.list_prompts.call_count == 1 + assert len(tool._functions) == 0 + + +async def test_load_tools_empty_pagination(): + """Test that load_tools handles empty paginated results.""" + from unittest.mock import AsyncMock, MagicMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + + # Mock the session + mock_session = AsyncMock() + tool.session = mock_session + tool.load_tools_flag = True + + # Create empty response + page1 = MagicMock() + page1.tools = [] + page1.nextCursor = None + + mock_session.list_tools = AsyncMock(return_value=page1) + + # Load tools + await tool.load_tools() + + # Verify + assert mock_session.list_tools.call_count == 1 + assert len(tool._functions) == 0 + + +async def test_load_prompts_empty_pagination(): + """Test that load_prompts handles empty paginated results.""" + from unittest.mock import AsyncMock, MagicMock + + from agent_framework._mcp import MCPTool + + tool = MCPTool(name="test_tool") + + # Mock the session + mock_session = AsyncMock() + tool.session = mock_session + tool.load_prompts_flag = True + + # Create empty response + page1 = MagicMock() + page1.prompts = [] + page1.nextCursor = None + + mock_session.list_prompts = AsyncMock(return_value=page1) + + # Load prompts + await tool.load_prompts() + + # Verify + assert mock_session.list_prompts.call_count == 1 + assert len(tool._functions) == 0 + + +async def test_mcp_tool_connection_properly_invalidated_after_closed_resource_error(): + """Test that verifies reconnection on ClosedResourceError for issue #2884. + + This test verifies the fix for issue #2884: the tool tries operations optimistically + and only reconnects when ClosedResourceError is encountered, avoiding extra latency. + """ + from unittest.mock import AsyncMock, MagicMock, patch + + from anyio.streams.memory import ClosedResourceError + + from agent_framework._mcp import MCPStdioTool + from agent_framework.exceptions import ToolExecutionException + + # Create a mock MCP tool + tool = MCPStdioTool( + name="test_server", + command="test_command", + args=["arg1"], + load_tools=True, + ) + + # Mock the session + mock_session = MagicMock() + mock_session._request_id = 1 + mock_session.call_tool = AsyncMock() + + # Mock _exit_stack.aclose to track cleanup calls + original_exit_stack = tool._exit_stack + tool._exit_stack.aclose = AsyncMock() + + # Mock connect() to avoid trying to start actual process + with patch.object(tool, "connect", new_callable=AsyncMock) as mock_connect: + + async def restore_session(*, reset=False): + if reset: + await original_exit_stack.aclose() + tool.session = mock_session + tool.is_connected = True + tool._tools_loaded = True + + mock_connect.side_effect = restore_session + + # Simulate initial connection + tool.session = mock_session + tool.is_connected = True + tool._tools_loaded = True + + # First call should work - connection is valid + mock_session.call_tool.return_value = MagicMock(content=[]) + result = await tool.call_tool("test_tool", arg1="value1") + assert result is not None + + # Test Case 1: Connection closed unexpectedly, should reconnect and retry + # Simulate ClosedResourceError on first call, then succeed + call_count = 0 + + async def call_tool_with_error(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ClosedResourceError + return MagicMock(content=[]) + + mock_session.call_tool = call_tool_with_error + + # This call should trigger reconnection after ClosedResourceError + result = await tool.call_tool("test_tool", arg1="value2") + assert result is not None + # Verify reconnect was attempted with reset=True + assert mock_connect.call_count >= 1 + mock_connect.assert_called_with(reset=True) + # Verify _exit_stack.aclose was called during reconnection + original_exit_stack.aclose.assert_called() + + # Test Case 2: Reconnection failure + # Reset counters + call_count = 0 + mock_connect.reset_mock() + original_exit_stack.aclose.reset_mock() + + # Make call_tool always raise ClosedResourceError + async def always_fail(*args, **kwargs): + raise ClosedResourceError + + mock_session.call_tool = always_fail + + # Change mock_connect to simulate failed reconnection + mock_connect.side_effect = Exception("Failed to reconnect") + + # This should raise ToolExecutionException when reconnection fails + with pytest.raises(ToolExecutionException) as exc_info: + await tool.call_tool("test_tool", arg1="value3") + + # Verify reconnection was attempted + assert mock_connect.call_count >= 1 + # Verify error message indicates reconnection failure + assert "failed to reconnect" in str(exc_info.value).lower() + + +async def test_mcp_tool_get_prompt_reconnection_on_closed_resource_error(): + """Test that get_prompt also reconnects on ClosedResourceError. + + This verifies that the fix for issue #2884 applies to get_prompt as well, + and that _exit_stack.aclose() is properly called during reconnection. + """ + from unittest.mock import AsyncMock, MagicMock, patch + + from anyio.streams.memory import ClosedResourceError + + from agent_framework._mcp import MCPStdioTool + from agent_framework.exceptions import ToolExecutionException + + # Create a mock MCP tool + tool = MCPStdioTool( + name="test_server", + command="test_command", + args=["arg1"], + load_prompts=True, + ) + + # Mock the session + mock_session = MagicMock() + mock_session._request_id = 1 + mock_session.get_prompt = AsyncMock() + + # Mock _exit_stack.aclose to track cleanup calls + original_exit_stack = tool._exit_stack + tool._exit_stack.aclose = AsyncMock() + + # Mock connect() to avoid trying to start actual process + with patch.object(tool, "connect", new_callable=AsyncMock) as mock_connect: + + async def restore_session(*, reset=False): + if reset: + await original_exit_stack.aclose() + tool.session = mock_session + tool.is_connected = True + tool._prompts_loaded = True + + mock_connect.side_effect = restore_session + + # Simulate initial connection + tool.session = mock_session + tool.is_connected = True + tool._prompts_loaded = True + + # First call should work - connection is valid + mock_session.get_prompt.return_value = MagicMock(messages=[]) + result = await tool.get_prompt("test_prompt", arg1="value1") + assert result is not None + + # Test Case 1: Connection closed unexpectedly, should reconnect and retry + # Simulate ClosedResourceError on first call, then succeed + call_count = 0 + + async def get_prompt_with_error(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ClosedResourceError + return MagicMock(messages=[]) + + mock_session.get_prompt = get_prompt_with_error + + # This call should trigger reconnection after ClosedResourceError + result = await tool.get_prompt("test_prompt", arg1="value2") + assert result is not None + # Verify reconnect was attempted with reset=True + assert mock_connect.call_count >= 1 + mock_connect.assert_called_with(reset=True) + # Verify _exit_stack.aclose was called during reconnection + original_exit_stack.aclose.assert_called() + + # Test Case 2: Reconnection failure + # Reset counters + call_count = 0 + mock_connect.reset_mock() + original_exit_stack.aclose.reset_mock() + + # Make get_prompt always raise ClosedResourceError + async def always_fail(*args, **kwargs): + raise ClosedResourceError + + mock_session.get_prompt = always_fail + + # Change mock_connect to simulate failed reconnection + mock_connect.side_effect = Exception("Failed to reconnect") + + # This should raise ToolExecutionException when reconnection fails + with pytest.raises(ToolExecutionException) as exc_info: + await tool.get_prompt("test_prompt", arg1="value3") + + # Verify reconnection was attempted + assert mock_connect.call_count >= 1 + # Verify error message indicates reconnection failure + assert "failed to reconnect" in str(exc_info.value).lower() diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index f70e6ddb56..73327b4c1f 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -5,7 +5,7 @@ import pytest from opentelemetry import trace from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from agent_framework import ( AIFunction, @@ -15,7 +15,11 @@ ToolProtocol, ai_function, ) -from agent_framework._tools import _parse_annotation, _parse_inputs +from agent_framework._tools import ( + _build_pydantic_model_from_json_schema, + _parse_annotation, + _parse_inputs, +) from agent_framework.exceptions import ToolException from agent_framework.observability import OtelAttr @@ -1548,4 +1552,467 @@ def test_parse_annotation_with_annotated_and_literal(): assert get_args(literal_type) == ("A", "B", "C") +def test_build_pydantic_model_from_json_schema_array_of_objects_issue(): + """Test for Tools with complex input schema (array of objects). + + This test verifies that JSON schemas with array properties containing nested objects + are properly parsed, ensuring that the nested object schema is preserved + and not reduced to a bare dict. + + Example from issue: + ``` + const SalesOrderItemSchema = z.object({ + customerMaterialNumber: z.string().optional(), + quantity: z.number(), + unitOfMeasure: z.string() + }); + + const CreateSalesOrderInputSchema = z.object({ + contract: z.string(), + items: z.array(SalesOrderItemSchema) + }); + ``` + + The issue was that agents only saw: + ``` + {"contract": "str", "items": "list[dict]"} + ``` + + Instead of the proper nested schema with all fields. + """ + # Schema matching the issue description + schema = { + "type": "object", + "properties": { + "contract": {"type": "string", "description": "Reference contract number"}, + "items": { + "type": "array", + "description": "Sales order line items", + "items": { + "type": "object", + "properties": { + "customerMaterialNumber": { + "type": "string", + "description": "Customer's material number", + }, + "quantity": {"type": "number", "description": "Order quantity"}, + "unitOfMeasure": { + "type": "string", + "description": "Unit of measure (e.g., 'ST', 'KG', 'TO')", + }, + }, + "required": ["quantity", "unitOfMeasure"], + }, + }, + }, + "required": ["contract", "items"], + } + + model = _build_pydantic_model_from_json_schema("create_sales_order", schema) + + # Test valid data + valid_data = { + "contract": "CONTRACT-123", + "items": [ + { + "customerMaterialNumber": "MAT-001", + "quantity": 10, + "unitOfMeasure": "ST", + }, + {"quantity": 5.5, "unitOfMeasure": "KG"}, + ], + } + + instance = model(**valid_data) + + # Verify the data was parsed correctly + assert instance.contract == "CONTRACT-123" + assert len(instance.items) == 2 + + # Verify first item + assert instance.items[0].customerMaterialNumber == "MAT-001" + assert instance.items[0].quantity == 10 + assert instance.items[0].unitOfMeasure == "ST" + + # Verify second item (optional field not provided) + assert instance.items[1].quantity == 5.5 + assert instance.items[1].unitOfMeasure == "KG" + + # Verify that items are proper BaseModel instances, not bare dicts + assert isinstance(instance.items[0], BaseModel) + assert isinstance(instance.items[1], BaseModel) + + # Verify that the nested object has the expected fields + assert hasattr(instance.items[0], "customerMaterialNumber") + assert hasattr(instance.items[0], "quantity") + assert hasattr(instance.items[0], "unitOfMeasure") + + # CRITICAL: Validate using the same methods that actual chat clients use + # This is what would actually be sent to the LLM + + # Create an AIFunction wrapper to access the client-facing APIs + def dummy_func(**kwargs): + return kwargs + + test_func = AIFunction( + func=dummy_func, + name="create_sales_order", + description="Create a sales order", + input_model=model, + ) + + # Test 1: Anthropic client uses tool.parameters() directly + anthropic_schema = test_func.parameters() + + # Verify contract property + assert "contract" in anthropic_schema["properties"] + assert anthropic_schema["properties"]["contract"]["type"] == "string" + + # Verify items array property exists + assert "items" in anthropic_schema["properties"] + items_prop = anthropic_schema["properties"]["items"] + assert items_prop["type"] == "array" + + # THE KEY TEST for Anthropic: array items must have proper object schema + assert "items" in items_prop, "Array should have 'items' schema definition" + array_items_schema = items_prop["items"] + + # Resolve schema if using $ref + if "$ref" in array_items_schema: + ref_path = array_items_schema["$ref"] + assert ref_path.startswith("#/$defs/") or ref_path.startswith("#/definitions/") + ref_name = ref_path.split("/")[-1] + defs = anthropic_schema.get("$defs", anthropic_schema.get("definitions", {})) + assert ref_name in defs, f"Referenced schema '{ref_name}' should exist" + item_schema = defs[ref_name] + else: + item_schema = array_items_schema + + # Verify the nested object has all properties defined + assert "properties" in item_schema, "Array items should have properties (not bare dict)" + item_properties = item_schema["properties"] + + # All three fields must be present in schema sent to LLM + assert "customerMaterialNumber" in item_properties, "customerMaterialNumber missing from LLM schema" + assert "quantity" in item_properties, "quantity missing from LLM schema" + assert "unitOfMeasure" in item_properties, "unitOfMeasure missing from LLM schema" + + # Verify types are correct + assert item_properties["customerMaterialNumber"]["type"] == "string" + assert item_properties["quantity"]["type"] in ["number", "integer"] + assert item_properties["unitOfMeasure"]["type"] == "string" + + # Test 2: OpenAI client uses tool.to_json_schema_spec() + openai_spec = test_func.to_json_schema_spec() + + assert openai_spec["type"] == "function" + assert "function" in openai_spec + openai_schema = openai_spec["function"]["parameters"] + + # Verify the same structure is present in OpenAI format + assert "items" in openai_schema["properties"] + openai_items_prop = openai_schema["properties"]["items"] + assert openai_items_prop["type"] == "array" + assert "items" in openai_items_prop + + openai_array_items = openai_items_prop["items"] + if "$ref" in openai_array_items: + ref_path = openai_array_items["$ref"] + ref_name = ref_path.split("/")[-1] + defs = openai_schema.get("$defs", openai_schema.get("definitions", {})) + openai_item_schema = defs[ref_name] + else: + openai_item_schema = openai_array_items + + assert "properties" in openai_item_schema + openai_props = openai_item_schema["properties"] + assert "customerMaterialNumber" in openai_props + assert "quantity" in openai_props + assert "unitOfMeasure" in openai_props + + # Test validation - missing required quantity + with pytest.raises(ValidationError): + model( + contract="CONTRACT-456", + items=[ + { + "customerMaterialNumber": "MAT-002", + "unitOfMeasure": "TO", + # Missing required 'quantity' + } + ], + ) + + # Test validation - missing required unitOfMeasure + with pytest.raises(ValidationError): + model( + contract="CONTRACT-789", + items=[ + { + "quantity": 20 + # Missing required 'unitOfMeasure' + } + ], + ) + + +def test_one_of_discriminator_polymorphism(): + """Test that oneOf with discriminator creates proper polymorphic union types. + + Tests that oneOf + discriminator patterns are properly converted to Pydantic discriminated unions. + """ + schema = { + "$defs": { + "CreateProject": { + "description": "Action: Create an Azure DevOps project.", + "properties": { + "name": { + "const": "create_project", + "default": "create_project", + "type": "string", + }, + "params": {"$ref": "#/$defs/CreateProjectParams"}, + }, + "required": ["params"], + "type": "object", + }, + "CreateProjectParams": { + "description": "Parameters for the create_project action.", + "properties": { + "orgUrl": {"minLength": 1, "type": "string"}, + "projectName": {"minLength": 1, "type": "string"}, + "description": {"default": "", "type": "string"}, + "template": {"default": "Agile", "type": "string"}, + "sourceControl": { + "default": "Git", + "enum": ["Git", "Tfvc"], + "type": "string", + }, + "visibility": {"default": "private", "type": "string"}, + }, + "required": ["orgUrl", "projectName"], + "type": "object", + }, + "DeployRequest": { + "description": "Request to deploy Azure DevOps resources.", + "properties": { + "projectName": {"minLength": 1, "type": "string"}, + "organization": {"minLength": 1, "type": "string"}, + "actions": { + "items": { + "discriminator": { + "mapping": { + "create_project": "#/$defs/CreateProject", + "hello_world": "#/$defs/HelloWorld", + }, + "propertyName": "name", + }, + "oneOf": [ + {"$ref": "#/$defs/HelloWorld"}, + {"$ref": "#/$defs/CreateProject"}, + ], + }, + "type": "array", + }, + }, + "required": ["projectName", "organization"], + "type": "object", + }, + "HelloWorld": { + "description": "Action: Prints a greeting message.", + "properties": { + "name": { + "const": "hello_world", + "default": "hello_world", + "type": "string", + }, + "params": {"$ref": "#/$defs/HelloWorldParams"}, + }, + "required": ["params"], + "type": "object", + }, + "HelloWorldParams": { + "description": "Parameters for the hello_world action.", + "properties": { + "name": { + "description": "Name to greet", + "minLength": 1, + "type": "string", + } + }, + "required": ["name"], + "type": "object", + }, + }, + "properties": {"params": {"$ref": "#/$defs/DeployRequest"}}, + "required": ["params"], + "type": "object", + } + + # Build the model + model = _build_pydantic_model_from_json_schema("deploy_tool", schema) + + # Verify the model structure + assert model is not None + assert issubclass(model, BaseModel) + + # Test with HelloWorld action + hello_world_data = { + "params": { + "projectName": "MyProject", + "organization": "MyOrg", + "actions": [ + { + "name": "hello_world", + "params": {"name": "Alice"}, + } + ], + } + } + + instance = model(**hello_world_data) + assert instance.params.projectName == "MyProject" + assert instance.params.organization == "MyOrg" + assert len(instance.params.actions) == 1 + assert instance.params.actions[0].name == "hello_world" + assert instance.params.actions[0].params.name == "Alice" + + # Test with CreateProject action + create_project_data = { + "params": { + "projectName": "MyProject", + "organization": "MyOrg", + "actions": [ + { + "name": "create_project", + "params": { + "orgUrl": "https://dev.azure.com/myorg", + "projectName": "NewProject", + "sourceControl": "Git", + }, + } + ], + } + } + + instance2 = model(**create_project_data) + assert instance2.params.actions[0].name == "create_project" + assert instance2.params.actions[0].params.projectName == "NewProject" + assert instance2.params.actions[0].params.sourceControl == "Git" + + # Test with mixed actions + mixed_data = { + "params": { + "projectName": "MyProject", + "organization": "MyOrg", + "actions": [ + {"name": "hello_world", "params": {"name": "Bob"}}, + { + "name": "create_project", + "params": { + "orgUrl": "https://dev.azure.com/myorg", + "projectName": "AnotherProject", + }, + }, + ], + } + } + + instance3 = model(**mixed_data) + assert len(instance3.params.actions) == 2 + assert instance3.params.actions[0].name == "hello_world" + assert instance3.params.actions[1].name == "create_project" + + +def test_const_creates_literal(): + """Test that const in JSON Schema creates Literal type.""" + schema = { + "properties": { + "action": { + "const": "create", + "type": "string", + "description": "Action type", + }, + "value": {"type": "integer"}, + }, + "required": ["action", "value"], + } + + model = _build_pydantic_model_from_json_schema("test_const", schema) + + # Verify valid const value works + instance = model(action="create", value=42) + assert instance.action == "create" + assert instance.value == 42 + + # Verify incorrect const value fails + with pytest.raises(ValidationError): + model(action="delete", value=42) + + +def test_enum_creates_literal(): + """Test that enum in JSON Schema creates Literal type.""" + schema = { + "properties": { + "status": { + "enum": ["pending", "approved", "rejected"], + "type": "string", + "description": "Status", + }, + "priority": {"enum": [1, 2, 3], "type": "integer"}, + }, + "required": ["status"], + } + + model = _build_pydantic_model_from_json_schema("test_enum", schema) + + # Verify valid enum values work + instance = model(status="approved", priority=2) + assert instance.status == "approved" + assert instance.priority == 2 + + # Verify invalid enum value fails + with pytest.raises(ValidationError): + model(status="unknown") + + with pytest.raises(ValidationError): + model(status="pending", priority=5) + + +def test_nested_object_with_const_and_enum(): + """Test that const and enum work in nested objects.""" + schema = { + "properties": { + "config": { + "type": "object", + "properties": { + "type": { + "const": "production", + "default": "production", + "type": "string", + }, + "level": {"enum": ["low", "medium", "high"], "type": "string"}, + }, + "required": ["level"], + } + }, + "required": ["config"], + } + + model = _build_pydantic_model_from_json_schema("test_nested", schema) + + # Valid data + instance = model(config={"type": "production", "level": "high"}) + assert instance.config.type == "production" + assert instance.config.level == "high" + + # Invalid const in nested object + with pytest.raises(ValidationError): + model(config={"type": "development", "level": "low"}) + + # Invalid enum in nested object + with pytest.raises(ValidationError): + model(config={"type": "production", "level": "critical"}) + + # endregion