|
1 | 1 | # Copyright (c) Microsoft. All rights reserved. |
2 | 2 |
|
| 3 | +import os |
| 4 | +from collections.abc import AsyncIterator |
| 5 | +from contextlib import asynccontextmanager |
| 6 | +from typing import Annotated |
3 | 7 | from unittest.mock import AsyncMock, MagicMock, patch |
4 | 8 |
|
5 | 9 | import pytest |
6 | 10 | from agent_framework import ( |
| 11 | + AgentRunResponse, |
| 12 | + AgentRunResponseUpdate, |
| 13 | + ChatAgent, |
7 | 14 | ChatClientProtocol, |
8 | 15 | ChatMessage, |
9 | 16 | ChatOptions, |
10 | 17 | Role, |
11 | 18 | TextContent, |
12 | 19 | ) |
13 | 20 | from agent_framework.exceptions import ServiceInitializationError |
| 21 | +from azure.ai.projects.aio import AIProjectClient |
14 | 22 | from azure.ai.projects.models import ( |
15 | 23 | ResponseTextFormatConfigurationJsonSchema, |
16 | 24 | ) |
| 25 | +from azure.identity.aio import AzureCliCredential |
17 | 26 | from openai.types.responses.parsed_response import ParsedResponse |
18 | 27 | from openai.types.responses.response import Response as OpenAIResponse |
19 | | -from pydantic import BaseModel, ConfigDict, ValidationError |
| 28 | +from pydantic import BaseModel, ConfigDict, Field, ValidationError |
20 | 29 |
|
21 | 30 | from agent_framework_azure_ai import AzureAIClient, AzureAISettings |
22 | 31 |
|
| 32 | +skip_if_azure_ai_integration_tests_disabled = pytest.mark.skipif( |
| 33 | + os.getenv("RUN_INTEGRATION_TESTS", "false").lower() != "true" |
| 34 | + or os.getenv("AZURE_AI_PROJECT_ENDPOINT", "") in ("", "https://test-project.cognitiveservices.azure.com/") |
| 35 | + or os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME", "") == "", |
| 36 | + reason=( |
| 37 | + "No real AZURE_AI_PROJECT_ENDPOINT or AZURE_AI_MODEL_DEPLOYMENT_NAME provided; skipping integration tests." |
| 38 | + if os.getenv("RUN_INTEGRATION_TESTS", "false").lower() == "true" |
| 39 | + else "Integration tests are disabled." |
| 40 | + ), |
| 41 | +) |
| 42 | + |
| 43 | + |
| 44 | +@asynccontextmanager |
| 45 | +async def temporary_chat_client(agent_name: str) -> AsyncIterator[AzureAIClient]: |
| 46 | + """Async context manager that creates an Azure AI agent and yields an `AzureAIClient`. |
| 47 | +
|
| 48 | + The underlying agent version is cleaned up automatically after use. |
| 49 | + Tests can construct their own `ChatAgent` instances from the yielded client. |
| 50 | + """ |
| 51 | + endpoint = os.environ["AZURE_AI_PROJECT_ENDPOINT"] |
| 52 | + async with ( |
| 53 | + AzureCliCredential() as credential, |
| 54 | + AIProjectClient(endpoint=endpoint, credential=credential) as project_client, |
| 55 | + ): |
| 56 | + chat_client = AzureAIClient( |
| 57 | + project_client=project_client, |
| 58 | + agent_name=agent_name, |
| 59 | + ) |
| 60 | + try: |
| 61 | + yield chat_client |
| 62 | + finally: |
| 63 | + await project_client.agents.delete(agent_name=agent_name) |
| 64 | + |
23 | 65 |
|
24 | 66 | def create_test_azure_ai_client( |
25 | 67 | mock_project_client: MagicMock, |
@@ -751,3 +793,64 @@ def mock_project_client() -> MagicMock: |
751 | 793 | mock_client.close = AsyncMock() |
752 | 794 |
|
753 | 795 | return mock_client |
| 796 | + |
| 797 | + |
| 798 | +def get_weather( |
| 799 | + location: Annotated[str, Field(description="The location to get the weather for.")], |
| 800 | +) -> str: |
| 801 | + """Get the weather for a given location.""" |
| 802 | + return f"The weather in {location} is sunny with a high of 25°C." |
| 803 | + |
| 804 | + |
| 805 | +@pytest.mark.flaky |
| 806 | +@skip_if_azure_ai_integration_tests_disabled |
| 807 | +async def test_azure_ai_chat_client_agent_basic_run() -> None: |
| 808 | + """Test ChatAgent basic run functionality with AzureAIClient.""" |
| 809 | + async with ( |
| 810 | + temporary_chat_client(agent_name="BasicRunAgent") as chat_client, |
| 811 | + ChatAgent(chat_client=chat_client) as agent, |
| 812 | + ): |
| 813 | + response = await agent.run("Hello! Please respond with 'Hello World' exactly.") |
| 814 | + |
| 815 | + # Validate response |
| 816 | + assert isinstance(response, AgentRunResponse) |
| 817 | + assert response.text is not None |
| 818 | + assert len(response.text) > 0 |
| 819 | + assert "Hello World" in response.text |
| 820 | + |
| 821 | + |
| 822 | +@pytest.mark.flaky |
| 823 | +@skip_if_azure_ai_integration_tests_disabled |
| 824 | +async def test_azure_ai_chat_client_agent_basic_run_streaming() -> None: |
| 825 | + """Test ChatAgent basic streaming functionality with AzureAIClient.""" |
| 826 | + async with ( |
| 827 | + temporary_chat_client(agent_name="BasicRunStreamingAgent") as chat_client, |
| 828 | + ChatAgent(chat_client=chat_client) as agent, |
| 829 | + ): |
| 830 | + full_message: str = "" |
| 831 | + async for chunk in agent.run_stream("Please respond with exactly: 'This is a streaming response test.'"): |
| 832 | + assert chunk is not None |
| 833 | + assert isinstance(chunk, AgentRunResponseUpdate) |
| 834 | + if chunk.text: |
| 835 | + full_message += chunk.text |
| 836 | + |
| 837 | + # Validate streaming response |
| 838 | + assert len(full_message) > 0 |
| 839 | + assert "streaming response test" in full_message.lower() |
| 840 | + |
| 841 | + |
| 842 | +@pytest.mark.flaky |
| 843 | +@skip_if_azure_ai_integration_tests_disabled |
| 844 | +async def test_azure_ai_chat_client_agent_with_tools() -> None: |
| 845 | + """Test ChatAgent tools with AzureAIClient.""" |
| 846 | + async with ( |
| 847 | + temporary_chat_client(agent_name="RunToolsAgent") as chat_client, |
| 848 | + ChatAgent(chat_client=chat_client, tools=[get_weather]) as agent, |
| 849 | + ): |
| 850 | + response = await agent.run("What's the weather like in Seattle?") |
| 851 | + |
| 852 | + # Validate response |
| 853 | + assert isinstance(response, AgentRunResponse) |
| 854 | + assert response.text is not None |
| 855 | + assert len(response.text) > 0 |
| 856 | + assert any(word in response.text.lower() for word in ["sunny", "25"]) |
0 commit comments