|
9 | 9 | import argparse
|
10 | 10 | import pstats
|
11 | 11 | import cProfile
|
12 |
| -from typing import Optional |
| 12 | +from typing import Optional, TYPE_CHECKING |
| 13 | + |
| 14 | +if TYPE_CHECKING: |
| 15 | + from azure.core.credentials import TokenCredential |
13 | 16 |
|
14 | 17 |
|
15 | 18 | PSTATS_PRINT_DEFAULT_SORT_KEY = pstats.SortKey.TIME
|
@@ -218,6 +221,35 @@ def get_from_env(variable: str) -> str:
|
218 | 221 | if not value:
|
219 | 222 | raise ValueError("Undefined environment variable {}".format(variable))
|
220 | 223 | return value
|
| 224 | + |
| 225 | + @staticmethod |
| 226 | + def get_credential(is_async: bool = False) -> "TokenCredential": |
| 227 | + # If AzurePipelinesCredential is detected, use it. |
| 228 | + service_connection_id = os.environ.get("AZURESUBSCRIPTION_SERVICE_CONNECTION_ID") |
| 229 | + client_id = os.environ.get("AZURESUBSCRIPTION_CLIENT_ID") |
| 230 | + tenant_id = os.environ.get("AZURESUBSCRIPTION_TENANT_ID") |
| 231 | + system_access_token = os.environ.get("SYSTEM_ACCESSTOKEN") |
| 232 | + if service_connection_id and client_id and tenant_id and system_access_token: |
| 233 | + if is_async: |
| 234 | + from azure.identity.aio import AzurePipelinesCredential |
| 235 | + else: |
| 236 | + from azure.identity import AzurePipelinesCredential |
| 237 | + |
| 238 | + return AzurePipelinesCredential( |
| 239 | + tenant_id=tenant_id, |
| 240 | + client_id=client_id, |
| 241 | + service_connection_id=service_connection_id, |
| 242 | + system_access_token=system_access_token, |
| 243 | + ) |
| 244 | + |
| 245 | + # Fall back to DefaultAzureCredential |
| 246 | + if is_async: |
| 247 | + from azure.identity.aio import DefaultAzureCredential |
| 248 | + else: |
| 249 | + from azure.identity import DefaultAzureCredential |
| 250 | + |
| 251 | + return DefaultAzureCredential(exclude_managed_identity_credential=True) |
| 252 | + |
221 | 253 |
|
222 | 254 | def _save_profile(self, sync: str, output_path: Optional[str] = None) -> None:
|
223 | 255 | """
|
|
0 commit comments