Skip to content

Commit 902a656

Browse files
Merge pull request microsoft#350 from microsoft/rc-psl-sfi-dev
feat: replacing DefaultAzureCredential with ManagedIdentityCredential
2 parents ec23bc8 + 0ba8780 commit 902a656

File tree

10 files changed

+139
-35
lines changed

10 files changed

+139
-35
lines changed

infra/main.bicep

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,10 @@ module containerApp 'br/public:avm/res/app/container-app:0.14.2' = if (container
10341034
name: 'AZURE_AI_AGENT_MODEL_DEPLOYMENT_NAME'
10351035
value: aiFoundryAiServicesModelDeployment.name
10361036
}
1037+
{
1038+
name: 'APP_ENV'
1039+
value: 'Prod'
1040+
}
10371041
]
10381042
}
10391043
]
@@ -1087,6 +1091,7 @@ module webSite 'br/public:avm/res/web/site:0.15.1' = if (webSiteEnabled) {
10871091
WEBSITES_CONTAINER_START_TIME_LIMIT: '1800' // 30 minutes, adjust as needed
10881092
BACKEND_API_URL: 'https://${containerApp.outputs.fqdn}'
10891093
AUTH_ENABLED: 'false'
1094+
APP_ENV: 'Prod'
10901095
}
10911096
}
10921097
}

src/backend/.env.sample

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ AZURE_AI_MODEL_DEPLOYMENT_NAME=gpt-4o
1616
APPLICATIONINSIGHTS_CONNECTION_STRING=
1717
AZURE_AI_AGENT_MODEL_DEPLOYMENT_NAME=gpt-4o
1818
AZURE_AI_AGENT_ENDPOINT=
19+
APP_ENV="dev"
1920

2021
BACKEND_API_URL=http://localhost:8000
2122
FRONTEND_SITE_NAME=http://127.0.0.1:3000

src/backend/app_config.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from azure.ai.projects.aio import AIProjectClient
77
from azure.cosmos.aio import CosmosClient
8-
from azure.identity import DefaultAzureCredential
8+
from helpers.azure_credential_utils import get_azure_credential
99
from dotenv import load_dotenv
1010
from semantic_kernel.kernel import Kernel
1111

@@ -106,23 +106,6 @@ def _get_bool(self, name: str) -> bool:
106106
"""
107107
return name in os.environ and os.environ[name].lower() in ["true", "1"]
108108

109-
def get_azure_credentials(self):
110-
"""Get Azure credentials using DefaultAzureCredential.
111-
112-
Returns:
113-
DefaultAzureCredential instance for Azure authentication
114-
"""
115-
# Cache the credentials object
116-
if self._azure_credentials is not None:
117-
return self._azure_credentials
118-
119-
try:
120-
self._azure_credentials = DefaultAzureCredential()
121-
return self._azure_credentials
122-
except Exception as exc:
123-
logging.warning("Failed to create DefaultAzureCredential: %s", exc)
124-
return None
125-
126109
def get_cosmos_database_client(self):
127110
"""Get a Cosmos DB client for the configured database.
128111
@@ -132,7 +115,7 @@ def get_cosmos_database_client(self):
132115
try:
133116
if self._cosmos_client is None:
134117
self._cosmos_client = CosmosClient(
135-
self.COSMOSDB_ENDPOINT, credential=self.get_azure_credentials()
118+
self.COSMOSDB_ENDPOINT, credential=get_azure_credential()
136119
)
137120

138121
if self._cosmos_database is None:
@@ -169,10 +152,10 @@ def get_ai_project_client(self):
169152
return self._ai_project_client
170153

171154
try:
172-
credential = self.get_azure_credentials()
155+
credential = get_azure_credential()
173156
if credential is None:
174157
raise RuntimeError(
175-
"Unable to acquire Azure credentials; ensure DefaultAzureCredential is configured"
158+
"Unable to acquire Azure credentials; ensure Managed Identity is configured"
176159
)
177160

178161
endpoint = self.AZURE_AI_AGENT_ENDPOINT

src/backend/config_kernel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Import AppConfig from app_config
22
from app_config import config
3+
from helpers.azure_credential_utils import get_azure_credential
34

45

56
# This file is left as a lightweight wrapper around AppConfig for backward compatibility
@@ -31,7 +32,7 @@ class Config:
3132
@staticmethod
3233
def GetAzureCredentials():
3334
"""Get Azure credentials using the AppConfig implementation."""
34-
return config.get_azure_credentials()
35+
return get_azure_credential()
3536

3637
@staticmethod
3738
def GetCosmosDatabaseClient():

src/backend/context/cosmos_memory_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from azure.cosmos.partition_key import PartitionKey
1212
from azure.cosmos.aio import CosmosClient
13-
from azure.identity import DefaultAzureCredential
13+
from helpers.azure_credential_utils import get_azure_credential
1414
from semantic_kernel.memory.memory_record import MemoryRecord
1515
from semantic_kernel.memory.memory_store_base import MemoryStoreBase
1616
from semantic_kernel.contents import ChatMessageContent, ChatHistory, AuthorRole
@@ -73,7 +73,7 @@ async def initialize(self):
7373
if not self._database:
7474
# Create Cosmos client
7575
cosmos_client = CosmosClient(
76-
self._cosmos_endpoint, credential=DefaultAzureCredential()
76+
self._cosmos_endpoint, credential=get_azure_credential()
7777
)
7878
self._database = cosmos_client.get_database_client(
7979
self._cosmos_database
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
from azure.identity import ManagedIdentityCredential, DefaultAzureCredential
3+
from azure.identity.aio import ManagedIdentityCredential as AioManagedIdentityCredential, DefaultAzureCredential as AioDefaultAzureCredential
4+
5+
6+
async def get_azure_credential_async(client_id=None):
7+
"""
8+
Returns an Azure credential asynchronously based on the application environment.
9+
10+
If the environment is 'dev', it uses AioDefaultAzureCredential.
11+
Otherwise, it uses AioManagedIdentityCredential.
12+
13+
Args:
14+
client_id (str, optional): The client ID for the Managed Identity Credential.
15+
16+
Returns:
17+
Credential object: Either AioDefaultAzureCredential or AioManagedIdentityCredential.
18+
"""
19+
if os.getenv("APP_ENV", "prod").lower() == 'dev':
20+
return AioDefaultAzureCredential() # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development
21+
else:
22+
return AioManagedIdentityCredential(client_id=client_id)
23+
24+
25+
def get_azure_credential(client_id=None):
26+
"""
27+
Returns an Azure credential based on the application environment.
28+
29+
If the environment is 'dev', it uses DefaultAzureCredential.
30+
Otherwise, it uses ManagedIdentityCredential.
31+
32+
Args:
33+
client_id (str, optional): The client ID for the Managed Identity Credential.
34+
35+
Returns:
36+
Credential object: Either DefaultAzureCredential or ManagedIdentityCredential.
37+
"""
38+
if os.getenv("APP_ENV", "prod").lower() == 'dev':
39+
return DefaultAzureCredential() # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development
40+
else:
41+
return ManagedIdentityCredential(client_id=client_id)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import pytest
2+
import sys
3+
import os
4+
from unittest.mock import patch, MagicMock
5+
6+
# Ensure src/backend is on the Python path for imports
7+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
8+
9+
import helpers.azure_credential_utils as azure_credential_utils
10+
11+
# Synchronous tests
12+
13+
@patch("helpers.azure_credential_utils.os.getenv")
14+
@patch("helpers.azure_credential_utils.DefaultAzureCredential")
15+
@patch("helpers.azure_credential_utils.ManagedIdentityCredential")
16+
def test_get_azure_credential_dev_env(mock_managed_identity_credential, mock_default_azure_credential, mock_getenv):
17+
"""Test get_azure_credential in dev environment."""
18+
mock_getenv.return_value = "dev"
19+
mock_default_credential = MagicMock()
20+
mock_default_azure_credential.return_value = mock_default_credential
21+
22+
credential = azure_credential_utils.get_azure_credential()
23+
24+
mock_getenv.assert_called_once_with("APP_ENV", "prod")
25+
mock_default_azure_credential.assert_called_once()
26+
mock_managed_identity_credential.assert_not_called()
27+
assert credential == mock_default_credential
28+
29+
@patch("helpers.azure_credential_utils.os.getenv")
30+
@patch("helpers.azure_credential_utils.DefaultAzureCredential")
31+
@patch("helpers.azure_credential_utils.ManagedIdentityCredential")
32+
def test_get_azure_credential_non_dev_env(mock_managed_identity_credential, mock_default_azure_credential, mock_getenv):
33+
"""Test get_azure_credential in non-dev environment."""
34+
mock_getenv.return_value = "prod"
35+
mock_managed_credential = MagicMock()
36+
mock_managed_identity_credential.return_value = mock_managed_credential
37+
credential = azure_credential_utils.get_azure_credential(client_id="test-client-id")
38+
39+
mock_getenv.assert_called_once_with("APP_ENV", "prod")
40+
mock_managed_identity_credential.assert_called_once_with(client_id="test-client-id")
41+
mock_default_azure_credential.assert_not_called()
42+
assert credential == mock_managed_credential
43+
44+
# Asynchronous tests
45+
46+
@pytest.mark.asyncio
47+
@patch("helpers.azure_credential_utils.os.getenv")
48+
@patch("helpers.azure_credential_utils.AioDefaultAzureCredential")
49+
@patch("helpers.azure_credential_utils.AioManagedIdentityCredential")
50+
async def test_get_azure_credential_async_dev_env(mock_aio_managed_identity_credential, mock_aio_default_azure_credential, mock_getenv):
51+
"""Test get_azure_credential_async in dev environment."""
52+
mock_getenv.return_value = "dev"
53+
mock_aio_default_credential = MagicMock()
54+
mock_aio_default_azure_credential.return_value = mock_aio_default_credential
55+
56+
credential = await azure_credential_utils.get_azure_credential_async()
57+
58+
mock_getenv.assert_called_once_with("APP_ENV", "prod")
59+
mock_aio_default_azure_credential.assert_called_once()
60+
mock_aio_managed_identity_credential.assert_not_called()
61+
assert credential == mock_aio_default_credential
62+
63+
@pytest.mark.asyncio
64+
@patch("helpers.azure_credential_utils.os.getenv")
65+
@patch("helpers.azure_credential_utils.AioDefaultAzureCredential")
66+
@patch("helpers.azure_credential_utils.AioManagedIdentityCredential")
67+
async def test_get_azure_credential_async_non_dev_env(mock_aio_managed_identity_credential, mock_aio_default_azure_credential, mock_getenv):
68+
"""Test get_azure_credential_async in non-dev environment."""
69+
mock_getenv.return_value = "prod"
70+
mock_aio_managed_credential = MagicMock()
71+
mock_aio_managed_identity_credential.return_value = mock_aio_managed_credential
72+
73+
credential = await azure_credential_utils.get_azure_credential_async(client_id="test-client-id")
74+
75+
mock_getenv.assert_called_once_with("APP_ENV", "prod")
76+
mock_aio_managed_identity_credential.assert_called_once_with(client_id="test-client-id")
77+
mock_aio_default_azure_credential.assert_not_called()
78+
assert credential == mock_aio_managed_credential

src/backend/tests/test_config.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,3 @@ def test_get_bool_config():
5252
assert GetBoolConfig("FEATURE_ENABLED") is True
5353
with patch.dict("os.environ", {"FEATURE_ENABLED": "0"}):
5454
assert GetBoolConfig("FEATURE_ENABLED") is False
55-
56-
57-
@patch("config.DefaultAzureCredential")
58-
def test_get_azure_credentials_with_env_vars(mock_default_cred):
59-
"""Test Config.GetAzureCredentials with explicit credentials."""
60-
with patch.dict(os.environ, MOCK_ENV_VARS):
61-
creds = Config.GetAzureCredentials()
62-
assert creds is not None

src/backend/utils_kernel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111

1212
# Import AppConfig from app_config
1313
from app_config import config
14-
from azure.identity import DefaultAzureCredential
1514
from context.cosmos_memory_kernel import CosmosMemoryContext
1615

16+
# Import the credential utility
17+
from helpers.azure_credential_utils import get_azure_credential
18+
1719
# Import agent factory and the new AppConfig
1820
from kernel_agents.agent_factory import AgentFactory
1921
from kernel_agents.group_chat_manager import GroupChatManager
@@ -169,8 +171,8 @@ async def rai_success(description: str) -> bool:
169171
True if it passes, False otherwise
170172
"""
171173
try:
172-
# Use DefaultAzureCredential for authentication to Azure OpenAI
173-
credential = DefaultAzureCredential()
174+
# Use managed identity for authentication to Azure OpenAI
175+
credential = get_azure_credential()
174176
access_token = credential.get_token(
175177
"https://cognitiveservices.azure.com/.default"
176178
).token

src/frontend/.env.sample

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
API_URL=http://localhost:8000
44
ENABLE_AUTH=false
5+
APP_ENV="dev"
56
# VITE_APP_MSAL_AUTH_CLIENTID=""
67
# VITE_APP_MSAL_AUTH_AUTHORITY=""
78
# VITE_APP_MSAL_REDIRECT_URL="/"

0 commit comments

Comments
 (0)