diff --git a/scripts/data_scripts/azure_credential_utils.py b/scripts/data_scripts/azure_credential_utils.py new file mode 100644 index 000000000..e8d9d7051 --- /dev/null +++ b/scripts/data_scripts/azure_credential_utils.py @@ -0,0 +1,48 @@ +import os +from azure.identity import ManagedIdentityCredential, DefaultAzureCredential +from azure.identity.aio import ( + ManagedIdentityCredential as AioManagedIdentityCredential, + DefaultAzureCredential as AioDefaultAzureCredential, +) + + +async def get_azure_credential_async(client_id=None): + """ + Returns an Azure credential asynchronously based on the application environment. + + If the environment is 'dev', it uses AioDefaultAzureCredential. + Otherwise, it uses AioManagedIdentityCredential. + + Args: + client_id (str, optional): The client ID for the Managed Identity Credential. + + Returns: + Credential object: Either AioDefaultAzureCredential or AioManagedIdentityCredential. + """ + if os.getenv("APP_ENV", "prod").lower() == "dev": + return ( + AioDefaultAzureCredential() + ) # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development + else: + return AioManagedIdentityCredential(client_id=client_id) + + +def get_azure_credential(client_id=None): + """ + Returns an Azure credential based on the application environment. + + If the environment is 'dev', it uses DefaultAzureCredential. + Otherwise, it uses ManagedIdentityCredential. + + Args: + client_id (str, optional): The client ID for the Managed Identity Credential. + + Returns: + Credential object: Either DefaultAzureCredential or ManagedIdentityCredential. + """ + if os.getenv("APP_ENV", "prod").lower() == "dev": + return ( + DefaultAzureCredential() + ) # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development + else: + return ManagedIdentityCredential(client_id=client_id) diff --git a/scripts/data_scripts/create_postgres_tables.py b/scripts/data_scripts/create_postgres_tables.py index 805fd7621..14e395d4f 100644 --- a/scripts/data_scripts/create_postgres_tables.py +++ b/scripts/data_scripts/create_postgres_tables.py @@ -1,4 +1,4 @@ -from azure.identity import DefaultAzureCredential +from azure_credential_utils import get_azure_credential import psycopg2 from psycopg2 import sql @@ -61,7 +61,7 @@ def grant_permissions(cursor, dbname, schema_name, principal_name): # Acquire the access token -cred = DefaultAzureCredential() +cred = get_azure_credential() access_token = cred.get_token("https://ossrdbms-aad.database.windows.net/.default") # Combine the token with the connection string to establish the connection. diff --git a/scripts/run_create_table_script.sh b/scripts/run_create_table_script.sh index 8777ecbc5..90a20c50e 100644 --- a/scripts/run_create_table_script.sh +++ b/scripts/run_create_table_script.sh @@ -23,6 +23,7 @@ az postgres flexible-server firewall-rule create --resource-group $resourceGroup # Download the create table python file curl --output "create_postgres_tables.py" ${baseUrl}"scripts/data_scripts/create_postgres_tables.py" +curl --output "azure_credential_utils.py" ${baseUrl}"scripts/data_scripts/azure_credential_utils.py" # Download the requirement file curl --output "$requirementFile" "$requirementFileUrl"