diff --git a/infra/deploy_app_service.bicep b/infra/deploy_app_service.bicep index 648fbf16f..fc0967ef6 100644 --- a/infra/deploy_app_service.bicep +++ b/infra/deploy_app_service.bicep @@ -184,6 +184,10 @@ resource Website 'Microsoft.Web/sites@2020-06-01' = { serverFarmId: HostingPlanName siteConfig: { appSettings: [ + { + name: 'APP_ENV' + value: 'Prod' + } { name: 'APPINSIGHTS_INSTRUMENTATIONKEY' value: reference(applicationInsightsId, '2015-05-01').InstrumentationKey diff --git a/infra/main.bicep b/infra/main.bicep index 7c225c1f8..705259f56 100644 --- a/infra/main.bicep +++ b/infra/main.bicep @@ -74,8 +74,7 @@ param aiDeploymentsLocation string param AZURE_LOCATION string = '' var solutionLocation = empty(AZURE_LOCATION) ? resourceGroup().location : AZURE_LOCATION -var uniqueId = toLower(uniqueString(environmentName, subscription().id, solutionLocation)) - +var uniqueId = toLower(uniqueString(environmentName, subscription().id, solutionLocation, resourceGroup().name)) var solutionPrefix = 'ca${padLeft(take(uniqueId, 12), 12, '0')}' // Load the abbrevations file required to name the azure resources. diff --git a/infra/main.json b/infra/main.json index 647023052..02d2d31ea 100644 --- a/infra/main.json +++ b/infra/main.json @@ -5,7 +5,7 @@ "_generator": { "name": "bicep", "version": "0.36.177.2456", - "templateHash": "2238194529646818649" + "templateHash": "3253509031453285119" } }, "parameters": { @@ -359,7 +359,7 @@ } }, "solutionLocation": "[if(empty(parameters('AZURE_LOCATION')), resourceGroup().location, parameters('AZURE_LOCATION'))]", - "uniqueId": "[toLower(uniqueString(parameters('environmentName'), subscription().id, variables('solutionLocation')))]", + "uniqueId": "[toLower(uniqueString(parameters('environmentName'), subscription().id, variables('solutionLocation'), resourceGroup().name))]", "solutionPrefix": "[format('ca{0}', padLeft(take(variables('uniqueId'), 12), 12, '0'))]", "abbrs": "[variables('$fxv#0')]", "functionAppSqlPrompt": "Generate a valid T-SQL query to find {query} for tables and columns provided below:\r\n 1. Table: Clients\r\n Columns: ClientId, Client, Email, Occupation, MaritalStatus, Dependents\r\n 2. Table: InvestmentGoals\r\n Columns: ClientId, InvestmentGoal\r\n 3. Table: Assets\r\n Columns: ClientId, AssetDate, Investment, ROI, Revenue, AssetType\r\n 4. Table: ClientSummaries\r\n Columns: ClientId, ClientSummary\r\n 5. Table: InvestmentGoalsDetails\r\n Columns: ClientId, InvestmentGoal, TargetAmount, Contribution\r\n 6. Table: Retirement\r\n Columns: ClientId, StatusDate, RetirementGoalProgress, EducationGoalProgress\r\n 7. Table: ClientMeetings\r\n Columns: ClientId, ConversationId, Title, StartTime, EndTime, Advisor, ClientEmail\r\n Always use the Investment column from the Assets table as the value.\r\n Assets table has snapshots of values by date. Do not add numbers across different dates for total values.\r\n Do not use client name in filters.\r\n Do not include assets values unless asked for.\r\n ALWAYS use ClientId = {clientid} in the query filter.\r\n ALWAYS select Client Name (Column: Client) in the query.\r\n Query filters are IMPORTANT. Add filters like AssetType, AssetDate, etc. if needed.\r\n When answering scheduling or time-based meeting questions, always use the StartTime column from ClientMeetings table. Use correct logic to return the most recent past meeting (last/previous) or the nearest future meeting (next/upcoming), and ensure only StartTime column is used for meeting timing comparisons.\r\n Only return the generated SQL query. Do not return anything else.", @@ -744,7 +744,7 @@ "_generator": { "name": "bicep", "version": "0.36.177.2456", - "templateHash": "1124249040831466979" + "templateHash": "1343961496887433815" } }, "parameters": { @@ -1458,7 +1458,7 @@ "_generator": { "name": "bicep", "version": "0.36.177.2456", - "templateHash": "11899270249637077405" + "templateHash": "10199364008784095733" } }, "parameters": { @@ -2244,7 +2244,7 @@ "_generator": { "name": "bicep", "version": "0.36.177.2456", - "templateHash": "10507186896960913919" + "templateHash": "7899619253922538038" } }, "parameters": { @@ -2615,6 +2615,10 @@ "serverFarmId": "[parameters('HostingPlanName')]", "siteConfig": { "appSettings": [ + { + "name": "APP_ENV", + "value": "Prod" + }, { "name": "APPINSIGHTS_INSTRUMENTATIONKEY", "value": "[reference(parameters('applicationInsightsId'), '2015-05-01').InstrumentationKey]" @@ -2894,7 +2898,7 @@ "_generator": { "name": "bicep", "version": "0.36.177.2456", - "templateHash": "11899270249637077405" + "templateHash": "10199364008784095733" } }, "parameters": { diff --git a/infra/scripts/fabric_scripts/create_fabric_items.py b/infra/scripts/fabric_scripts/create_fabric_items.py index 9a718a425..9e32fb2b0 100644 --- a/infra/scripts/fabric_scripts/create_fabric_items.py +++ b/infra/scripts/fabric_scripts/create_fabric_items.py @@ -1,4 +1,3 @@ -from azure.identity import DefaultAzureCredential import base64 import json import requests @@ -8,7 +7,6 @@ import time -# credential = DefaultAzureCredential() from azure.identity import AzureCliCredential credential = AzureCliCredential() diff --git a/infra/scripts/index_scripts/create_search_index.py b/infra/scripts/index_scripts/create_search_index.py index b429a6456..b0a56f118 100644 --- a/infra/scripts/index_scripts/create_search_index.py +++ b/infra/scripts/index_scripts/create_search_index.py @@ -5,7 +5,8 @@ import time import pandas as pd -from azure.identity import DefaultAzureCredential, get_bearer_token_provider +from azure.identity import get_bearer_token_provider +from azure.identity import AzureCliCredential from azure.keyvault.secrets import SecretClient from azure.search.documents import SearchClient from azure.search.documents.indexes import SearchIndexClient @@ -44,9 +45,7 @@ "clienttranscripts/meeting_transcripts_metadata/transcripts_metadata.csv" ) -credential = DefaultAzureCredential( - managed_identity_client_id=managed_identity_client_id -) +credential = AzureCliCredential() token_provider = get_bearer_token_provider( credential, "https://cognitiveservices.azure.com/.default" diff --git a/infra/scripts/index_scripts/create_sql_tables.py b/infra/scripts/index_scripts/create_sql_tables.py index be04dbc7a..73bcc33cd 100644 --- a/infra/scripts/index_scripts/create_sql_tables.py +++ b/infra/scripts/index_scripts/create_sql_tables.py @@ -7,15 +7,13 @@ import pandas as pd import pyodbc -from azure.identity import DefaultAzureCredential +from azure.identity import AzureCliCredential from azure.keyvault.secrets import SecretClient def get_secrets_from_kv(kv_name, secret_name): key_vault_name = kv_name # Set the name of the Azure Key Vault - credential = DefaultAzureCredential( - managed_identity_client_id=managed_identity_client_id - ) + credential = AzureCliCredential() # Use Azure CLI Credential secret_client = SecretClient( vault_url=f"https://{key_vault_name}.vault.azure.net/", credential=credential ) # Create a secret client object using the credential and Key Vault name @@ -27,9 +25,7 @@ def get_secrets_from_kv(kv_name, secret_name): driver = "{ODBC Driver 18 for SQL Server}" -credential = DefaultAzureCredential( - managed_identity_client_id=managed_identity_client_id -) +credential = AzureCliCredential() # Use Azure CLI Credential token_bytes = credential.get_token( "https://database.windows.net/.default" @@ -49,9 +45,7 @@ def get_secrets_from_kv(kv_name, secret_name): from azure.storage.filedatalake import DataLakeServiceClient account_name = get_secrets_from_kv(key_vault_name, "ADLS-ACCOUNT-NAME") -credential = DefaultAzureCredential( - managed_identity_client_id=managed_identity_client_id -) +credential = AzureCliCredential() # Use Azure CLI Credential account_url = f"https://{account_name}.dfs.core.windows.net" diff --git a/infra/scripts/index_scripts/create_update_sql_dates.py b/infra/scripts/index_scripts/create_update_sql_dates.py index 47e1a6364..4b75c3328 100644 --- a/infra/scripts/index_scripts/create_update_sql_dates.py +++ b/infra/scripts/index_scripts/create_update_sql_dates.py @@ -6,15 +6,13 @@ import pandas as pd import pymssql -from azure.identity import DefaultAzureCredential +from azure.identity import AzureCliCredential from azure.keyvault.secrets import SecretClient def get_secrets_from_kv(kv_name, secret_name): key_vault_name = kv_name # Set the name of the Azure Key Vault - credential = DefaultAzureCredential( - managed_identity_client_id=managed_identity_client_id - ) + credential = AzureCliCredential() # Use Azure CLI Credential secret_client = SecretClient( vault_url=f"https://{key_vault_name}.vault.azure.net/", credential=credential ) # Create a secret client object using the credential and Key Vault name @@ -32,9 +30,7 @@ def get_secrets_from_kv(kv_name, secret_name): from azure.storage.filedatalake import DataLakeServiceClient account_name = get_secrets_from_kv(key_vault_name, "ADLS-ACCOUNT-NAME") -credential = DefaultAzureCredential( - managed_identity_client_id=managed_identity_client_id -) +credential = AzureCliCredential() # Use Azure CLI Credential account_url = f"https://{account_name}.dfs.core.windows.net" diff --git a/src/App/.env.sample b/src/App/.env.sample index 0f69d5442..53ae69bf2 100644 --- a/src/App/.env.sample +++ b/src/App/.env.sample @@ -67,4 +67,5 @@ AZURE_SQL_SYSTEM_PROMPT="Generate a valid T-SQL query to find {query} for tables # Misc APPINSIGHTS_INSTRUMENTATIONKEY= AUTH_ENABLED="false" -USE_INTERNAL_STREAM="True" \ No newline at end of file +USE_INTERNAL_STREAM="True" +APP_ENV="dev" \ No newline at end of file diff --git a/src/App/WebApp.Dockerfile b/src/App/WebApp.Dockerfile index 48bcd5ff5..82a362152 100644 --- a/src/App/WebApp.Dockerfile +++ b/src/App/WebApp.Dockerfile @@ -36,4 +36,4 @@ COPY --from=frontend /home/node/app/static /usr/src/app/static/ WORKDIR /usr/src/app EXPOSE 80 -CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "80", "--workers", "4", "--log-level", "info", "--access-log"] +CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "80", "--workers", "1", "--log-level", "info", "--access-log"] diff --git a/src/App/app.py b/src/App/app.py index 901494b2b..3eaa8b7c1 100644 --- a/src/App/app.py +++ b/src/App/app.py @@ -6,7 +6,8 @@ import uuid from types import SimpleNamespace -from azure.identity import DefaultAzureCredential, get_bearer_token_provider +from azure.identity import get_bearer_token_provider +from backend.helpers.azure_credential_utils import get_azure_credential from azure.monitor.opentelemetry import configure_azure_monitor # from quart.sessions import SecureCookieSessionInterface @@ -90,6 +91,7 @@ async def shutdown(): await AgentFactory.delete_all_agent_instance() app.wealth_advisor_agent = None app.search_agent = None + app.sql_agent = None logging.info("Agents cleaned up during application shutdown") # app.secret_key = secrets.token_hex(16) @@ -185,7 +187,7 @@ def init_openai_client(use_data=SHOULD_USE_DATA): if not aoai_api_key: logging.debug("No AZURE_OPENAI_KEY found, using Azure AD auth") ad_token_provider = get_bearer_token_provider( - DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" + get_azure_credential(), "https://cognitiveservices.azure.com/.default" ) # Deployment @@ -233,7 +235,7 @@ def init_cosmosdb_client(): ) if not config.AZURE_COSMOSDB_ACCOUNT_KEY: - credential = DefaultAzureCredential() + credential = get_azure_credential() else: credential = config.AZURE_COSMOSDB_ACCOUNT_KEY diff --git a/src/App/backend/agents/agent_factory.py b/src/App/backend/agents/agent_factory.py index df81a2caf..6bae33808 100644 --- a/src/App/backend/agents/agent_factory.py +++ b/src/App/backend/agents/agent_factory.py @@ -10,8 +10,8 @@ from typing import Optional from azure.ai.projects import AIProjectClient -from azure.identity import DefaultAzureCredential as DefaultAzureCredentialSync -from azure.identity.aio import DefaultAzureCredential +from backend.helpers.azure_credential_utils import get_azure_credential +from backend.helpers.azure_credential_utils import get_azure_credential_async from semantic_kernel.agents import AzureAIAgent, AzureAIAgentSettings from backend.common.config import config @@ -35,7 +35,7 @@ async def get_wealth_advisor_agent(cls): async with cls._lock: if cls._wealth_advisor_agent is None: ai_agent_settings = AzureAIAgentSettings() - creds = DefaultAzureCredential() + creds = await get_azure_credential_async() client = AzureAIAgent.create_client( credential=creds, endpoint=ai_agent_settings.endpoint ) @@ -76,7 +76,7 @@ async def get_search_agent(cls): project_client = AIProjectClient( endpoint=config.AI_PROJECT_ENDPOINT, - credential=DefaultAzureCredentialSync(), + credential=get_azure_credential(), api_version="2025-05-01", ) @@ -137,7 +137,7 @@ async def get_sql_agent(cls) -> dict: project_client = AIProjectClient( endpoint=config.AI_PROJECT_ENDPOINT, - credential=DefaultAzureCredentialSync(), + credential=get_azure_credential(), api_version="2025-05-01", ) diff --git a/src/App/backend/helpers/azure_credential_utils.py b/src/App/backend/helpers/azure_credential_utils.py new file mode 100644 index 000000000..646efb444 --- /dev/null +++ b/src/App/backend/helpers/azure_credential_utils.py @@ -0,0 +1,41 @@ +import os +from azure.identity import ManagedIdentityCredential, DefaultAzureCredential +from azure.identity.aio import ManagedIdentityCredential as AioManagedIdentityCredential, DefaultAzureCredential as AioDefaultAzureCredential + + +async def get_azure_credential_async(client_id=None): + """ + Returns an Azure credential asynchronously based on the application environment. + + If the environment is 'dev', it uses AioDefaultAzureCredential. + Otherwise, it uses AioManagedIdentityCredential. + + Args: + client_id (str, optional): The client ID for the Managed Identity Credential. + + Returns: + Credential object: Either AioDefaultAzureCredential or AioManagedIdentityCredential. + """ + if os.getenv("APP_ENV", "prod").lower() == 'dev': + return AioDefaultAzureCredential() # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development + else: + return AioManagedIdentityCredential(client_id=client_id) + + +def get_azure_credential(client_id=None): + """ + Returns an Azure credential based on the application environment. + + If the environment is 'dev', it uses DefaultAzureCredential. + Otherwise, it uses ManagedIdentityCredential. + + Args: + client_id (str, optional): The client ID for the Managed Identity Credential. + + Returns: + Credential object: Either DefaultAzureCredential or ManagedIdentityCredential. + """ + if os.getenv("APP_ENV", "prod").lower() == 'dev': + return DefaultAzureCredential() # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development + else: + return ManagedIdentityCredential(client_id=client_id) diff --git a/src/App/backend/plugins/chat_with_data_plugin.py b/src/App/backend/plugins/chat_with_data_plugin.py index f421af7ef..86b064aac 100644 --- a/src/App/backend/plugins/chat_with_data_plugin.py +++ b/src/App/backend/plugins/chat_with_data_plugin.py @@ -9,7 +9,8 @@ MessageRole, ) from azure.ai.projects import AIProjectClient -from azure.identity import DefaultAzureCredential, get_bearer_token_provider +from azure.identity import get_bearer_token_provider +from backend.helpers.azure_credential_utils import get_azure_credential from semantic_kernel.functions.kernel_function_decorator import kernel_function from backend.common.config import config @@ -202,7 +203,7 @@ async def get_answers_from_calltranscripts( def get_openai_client(self): token_provider = get_bearer_token_provider( - DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" + get_azure_credential(), "https://cognitiveservices.azure.com/.default" ) openai_client = openai.AzureOpenAI( azure_endpoint=config.AZURE_OPENAI_ENDPOINT, @@ -213,7 +214,7 @@ def get_openai_client(self): def get_project_openai_client(self): project = AIProjectClient( - endpoint=config.AI_PROJECT_ENDPOINT, credential=DefaultAzureCredential() + endpoint=config.AI_PROJECT_ENDPOINT, credential=get_azure_credential() ) openai_client = project.inference.get_azure_openai_client( api_version=config.AZURE_OPENAI_PREVIEW_API_VERSION diff --git a/src/App/backend/services/sqldb_service.py b/src/App/backend/services/sqldb_service.py index be1c7b358..063c08bd1 100644 --- a/src/App/backend/services/sqldb_service.py +++ b/src/App/backend/services/sqldb_service.py @@ -3,11 +3,13 @@ import struct import pyodbc -from azure.identity import DefaultAzureCredential +from backend.helpers.azure_credential_utils import get_azure_credential from dotenv import load_dotenv from backend.common.config import config +import time + load_dotenv() driver = config.ODBC_DRIVER @@ -33,33 +35,47 @@ def dict_cursor(cursor): def get_connection(): - try: - credential = DefaultAzureCredential(managed_identity_client_id=mid_id) + max_retries = 5 + retry_delay = 2 - token_bytes = credential.get_token( - "https://database.windows.net/.default" - ).token.encode("utf-16-LE") - token_struct = struct.pack( - f" str: diff --git a/src/App/frontend/src/api/api.ts b/src/App/frontend/src/api/api.ts index 1cd6442c3..6245ba819 100644 --- a/src/App/frontend/src/api/api.ts +++ b/src/App/frontend/src/api/api.ts @@ -42,20 +42,33 @@ export const getpbi = async (): Promise => { return ''; } +const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + export const getUsers = async (): Promise => { - try { - const response = await fetch('/api/users'); - if (!response.ok) { - throw new Error(`Failed to fetch users: ${response.statusText}`); + const maxRetries = 1; + for (let attempt = 0; attempt <= maxRetries; attempt++) { + try { + const response = await fetch('/api/users', { + signal: AbortSignal.timeout(60000) + }); + if (!response.ok) { + throw new Error(`Failed to fetch users: ${response.statusText}`); + } + const data: User[] = await response.json(); + console.log('Fetched users:', data); + return data; + } catch (error) { + if (attempt < maxRetries && + error instanceof Error) { + console.warn(`Retrying fetch users... (retry ${attempt + 1}/${maxRetries})`); + await sleep(5000); // Simple 5 second delay + } else { + console.error('Error fetching users:', error); + return []; + } } - const data: User[] = await response.json(); - console.log('Fetched users:', data); - return data; - } catch (error) { - console.error('Error fetching users:', error); - return []; - // throw error; } + return []; }; // export const fetchChatHistoryInit = async (): Promise => { diff --git a/src/App/tests/backend/agents/test_agent_factory.py b/src/App/tests/backend/agents/test_agent_factory.py index 775cdab4c..676bba7be 100644 --- a/src/App/tests/backend/agents/test_agent_factory.py +++ b/src/App/tests/backend/agents/test_agent_factory.py @@ -21,7 +21,7 @@ def reset_singleton(self): @pytest.mark.asyncio @patch("backend.agents.agent_factory.AzureAIAgent") - @patch("backend.agents.agent_factory.DefaultAzureCredential") + @patch("backend.agents.agent_factory.get_azure_credential_async") @patch("backend.agents.agent_factory.AzureAIAgentSettings") @patch("backend.agents.agent_factory.ChatWithDataPlugin") async def test_get_wealth_advisor_agent_creates_agent_when_none_exists( @@ -74,7 +74,7 @@ async def test_get_wealth_advisor_agent_returns_existing_agent( @pytest.mark.asyncio @patch("backend.agents.agent_factory.config") @patch("backend.agents.agent_factory.AIProjectClient") - @patch("backend.agents.agent_factory.DefaultAzureCredentialSync") + @patch("backend.agents.agent_factory.get_azure_credential") async def test_get_search_agent_creates_agent_when_none_exists( self, mock_credential_sync, mock_ai_project_client, mock_config, reset_singleton ): @@ -112,7 +112,7 @@ async def test_get_search_agent_creates_agent_when_none_exists( @pytest.mark.asyncio @patch("backend.agents.agent_factory.config") @patch("backend.agents.agent_factory.AIProjectClient") - @patch("backend.agents.agent_factory.DefaultAzureCredentialSync") + @patch("backend.agents.agent_factory.get_azure_credential") async def test_get_search_agent_with_default_instructions( self, mock_credential_sync, mock_ai_project_client, mock_config, reset_singleton ): @@ -174,7 +174,7 @@ async def test_multiple_calls_return_same_wealth_advisor_instance( ) mock_agent_class.return_value = mock_agent_instance - with patch("backend.agents.agent_factory.DefaultAzureCredential"): + with patch("backend.agents.agent_factory.get_azure_credential_async"): with patch("backend.agents.agent_factory.AzureAIAgentSettings"): with patch("backend.agents.agent_factory.ChatWithDataPlugin"): # Act @@ -193,7 +193,7 @@ async def test_multiple_calls_return_same_search_agent_instance( with patch( "backend.agents.agent_factory.AIProjectClient" ) as mock_ai_project_client: - with patch("backend.agents.agent_factory.DefaultAzureCredentialSync"): + with patch("backend.agents.agent_factory.get_azure_credential"): mock_config.CALL_TRANSCRIPT_SYSTEM_PROMPT = "Test instructions" mock_config.AI_PROJECT_ENDPOINT = "https://test.endpoint.com" mock_config.AZURE_OPENAI_MODEL = "test-model" diff --git a/src/App/tests/backend/helpers/test_azure_credential_utils.py b/src/App/tests/backend/helpers/test_azure_credential_utils.py new file mode 100644 index 000000000..4eb79c332 --- /dev/null +++ b/src/App/tests/backend/helpers/test_azure_credential_utils.py @@ -0,0 +1,81 @@ +import pytest +import sys +import os +from unittest.mock import patch, MagicMock +import backend.helpers.azure_credential_utils as azure_credential_utils + +# Ensure src/backend is on the Python path for imports +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) + +# Synchronous tests + + +@patch("backend.helpers.azure_credential_utils.os.getenv") +@patch("backend.helpers.azure_credential_utils.DefaultAzureCredential") +@patch("backend.helpers.azure_credential_utils.ManagedIdentityCredential") +def test_get_azure_credential_dev_env(mock_managed_identity_credential, mock_default_azure_credential, mock_getenv): + """Test get_azure_credential in dev environment.""" + mock_getenv.return_value = "dev" + mock_default_credential = MagicMock() + mock_default_azure_credential.return_value = mock_default_credential + + credential = azure_credential_utils.get_azure_credential() + + mock_getenv.assert_called_once_with("APP_ENV", "prod") + mock_default_azure_credential.assert_called_once() + mock_managed_identity_credential.assert_not_called() + assert credential == mock_default_credential + + +@patch("backend.helpers.azure_credential_utils.os.getenv") +@patch("backend.helpers.azure_credential_utils.DefaultAzureCredential") +@patch("backend.helpers.azure_credential_utils.ManagedIdentityCredential") +def test_get_azure_credential_non_dev_env(mock_managed_identity_credential, mock_default_azure_credential, mock_getenv): + """Test get_azure_credential in non-dev environment.""" + mock_getenv.return_value = "prod" + mock_managed_credential = MagicMock() + mock_managed_identity_credential.return_value = mock_managed_credential + credential = azure_credential_utils.get_azure_credential(client_id="test-client-id") + + mock_getenv.assert_called_once_with("APP_ENV", "prod") + mock_managed_identity_credential.assert_called_once_with(client_id="test-client-id") + mock_default_azure_credential.assert_not_called() + assert credential == mock_managed_credential + +# Asynchronous tests + + +@pytest.mark.asyncio +@patch("backend.helpers.azure_credential_utils.os.getenv") +@patch("backend.helpers.azure_credential_utils.AioDefaultAzureCredential") +@patch("backend.helpers.azure_credential_utils.AioManagedIdentityCredential") +async def test_get_azure_credential_async_dev_env(mock_aio_managed_identity_credential, mock_aio_default_azure_credential, mock_getenv): + """Test get_azure_credential_async in dev environment.""" + mock_getenv.return_value = "dev" + mock_aio_default_credential = MagicMock() + mock_aio_default_azure_credential.return_value = mock_aio_default_credential + + credential = await azure_credential_utils.get_azure_credential_async() + + mock_getenv.assert_called_once_with("APP_ENV", "prod") + mock_aio_default_azure_credential.assert_called_once() + mock_aio_managed_identity_credential.assert_not_called() + assert credential == mock_aio_default_credential + + +@pytest.mark.asyncio +@patch("backend.helpers.azure_credential_utils.os.getenv") +@patch("backend.helpers.azure_credential_utils.AioDefaultAzureCredential") +@patch("backend.helpers.azure_credential_utils.AioManagedIdentityCredential") +async def test_get_azure_credential_async_non_dev_env(mock_aio_managed_identity_credential, mock_aio_default_azure_credential, mock_getenv): + """Test get_azure_credential_async in non-dev environment.""" + mock_getenv.return_value = "prod" + mock_aio_managed_credential = MagicMock() + mock_aio_managed_identity_credential.return_value = mock_aio_managed_credential + + credential = await azure_credential_utils.get_azure_credential_async(client_id="test-client-id") + + mock_getenv.assert_called_once_with("APP_ENV", "prod") + mock_aio_managed_identity_credential.assert_called_once_with(client_id="test-client-id") + mock_aio_default_azure_credential.assert_not_called() + assert credential == mock_aio_managed_credential diff --git a/src/App/tests/backend/plugins/test_chat_with_data_plugin.py b/src/App/tests/backend/plugins/test_chat_with_data_plugin.py index 684c947ae..cc1e91912 100644 --- a/src/App/tests/backend/plugins/test_chat_with_data_plugin.py +++ b/src/App/tests/backend/plugins/test_chat_with_data_plugin.py @@ -15,7 +15,7 @@ def setup_method(self): @patch("backend.plugins.chat_with_data_plugin.config") @patch("backend.plugins.chat_with_data_plugin.openai.AzureOpenAI") @patch("backend.plugins.chat_with_data_plugin.get_bearer_token_provider") - @patch("backend.plugins.chat_with_data_plugin.DefaultAzureCredential") + @patch("backend.plugins.chat_with_data_plugin.get_azure_credential") def test_get_openai_client_success( self, mock_default_credential, @@ -50,7 +50,7 @@ def test_get_openai_client_success( @patch("backend.plugins.chat_with_data_plugin.config") @patch("backend.plugins.chat_with_data_plugin.AIProjectClient") - @patch("backend.plugins.chat_with_data_plugin.DefaultAzureCredential") + @patch("backend.plugins.chat_with_data_plugin.get_azure_credential") def test_get_project_openai_client_success( self, mock_default_credential, mock_ai_project_client, mock_config ): diff --git a/src/App/tests/backend/services/test_sqldb_service.py b/src/App/tests/backend/services/test_sqldb_service.py index 3a3745c3f..fc61325e8 100644 --- a/src/App/tests/backend/services/test_sqldb_service.py +++ b/src/App/tests/backend/services/test_sqldb_service.py @@ -16,10 +16,10 @@ @patch("backend.services.sqldb_service.pyodbc.connect") # Mock pyodbc.connect @patch( - "backend.services.sqldb_service.DefaultAzureCredential" -) # Mock DefaultAzureCredential + "backend.services.sqldb_service.get_azure_credential" +) # Mock AzureCliCredential def test_get_connection(mock_credential_class, mock_connect): - # Mock the DefaultAzureCredential and get_token method + # Mock the AzureCliCredential and get_token method mock_credential = MagicMock() mock_credential_class.return_value = mock_credential mock_token = MagicMock() @@ -32,9 +32,9 @@ def test_get_connection(mock_credential_class, mock_connect): # Call the function conn = sql_db.get_connection() - # Assert that DefaultAzureCredential and get_token were called correctly + # Assert that AzureCliCredential and get_token were called correctly mock_credential_class.assert_called_once_with( - managed_identity_client_id=sql_db.mid_id + client_id=sql_db.mid_id ) mock_credential.get_token.assert_called_once_with( "https://database.windows.net/.default" @@ -59,10 +59,10 @@ def test_get_connection(mock_credential_class, mock_connect): @patch("backend.services.sqldb_service.pyodbc.connect") # Mock pyodbc.connect @patch( - "backend.services.sqldb_service.DefaultAzureCredential" -) # Mock DefaultAzureCredential + "backend.services.sqldb_service.get_azure_credential" +) # Mock AzureCliCredential def test_get_connection_token_failure(mock_credential_class, mock_connect): - # Mock the DefaultAzureCredential and get_token method + # Mock the AzureCliCredential and get_token method mock_credential = MagicMock() mock_credential_class.return_value = mock_credential mock_token = MagicMock()