Skip to content

Commit 6eafdfa

Browse files
fix: Replace DefaultAzureCredential with ManagedIdentityCredential for production-safe authentication (#1876)
Co-authored-by: Prajwal D C <[email protected]>
1 parent 2ff456e commit 6eafdfa

33 files changed

+316
-132
lines changed

.env.sample

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ AZURE_SPEECH_SERVICE_REGION=
6363
AZURE_AUTH_TYPE=keys
6464
USE_KEY_VAULT=true
6565
AZURE_KEY_VAULT_ENDPOINT=
66+
# Application environment (e.g., dev, prod)
67+
APP_ENV="dev"
6668
# Chat conversation type to decide between custom or byod (bring your own data) conversation type
6769
CONVERSATION_FLOW=
6870
# Chat History CosmosDB Integration Settings

code/backend/batch/utilities/chat_history/database_factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from ..helpers.env_helper import EnvHelper
33
from .cosmosdb import CosmosConversationClient
44
from .postgresdbservice import PostgresConversationClient
5-
from azure.identity import DefaultAzureCredential
5+
from ..helpers.azure_credential_utils import get_azure_credential
66
from ..helpers.config.database_type import DatabaseType
77

88

@@ -25,7 +25,7 @@ def get_conversation_client():
2525
f"https://{env_helper.AZURE_COSMOSDB_ACCOUNT}.documents.azure.com:443/"
2626
)
2727
credential = (
28-
DefaultAzureCredential()
28+
get_azure_credential()
2929
if not env_helper.AZURE_COSMOSDB_ACCOUNT_KEY
3030
else env_helper.AZURE_COSMOSDB_ACCOUNT_KEY
3131
)

code/backend/batch/utilities/chat_history/postgresdbservice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import asyncpg
33
from datetime import datetime, timezone
4-
from azure.identity import DefaultAzureCredential
4+
from ..helpers.azure_credential_utils import get_azure_credential
55

66
from .database_client_base import DatabaseClientBase
77

@@ -21,7 +21,7 @@ def __init__(
2121

2222
async def connect(self):
2323
try:
24-
credential = DefaultAzureCredential()
24+
credential = get_azure_credential()
2525
token = credential.get_token(
2626
"https://ossrdbms-aad.database.windows.net/.default"
2727
).token

code/backend/batch/utilities/helpers/azure_blob_storage_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from azure.storage.queue import QueueClient, BinaryBase64EncodePolicy
1313
import chardet
1414
from .env_helper import EnvHelper
15-
from azure.identity import DefaultAzureCredential
15+
from .azure_credential_utils import get_azure_credential
1616

1717

1818
def connection_string(account_name: str, account_key: str):
@@ -25,7 +25,7 @@ def create_queue_client():
2525
return QueueClient(
2626
account_url=f"https://{env_helper.AZURE_BLOB_ACCOUNT_NAME}.queue.core.windows.net/",
2727
queue_name=env_helper.DOCUMENT_PROCESSING_QUEUE_NAME,
28-
credential=DefaultAzureCredential(),
28+
credential=get_azure_credential(),
2929
message_encode_policy=BinaryBase64EncodePolicy(),
3030
)
3131

@@ -56,7 +56,7 @@ def __init__(
5656
if self.auth_type == "rbac":
5757
self.account_key = None
5858
self.blob_service_client = BlobServiceClient(
59-
account_url=self.endpoint, credential=DefaultAzureCredential()
59+
account_url=self.endpoint, credential=get_azure_credential()
6060
)
6161
self.user_delegation_key = self.request_user_delegation_key(
6262
blob_service_client=self.blob_service_client

code/backend/batch/utilities/helpers/azure_computer_vision_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from urllib.parse import urljoin
3-
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
3+
from azure.identity import get_bearer_token_provider
4+
from .azure_credential_utils import get_azure_credential
45

56
import requests
67
from requests import Response
@@ -56,7 +57,7 @@ def __make_request(self, path: str, body) -> Response:
5657
headers["Ocp-Apim-Subscription-Key"] = self.key
5758
else:
5859
token_provider = get_bearer_token_provider(
59-
DefaultAzureCredential(), self.__TOKEN_SCOPE
60+
get_azure_credential(), self.__TOKEN_SCOPE
6061
)
6162
headers["Authorization"] = "Bearer " + token_provider()
6263

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import os
2+
from azure.identity import ManagedIdentityCredential, DefaultAzureCredential
3+
from azure.identity.aio import (
4+
ManagedIdentityCredential as AioManagedIdentityCredential,
5+
DefaultAzureCredential as AioDefaultAzureCredential,
6+
)
7+
8+
9+
async def get_azure_credential_async(client_id=None):
10+
"""
11+
Returns an Azure credential asynchronously based on the application environment.
12+
13+
If the environment is 'dev', it uses AioDefaultAzureCredential.
14+
Otherwise, it uses AioManagedIdentityCredential.
15+
16+
Args:
17+
client_id (str, optional): The client ID for the Managed Identity Credential.
18+
19+
Returns:
20+
Credential object: Either AioDefaultAzureCredential or AioManagedIdentityCredential.
21+
"""
22+
if os.getenv("APP_ENV", "prod").lower() == "dev":
23+
return (
24+
AioDefaultAzureCredential()
25+
) # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development
26+
else:
27+
return AioManagedIdentityCredential(client_id=client_id)
28+
29+
30+
def get_azure_credential(client_id=None):
31+
"""
32+
Returns an Azure credential based on the application environment.
33+
34+
If the environment is 'dev', it uses DefaultAzureCredential.
35+
Otherwise, it uses ManagedIdentityCredential.
36+
37+
Args:
38+
client_id (str, optional): The client ID for the Managed Identity Credential.
39+
40+
Returns:
41+
Credential object: Either DefaultAzureCredential or ManagedIdentityCredential.
42+
"""
43+
if os.getenv("APP_ENV", "prod").lower() == "dev":
44+
return (
45+
DefaultAzureCredential()
46+
) # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development
47+
else:
48+
return ManagedIdentityCredential(client_id=client_id)

code/backend/batch/utilities/helpers/azure_form_recognizer_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from azure.core.credentials import AzureKeyCredential
33
from azure.ai.formrecognizer import DocumentAnalysisClient
4-
from azure.identity import DefaultAzureCredential
4+
from .azure_credential_utils import get_azure_credential
55
import html
66
import traceback
77
from .env_helper import EnvHelper
@@ -19,7 +19,7 @@ def __init__(self) -> None:
1919
if env_helper.AZURE_AUTH_TYPE == "rbac":
2020
self.document_analysis_client = DocumentAnalysisClient(
2121
endpoint=self.AZURE_FORM_RECOGNIZER_ENDPOINT,
22-
credential=DefaultAzureCredential(),
22+
credential=get_azure_credential(),
2323
headers={
2424
"x-ms-useragent": "chat-with-your-data-solution-accelerator/1.0.0"
2525
},

code/backend/batch/utilities/helpers/azure_postgres_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import psycopg2
33
from psycopg2.extras import execute_values, RealDictCursor
4-
from azure.identity import DefaultAzureCredential
4+
from .azure_credential_utils import get_azure_credential
55
from .llm_helper import LLMHelper
66
from .env_helper import EnvHelper
77

@@ -24,7 +24,7 @@ def _create_search_client(self):
2424
dbname = self.env_helper.POSTGRESQL_DATABASE
2525

2626
# Acquire the access token
27-
credential = DefaultAzureCredential()
27+
credential = get_azure_credential()
2828
access_token = credential.get_token(
2929
"https://ossrdbms-aad.database.windows.net/.default"
3030
)

code/backend/batch/utilities/helpers/azure_search_helper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Union
33
from langchain_community.vectorstores import AzureSearch
44
from azure.core.credentials import AzureKeyCredential
5-
from azure.identity import DefaultAzureCredential
5+
from .azure_credential_utils import get_azure_credential
66
from azure.search.documents import SearchClient
77
from azure.search.documents.indexes import SearchIndexClient
88
from azure.search.documents.indexes.models import (
@@ -49,10 +49,10 @@ def _search_credential(self):
4949
if self.env_helper.is_auth_type_keys():
5050
return AzureKeyCredential(self.env_helper.AZURE_SEARCH_KEY)
5151
else:
52-
return DefaultAzureCredential()
52+
return get_azure_credential()
5353

5454
def _create_search_client(
55-
self, search_credential: Union[AzureKeyCredential, DefaultAzureCredential]
55+
self, search_credential: Union[AzureKeyCredential, get_azure_credential]
5656
) -> SearchClient:
5757
return SearchClient(
5858
endpoint=self.env_helper.AZURE_SEARCH_SERVICE,
@@ -61,7 +61,7 @@ def _create_search_client(
6161
)
6262

6363
def _create_search_index_client(
64-
self, search_credential: Union[AzureKeyCredential, DefaultAzureCredential]
64+
self, search_credential: Union[AzureKeyCredential, get_azure_credential]
6565
):
6666
return SearchIndexClient(
6767
endpoint=self.env_helper.AZURE_SEARCH_SERVICE, credential=search_credential
@@ -285,7 +285,7 @@ def get_conversation_logger(self):
285285
]
286286

287287
if self.env_helper.AZURE_AUTH_TYPE == "rbac":
288-
credential = DefaultAzureCredential()
288+
credential = get_azure_credential()
289289
return AzureSearch(
290290
azure_search_endpoint=self.env_helper.AZURE_SEARCH_SERVICE,
291291
azure_search_key=None, # Remove API key

code/backend/batch/utilities/helpers/env_helper.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import logging
44
import threading
55
from dotenv import load_dotenv
6-
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
6+
from azure.identity import get_bearer_token_provider
7+
from .azure_credential_utils import get_azure_credential
78
from azure.keyvault.secrets import SecretClient
89

910
from ..orchestrator.orchestration_strategy import OrchestrationStrategy
@@ -216,7 +217,7 @@ def __load_config(self, **kwargs) -> None:
216217
)
217218

218219
self.AZURE_TOKEN_PROVIDER = get_bearer_token_provider(
219-
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
220+
get_azure_credential(), "https://cognitiveservices.azure.com/.default"
220221
)
221222
self.ADVANCED_IMAGE_PROCESSING_MAX_IMAGES = self.get_env_var_int(
222223
"ADVANCED_IMAGE_PROCESSING_MAX_IMAGES", 1
@@ -361,8 +362,8 @@ def __load_config(self, **kwargs) -> None:
361362
self.OPEN_AI_FUNCTIONS_SYSTEM_PROMPT = os.getenv(
362363
"OPEN_AI_FUNCTIONS_SYSTEM_PROMPT", ""
363364
)
364-
self.SEMENTIC_KERNEL_SYSTEM_PROMPT = os.getenv(
365-
"SEMENTIC_KERNEL_SYSTEM_PROMPT", ""
365+
self.SEMANTIC_KERNEL_SYSTEM_PROMPT = os.getenv(
366+
"SEMANTIC_KERNEL_SYSTEM_PROMPT", ""
366367
)
367368

368369
self.ENFORCE_AUTH = self.get_env_var_bool("ENFORCE_AUTH", "True")
@@ -416,7 +417,7 @@ def __init__(self) -> None:
416417
417418
The constructor sets the USE_KEY_VAULT attribute based on the value of the USE_KEY_VAULT environment variable.
418419
If USE_KEY_VAULT is set to "true" (case-insensitive), it initializes a SecretClient object using the
419-
AZURE_KEY_VAULT_ENDPOINT environment variable and the DefaultAzureCredential.
420+
AZURE_KEY_VAULT_ENDPOINT environment variable and the get_azure_credential.
420421
421422
Args:
422423
None
@@ -428,7 +429,7 @@ def __init__(self) -> None:
428429
self.secret_client = None
429430
if self.USE_KEY_VAULT:
430431
self.secret_client = SecretClient(
431-
os.environ.get("AZURE_KEY_VAULT_ENDPOINT"), DefaultAzureCredential()
432+
os.environ.get("AZURE_KEY_VAULT_ENDPOINT"), get_azure_credential()
432433
)
433434

434435
def get_secret(self, secret_name: str) -> str:

0 commit comments

Comments
 (0)