Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions infra/main.bicep
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,10 @@ module containerApp 'br/public:avm/res/app/container-app:0.14.2' = if (container
name: 'AZURE_AI_AGENT_MODEL_DEPLOYMENT_NAME'
value: aiFoundryAiServicesModelDeployment.name
}
{
name: 'APP_ENV'
value: 'Prod'
}
]
}
]
Expand Down Expand Up @@ -1087,6 +1091,7 @@ module webSite 'br/public:avm/res/web/site:0.15.1' = if (webSiteEnabled) {
WEBSITES_CONTAINER_START_TIME_LIMIT: '1800' // 30 minutes, adjust as needed
BACKEND_API_URL: 'https://${containerApp.outputs.fqdn}'
AUTH_ENABLED: 'false'
APP_ENV: 'Prod'
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/backend/.env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ AZURE_AI_MODEL_DEPLOYMENT_NAME=gpt-4o
APPLICATIONINSIGHTS_CONNECTION_STRING=
AZURE_AI_AGENT_MODEL_DEPLOYMENT_NAME=gpt-4o
AZURE_AI_AGENT_ENDPOINT=
APP_ENV="dev"

BACKEND_API_URL=http://localhost:8000
FRONTEND_SITE_NAME=http://127.0.0.1:3000
25 changes: 4 additions & 21 deletions src/backend/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from azure.ai.projects.aio import AIProjectClient
from azure.cosmos.aio import CosmosClient
from azure.identity import DefaultAzureCredential
from helpers.azure_credential_utils import get_azure_credential
from dotenv import load_dotenv
from semantic_kernel.kernel import Kernel

Expand Down Expand Up @@ -106,23 +106,6 @@ def _get_bool(self, name: str) -> bool:
"""
return name in os.environ and os.environ[name].lower() in ["true", "1"]

def get_azure_credentials(self):
"""Get Azure credentials using DefaultAzureCredential.

Returns:
DefaultAzureCredential instance for Azure authentication
"""
# Cache the credentials object
if self._azure_credentials is not None:
return self._azure_credentials

try:
self._azure_credentials = DefaultAzureCredential()
return self._azure_credentials
except Exception as exc:
logging.warning("Failed to create DefaultAzureCredential: %s", exc)
return None

def get_cosmos_database_client(self):
"""Get a Cosmos DB client for the configured database.

Expand All @@ -132,7 +115,7 @@ def get_cosmos_database_client(self):
try:
if self._cosmos_client is None:
self._cosmos_client = CosmosClient(
self.COSMOSDB_ENDPOINT, credential=self.get_azure_credentials()
self.COSMOSDB_ENDPOINT, credential=get_azure_credential()
)

if self._cosmos_database is None:
Expand Down Expand Up @@ -169,10 +152,10 @@ def get_ai_project_client(self):
return self._ai_project_client

try:
credential = self.get_azure_credentials()
credential = get_azure_credential()
if credential is None:
raise RuntimeError(
"Unable to acquire Azure credentials; ensure DefaultAzureCredential is configured"
"Unable to acquire Azure credentials; ensure Managed Identity is configured"
)

endpoint = self.AZURE_AI_AGENT_ENDPOINT
Expand Down
3 changes: 2 additions & 1 deletion src/backend/config_kernel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Import AppConfig from app_config
from app_config import config
from helpers.azure_credential_utils import get_azure_credential


# This file is left as a lightweight wrapper around AppConfig for backward compatibility
Expand Down Expand Up @@ -31,7 +32,7 @@ class Config:
@staticmethod
def GetAzureCredentials():
"""Get Azure credentials using the AppConfig implementation."""
return config.get_azure_credentials()
return get_azure_credential()

@staticmethod
def GetCosmosDatabaseClient():
Expand Down
4 changes: 2 additions & 2 deletions src/backend/context/cosmos_memory_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from azure.cosmos.partition_key import PartitionKey
from azure.cosmos.aio import CosmosClient
from azure.identity import DefaultAzureCredential
from helpers.azure_credential_utils import get_azure_credential
from semantic_kernel.memory.memory_record import MemoryRecord
from semantic_kernel.memory.memory_store_base import MemoryStoreBase
from semantic_kernel.contents import ChatMessageContent, ChatHistory, AuthorRole
Expand Down Expand Up @@ -73,7 +73,7 @@ async def initialize(self):
if not self._database:
# Create Cosmos client
cosmos_client = CosmosClient(
self._cosmos_endpoint, credential=DefaultAzureCredential()
self._cosmos_endpoint, credential=get_azure_credential()
)
self._database = cosmos_client.get_database_client(
self._cosmos_database
Expand Down
41 changes: 41 additions & 0 deletions src/backend/helpers/azure_credential_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
from azure.identity import ManagedIdentityCredential, DefaultAzureCredential
from azure.identity.aio import ManagedIdentityCredential as AioManagedIdentityCredential, DefaultAzureCredential as AioDefaultAzureCredential


async def get_azure_credential_async(client_id=None):
"""
Returns an Azure credential asynchronously based on the application environment.

If the environment is 'dev', it uses AioDefaultAzureCredential.
Otherwise, it uses AioManagedIdentityCredential.

Args:
client_id (str, optional): The client ID for the Managed Identity Credential.

Returns:
Credential object: Either AioDefaultAzureCredential or AioManagedIdentityCredential.
"""
if os.getenv("APP_ENV", "prod").lower() == 'dev':
return AioDefaultAzureCredential() # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development
else:
return AioManagedIdentityCredential(client_id=client_id)


def get_azure_credential(client_id=None):
"""
Returns an Azure credential based on the application environment.

If the environment is 'dev', it uses DefaultAzureCredential.
Otherwise, it uses ManagedIdentityCredential.

Args:
client_id (str, optional): The client ID for the Managed Identity Credential.

Returns:
Credential object: Either DefaultAzureCredential or ManagedIdentityCredential.
"""
if os.getenv("APP_ENV", "prod").lower() == 'dev':
return DefaultAzureCredential() # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development
else:
return ManagedIdentityCredential(client_id=client_id)
78 changes: 78 additions & 0 deletions src/backend/tests/helpers/test_azure_credential_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest
import sys
import os
from unittest.mock import patch, MagicMock

# Ensure src/backend is on the Python path for imports
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

import helpers.azure_credential_utils as azure_credential_utils

# Synchronous tests

@patch("helpers.azure_credential_utils.os.getenv")
@patch("helpers.azure_credential_utils.DefaultAzureCredential")
@patch("helpers.azure_credential_utils.ManagedIdentityCredential")
def test_get_azure_credential_dev_env(mock_managed_identity_credential, mock_default_azure_credential, mock_getenv):
"""Test get_azure_credential in dev environment."""
mock_getenv.return_value = "dev"
mock_default_credential = MagicMock()
mock_default_azure_credential.return_value = mock_default_credential

credential = azure_credential_utils.get_azure_credential()

mock_getenv.assert_called_once_with("APP_ENV", "prod")
mock_default_azure_credential.assert_called_once()
mock_managed_identity_credential.assert_not_called()
assert credential == mock_default_credential

@patch("helpers.azure_credential_utils.os.getenv")
@patch("helpers.azure_credential_utils.DefaultAzureCredential")
@patch("helpers.azure_credential_utils.ManagedIdentityCredential")
def test_get_azure_credential_non_dev_env(mock_managed_identity_credential, mock_default_azure_credential, mock_getenv):
"""Test get_azure_credential in non-dev environment."""
mock_getenv.return_value = "prod"
mock_managed_credential = MagicMock()
mock_managed_identity_credential.return_value = mock_managed_credential
credential = azure_credential_utils.get_azure_credential(client_id="test-client-id")

mock_getenv.assert_called_once_with("APP_ENV", "prod")
mock_managed_identity_credential.assert_called_once_with(client_id="test-client-id")
mock_default_azure_credential.assert_not_called()
assert credential == mock_managed_credential

# Asynchronous tests

@pytest.mark.asyncio
@patch("helpers.azure_credential_utils.os.getenv")
@patch("helpers.azure_credential_utils.AioDefaultAzureCredential")
@patch("helpers.azure_credential_utils.AioManagedIdentityCredential")
async def test_get_azure_credential_async_dev_env(mock_aio_managed_identity_credential, mock_aio_default_azure_credential, mock_getenv):
"""Test get_azure_credential_async in dev environment."""
mock_getenv.return_value = "dev"
mock_aio_default_credential = MagicMock()
mock_aio_default_azure_credential.return_value = mock_aio_default_credential

credential = await azure_credential_utils.get_azure_credential_async()

mock_getenv.assert_called_once_with("APP_ENV", "prod")
mock_aio_default_azure_credential.assert_called_once()
mock_aio_managed_identity_credential.assert_not_called()
assert credential == mock_aio_default_credential

@pytest.mark.asyncio
@patch("helpers.azure_credential_utils.os.getenv")
@patch("helpers.azure_credential_utils.AioDefaultAzureCredential")
@patch("helpers.azure_credential_utils.AioManagedIdentityCredential")
async def test_get_azure_credential_async_non_dev_env(mock_aio_managed_identity_credential, mock_aio_default_azure_credential, mock_getenv):
"""Test get_azure_credential_async in non-dev environment."""
mock_getenv.return_value = "prod"
mock_aio_managed_credential = MagicMock()
mock_aio_managed_identity_credential.return_value = mock_aio_managed_credential

credential = await azure_credential_utils.get_azure_credential_async(client_id="test-client-id")

mock_getenv.assert_called_once_with("APP_ENV", "prod")
mock_aio_managed_identity_credential.assert_called_once_with(client_id="test-client-id")
mock_aio_default_azure_credential.assert_not_called()
assert credential == mock_aio_managed_credential
8 changes: 0 additions & 8 deletions src/backend/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,3 @@ def test_get_bool_config():
assert GetBoolConfig("FEATURE_ENABLED") is True
with patch.dict("os.environ", {"FEATURE_ENABLED": "0"}):
assert GetBoolConfig("FEATURE_ENABLED") is False


@patch("config.DefaultAzureCredential")
def test_get_azure_credentials_with_env_vars(mock_default_cred):
"""Test Config.GetAzureCredentials with explicit credentials."""
with patch.dict(os.environ, MOCK_ENV_VARS):
creds = Config.GetAzureCredentials()
assert creds is not None
8 changes: 5 additions & 3 deletions src/backend/utils_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

# Import AppConfig from app_config
from app_config import config
from azure.identity import DefaultAzureCredential
from context.cosmos_memory_kernel import CosmosMemoryContext

# Import the credential utility
from helpers.azure_credential_utils import get_azure_credential

# Import agent factory and the new AppConfig
from kernel_agents.agent_factory import AgentFactory
from kernel_agents.group_chat_manager import GroupChatManager
Expand Down Expand Up @@ -169,8 +171,8 @@ async def rai_success(description: str) -> bool:
True if it passes, False otherwise
"""
try:
# Use DefaultAzureCredential for authentication to Azure OpenAI
credential = DefaultAzureCredential()
# Use managed identity for authentication to Azure OpenAI
credential = get_azure_credential()
access_token = credential.get_token(
"https://cognitiveservices.azure.com/.default"
).token
Expand Down
1 change: 1 addition & 0 deletions src/frontend/.env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

API_URL=http://localhost:8000
ENABLE_AUTH=false
APP_ENV="dev"
# VITE_APP_MSAL_AUTH_CLIENTID=""
# VITE_APP_MSAL_AUTH_AUTHORITY=""
# VITE_APP_MSAL_REDIRECT_URL="/"
Expand Down
Loading