Skip to content

fix: Replace DefaultAzureCredential with ManagedIdentityCredential for production-safe authentication #1871

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ AZURE_SPEECH_SERVICE_REGION=
AZURE_AUTH_TYPE=keys
USE_KEY_VAULT=true
AZURE_KEY_VAULT_ENDPOINT=
# Application environment (e.g., dev, prod)
APP_ENV="dev"
# Chat conversation type to decide between custom or byod (bring your own data) conversation type
CONVERSATION_FLOW=
# Chat History CosmosDB Integration Settings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ..helpers.env_helper import EnvHelper
from .cosmosdb import CosmosConversationClient
from .postgresdbservice import PostgresConversationClient
from azure.identity import DefaultAzureCredential
from ..helpers.azure_credential_utils import get_azure_credential
from ..helpers.config.database_type import DatabaseType


Expand All @@ -25,7 +25,7 @@ def get_conversation_client():
f"https://{env_helper.AZURE_COSMOSDB_ACCOUNT}.documents.azure.com:443/"
)
credential = (
DefaultAzureCredential()
get_azure_credential()
if not env_helper.AZURE_COSMOSDB_ACCOUNT_KEY
else env_helper.AZURE_COSMOSDB_ACCOUNT_KEY
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import asyncpg
from datetime import datetime, timezone
from azure.identity import DefaultAzureCredential
from ..helpers.azure_credential_utils import get_azure_credential

from .database_client_base import DatabaseClientBase

Expand All @@ -21,7 +21,7 @@ def __init__(

async def connect(self):
try:
credential = DefaultAzureCredential()
credential = get_azure_credential()
token = credential.get_token(
"https://ossrdbms-aad.database.windows.net/.default"
).token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from azure.storage.queue import QueueClient, BinaryBase64EncodePolicy
import chardet
from .env_helper import EnvHelper
from azure.identity import DefaultAzureCredential
from .azure_credential_utils import get_azure_credential


def connection_string(account_name: str, account_key: str):
Expand All @@ -25,7 +25,7 @@ def create_queue_client():
return QueueClient(
account_url=f"https://{env_helper.AZURE_BLOB_ACCOUNT_NAME}.queue.core.windows.net/",
queue_name=env_helper.DOCUMENT_PROCESSING_QUEUE_NAME,
credential=DefaultAzureCredential(),
credential=get_azure_credential(),
message_encode_policy=BinaryBase64EncodePolicy(),
)

Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
if self.auth_type == "rbac":
self.account_key = None
self.blob_service_client = BlobServiceClient(
account_url=self.endpoint, credential=DefaultAzureCredential()
account_url=self.endpoint, credential=get_azure_credential()
)
self.user_delegation_key = self.request_user_delegation_key(
blob_service_client=self.blob_service_client
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from urllib.parse import urljoin
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from azure.identity import get_bearer_token_provider
from .azure_credential_utils import get_azure_credential

import requests
from requests import Response
Expand Down Expand Up @@ -56,7 +57,7 @@ def __make_request(self, path: str, body) -> Response:
headers["Ocp-Apim-Subscription-Key"] = self.key
else:
token_provider = get_bearer_token_provider(
DefaultAzureCredential(), self.__TOKEN_SCOPE
get_azure_credential(), self.__TOKEN_SCOPE
)
headers["Authorization"] = "Bearer " + token_provider()

Expand Down
48 changes: 48 additions & 0 deletions code/backend/batch/utilities/helpers/azure_credential_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
from azure.identity import ManagedIdentityCredential, DefaultAzureCredential
from azure.identity.aio import (
ManagedIdentityCredential as AioManagedIdentityCredential,
DefaultAzureCredential as AioDefaultAzureCredential,
)


async def get_azure_credential_async(client_id=None):
"""
Returns an Azure credential asynchronously based on the application environment.
If the environment is 'dev', it uses AioDefaultAzureCredential.
Otherwise, it uses AioManagedIdentityCredential.
Args:
client_id (str, optional): The client ID for the Managed Identity Credential.
Returns:
Credential object: Either AioDefaultAzureCredential or AioManagedIdentityCredential.
"""
if os.getenv("APP_ENV", "prod").lower() == "dev":
return (
AioDefaultAzureCredential()
) # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development
else:
return AioManagedIdentityCredential(client_id=client_id)


def get_azure_credential(client_id=None):
"""
Returns an Azure credential based on the application environment.
If the environment is 'dev', it uses DefaultAzureCredential.
Otherwise, it uses ManagedIdentityCredential.
Args:
client_id (str, optional): The client ID for the Managed Identity Credential.
Returns:
Credential object: Either DefaultAzureCredential or ManagedIdentityCredential.
"""
if os.getenv("APP_ENV", "prod").lower() == "dev":
return (
DefaultAzureCredential()
) # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development
else:
return ManagedIdentityCredential(client_id=client_id)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from azure.core.credentials import AzureKeyCredential
from azure.ai.formrecognizer import DocumentAnalysisClient
from azure.identity import DefaultAzureCredential
from .azure_credential_utils import get_azure_credential
import html
import traceback
from .env_helper import EnvHelper
Expand All @@ -19,7 +19,7 @@ def __init__(self) -> None:
if env_helper.AZURE_AUTH_TYPE == "rbac":
self.document_analysis_client = DocumentAnalysisClient(
endpoint=self.AZURE_FORM_RECOGNIZER_ENDPOINT,
credential=DefaultAzureCredential(),
credential=get_azure_credential(),
headers={
"x-ms-useragent": "chat-with-your-data-solution-accelerator/1.0.0"
},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import psycopg2
from psycopg2.extras import execute_values, RealDictCursor
from azure.identity import DefaultAzureCredential
from .azure_credential_utils import get_azure_credential
from .llm_helper import LLMHelper
from .env_helper import EnvHelper

Expand All @@ -24,7 +24,7 @@ def _create_search_client(self):
dbname = self.env_helper.POSTGRESQL_DATABASE

# Acquire the access token
credential = DefaultAzureCredential()
credential = get_azure_credential()
access_token = credential.get_token(
"https://ossrdbms-aad.database.windows.net/.default"
)
Expand Down
10 changes: 5 additions & 5 deletions code/backend/batch/utilities/helpers/azure_search_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Union
from langchain_community.vectorstores import AzureSearch
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential
from .azure_credential_utils import get_azure_credential
from azure.search.documents import SearchClient
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
Expand Down Expand Up @@ -49,10 +49,10 @@ def _search_credential(self):
if self.env_helper.is_auth_type_keys():
return AzureKeyCredential(self.env_helper.AZURE_SEARCH_KEY)
else:
return DefaultAzureCredential()
return get_azure_credential()

def _create_search_client(
self, search_credential: Union[AzureKeyCredential, DefaultAzureCredential]
self, search_credential: Union[AzureKeyCredential, get_azure_credential]
) -> SearchClient:
return SearchClient(
endpoint=self.env_helper.AZURE_SEARCH_SERVICE,
Expand All @@ -61,7 +61,7 @@ def _create_search_client(
)

def _create_search_index_client(
self, search_credential: Union[AzureKeyCredential, DefaultAzureCredential]
self, search_credential: Union[AzureKeyCredential, get_azure_credential]
):
return SearchIndexClient(
endpoint=self.env_helper.AZURE_SEARCH_SERVICE, credential=search_credential
Expand Down Expand Up @@ -285,7 +285,7 @@ def get_conversation_logger(self):
]

if self.env_helper.AZURE_AUTH_TYPE == "rbac":
credential = DefaultAzureCredential()
credential = get_azure_credential()
return AzureSearch(
azure_search_endpoint=self.env_helper.AZURE_SEARCH_SERVICE,
azure_search_key=None, # Remove API key
Expand Down
12 changes: 6 additions & 6 deletions code/backend/batch/utilities/helpers/env_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import logging
import threading
from dotenv import load_dotenv
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from azure.identity import get_bearer_token_provider
from .azure_credential_utils import get_azure_credential
from azure.keyvault.secrets import SecretClient

from ..orchestrator.orchestration_strategy import OrchestrationStrategy
from ..helpers.config.conversation_flow import ConversationFlow
from ..helpers.config.database_type import DatabaseType
Expand Down Expand Up @@ -216,7 +216,7 @@ def __load_config(self, **kwargs) -> None:
)

self.AZURE_TOKEN_PROVIDER = get_bearer_token_provider(
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
get_azure_credential(), "https://cognitiveservices.azure.com/.default"
)
self.ADVANCED_IMAGE_PROCESSING_MAX_IMAGES = self.get_env_var_int(
"ADVANCED_IMAGE_PROCESSING_MAX_IMAGES", 1
Expand Down Expand Up @@ -362,8 +362,8 @@ def __load_config(self, **kwargs) -> None:
self.OPEN_AI_FUNCTIONS_SYSTEM_PROMPT = os.getenv(
"OPEN_AI_FUNCTIONS_SYSTEM_PROMPT", ""
)
self.SEMENTIC_KERNEL_SYSTEM_PROMPT = os.getenv(
"SEMENTIC_KERNEL_SYSTEM_PROMPT", ""
self.SEMANTIC_KERNEL_SYSTEM_PROMPT = os.getenv(
"SEMANTIC_KERNEL_SYSTEM_PROMPT", ""
)

self.ENFORCE_AUTH = self.get_env_var_bool("ENFORCE_AUTH", "True")
Expand Down Expand Up @@ -429,7 +429,7 @@ def __init__(self) -> None:
self.secret_client = None
if self.USE_KEY_VAULT:
self.secret_client = SecretClient(
os.environ.get("AZURE_KEY_VAULT_ENDPOINT"), DefaultAzureCredential()
os.environ.get("AZURE_KEY_VAULT_ENDPOINT"), get_azure_credential()
)

def get_secret(self, secret_name: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions code/backend/batch/utilities/helpers/llm_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
AzureChatPromptExecutionSettings,
)
from azure.ai.ml import MLClient
from azure.identity import DefaultAzureCredential
from .azure_credential_utils import get_azure_credential
from .env_helper import EnvHelper

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -166,7 +166,7 @@ def get_sk_service_settings(self, service: AzureChatCompletion):
def get_ml_client(self):
if not hasattr(self, "_ml_client"):
self._ml_client = MLClient(
DefaultAzureCredential(),
get_azure_credential(),
self.env_helper.AZURE_SUBSCRIPTION_ID,
self.env_helper.AZURE_RESOURCE_GROUP,
self.env_helper.AZURE_ML_WORKSPACE_NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
from azure.search.documents.indexes import SearchIndexerClient
from ..helpers.env_helper import EnvHelper
from azure.identity import DefaultAzureCredential
from ..helpers.azure_credential_utils import get_azure_credential
from azure.core.credentials import AzureKeyCredential


Expand All @@ -19,7 +19,7 @@ def __init__(self, env_helper: EnvHelper):
(
AzureKeyCredential(self.env_helper.AZURE_SEARCH_KEY)
if self.env_helper.is_auth_type_keys()
else DefaultAzureCredential()
else get_azure_credential()
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
SearchIndex,
)
from ..helpers.env_helper import EnvHelper
from azure.identity import DefaultAzureCredential
from ..helpers.azure_credential_utils import get_azure_credential
from azure.core.credentials import AzureKeyCredential
from ..helpers.llm_helper import LLMHelper

Expand All @@ -39,7 +39,7 @@ def __init__(self, env_helper: EnvHelper, llm_helper: LLMHelper):
(
AzureKeyCredential(self.env_helper.AZURE_SEARCH_KEY)
if self.env_helper.is_auth_type_keys()
else DefaultAzureCredential()
else get_azure_credential()
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from azure.search.documents.indexes.models import SearchIndexer, FieldMapping
from azure.search.documents.indexes import SearchIndexerClient
from ..helpers.env_helper import EnvHelper
from azure.identity import DefaultAzureCredential
from ..helpers.azure_credential_utils import get_azure_credential
from azure.core.credentials import AzureKeyCredential

logger = logging.getLogger(__name__)
Expand All @@ -16,7 +16,7 @@ def __init__(self, env_helper: EnvHelper):
(
AzureKeyCredential(self.env_helper.AZURE_SEARCH_KEY)
if self.env_helper.is_auth_type_keys()
else DefaultAzureCredential()
else get_azure_credential()
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from azure.search.documents.indexes import SearchIndexerClient
from ..helpers.config.config_helper import IntegratedVectorizationConfig
from ..helpers.env_helper import EnvHelper
from azure.identity import DefaultAzureCredential
from ..helpers.azure_credential_utils import get_azure_credential
from azure.core.credentials import AzureKeyCredential

logger = logging.getLogger(__name__)
Expand All @@ -33,7 +33,7 @@ def __init__(
(
AzureKeyCredential(self.env_helper.AZURE_SEARCH_KEY)
if self.env_helper.is_auth_type_keys()
else DefaultAzureCredential()
else get_azure_credential()
),
)
self.integrated_vectorization_config = integrated_vectorization_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def orchestrate(
if response := self.call_content_safety_input(user_message):
return response

system_message = self.env_helper.SEMENTIC_KERNEL_SYSTEM_PROMPT
system_message = self.env_helper.SEMANTIC_KERNEL_SYSTEM_PROMPT
if not system_message:
system_message = """You help employees to navigate only private information sources.
You must prioritize the function call over your general knowledge for any question by calling the search_documents function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.models import VectorizableTextQuery
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential
from ..helpers.azure_credential_utils import get_azure_credential
from ..common.source_document import SourceDocument
import re

Expand All @@ -21,7 +21,7 @@ def create_search_client(self):
credential=(
AzureKeyCredential(self.env_helper.AZURE_SEARCH_KEY)
if self.env_helper.is_auth_type_keys()
else DefaultAzureCredential()
else get_azure_credential()
),
)

Expand Down Expand Up @@ -170,7 +170,7 @@ def _check_index_exists(self) -> bool:
credential=(
AzureKeyCredential(self.env_helper.AZURE_SEARCH_KEY)
if self.env_helper.is_auth_type_keys()
else DefaultAzureCredential()
else get_azure_credential()
),
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from azure.ai.contentsafety import ContentSafetyClient
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential
from ..helpers.azure_credential_utils import get_azure_credential
from azure.core.exceptions import HttpResponseError
from azure.ai.contentsafety.models import AnalyzeTextOptions
from ..helpers.env_helper import EnvHelper
Expand All @@ -19,7 +19,7 @@ def __init__(self):
logger.info("Initializing ContentSafetyClient with RBAC authentication.")
self.content_safety_client = ContentSafetyClient(
env_helper.AZURE_CONTENT_SAFETY_ENDPOINT,
DefaultAzureCredential(),
get_azure_credential(),
)
else:
logger.info(
Expand Down
Loading