Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions scripts/data_scripts/azure_credential_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
from azure.identity import ManagedIdentityCredential, DefaultAzureCredential
from azure.identity.aio import (
ManagedIdentityCredential as AioManagedIdentityCredential,
DefaultAzureCredential as AioDefaultAzureCredential,
)


async def get_azure_credential_async(client_id=None):
"""
Returns an Azure credential asynchronously based on the application environment.

If the environment is 'dev', it uses AioDefaultAzureCredential.
Otherwise, it uses AioManagedIdentityCredential.

Args:
client_id (str, optional): The client ID for the Managed Identity Credential.

Returns:
Credential object: Either AioDefaultAzureCredential or AioManagedIdentityCredential.
"""
if os.getenv("APP_ENV", "prod").lower() == "dev":
return (
AioDefaultAzureCredential()
) # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development
else:
return AioManagedIdentityCredential(client_id=client_id)


def get_azure_credential(client_id=None):
"""
Returns an Azure credential based on the application environment.

If the environment is 'dev', it uses DefaultAzureCredential.
Otherwise, it uses ManagedIdentityCredential.

Args:
client_id (str, optional): The client ID for the Managed Identity Credential.

Returns:
Credential object: Either DefaultAzureCredential or ManagedIdentityCredential.
"""
if os.getenv("APP_ENV", "prod").lower() == "dev":
return (
DefaultAzureCredential()
) # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development
else:
return ManagedIdentityCredential(client_id=client_id)
4 changes: 2 additions & 2 deletions scripts/data_scripts/create_postgres_tables.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from azure.identity import DefaultAzureCredential
from azure_credential_utils import get_azure_credential
import psycopg2
from psycopg2 import sql

Expand Down Expand Up @@ -61,7 +61,7 @@ def grant_permissions(cursor, dbname, schema_name, principal_name):


# Acquire the access token
cred = DefaultAzureCredential()
cred = get_azure_credential()
access_token = cred.get_token("https://ossrdbms-aad.database.windows.net/.default")

# Combine the token with the connection string to establish the connection.
Expand Down
1 change: 1 addition & 0 deletions scripts/run_create_table_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ az postgres flexible-server firewall-rule create --resource-group $resourceGroup

# Download the create table python file
curl --output "create_postgres_tables.py" ${baseUrl}"scripts/data_scripts/create_postgres_tables.py"
curl --output "azure_credential_utils.py" ${baseUrl}"scripts/data_scripts/azure_credential_utils.py"

# Download the requirement file
curl --output "$requirementFile" "$requirementFileUrl"
Expand Down