Skip to content

Commit 487d0b4

Browse files
authored
Fix Azure Entra ID only datastore auth access (Azure#38967)
- Fixes regression introduced when attempting to authenticate with datastores that are configured with shared key access disabled, and only accessible with Entra ID credentials (with the proper roles assigned) - Always retrieves time limited SAS tokens now for datastores configured with shared key access as per latest Azure security recommendations - Adds explicit tests for Entra ID only ("none") AI projects - Updates test recordings Bug 3686546
1 parent 8259791 commit 487d0b4

File tree

8 files changed

+70
-16
lines changed

8 files changed

+70
-16
lines changed

.vscode/cspell.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,6 +1382,7 @@
13821382
"fmeasure",
13831383
"upia",
13841384
"xpia",
1385+
"expirable",
13851386
]
13861387
},
13871388
{

sdk/evaluation/azure-ai-evaluation/assets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "python",
44
"TagPrefix": "python/evaluation/azure-ai-evaluation",
5-
"Tag": "python/evaluation/azure-ai-evaluation_4f3f9f39dc"
5+
"Tag": "python/evaluation/azure-ai-evaluation_326efc986d"
66
}

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_azure/_clients.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ._models import BlobStoreInfo, Workspace
1818

1919

20-
API_VERSION: Final[str] = "2024-10-01"
20+
API_VERSION: Final[str] = "2024-07-01-preview"
2121
QUERY_KEY_API_VERSION: Final[str] = "api-version"
2222
PATH_ML_WORKSPACES = ("providers", "Microsoft.MachineLearningServices", "workspaces")
2323

@@ -69,7 +69,9 @@ def get_credential(self) -> TokenCredential:
6969
self._get_token_manager()
7070
return cast(TokenCredential, self._credential)
7171

72-
def workspace_get_default_datastore(self, workspace_name: str, include_credentials: bool = False) -> BlobStoreInfo:
72+
def workspace_get_default_datastore(
73+
self, workspace_name: str, *, include_credentials: bool = False, **kwargs: Any
74+
) -> BlobStoreInfo:
7375
# 1. Get the default blob store
7476
# REST API documentation:
7577
# https://learn.microsoft.com/rest/api/azureml/datastores/list?view=rest-azureml-2024-10-01
@@ -92,18 +94,29 @@ def workspace_get_default_datastore(self, workspace_name: str, include_credentia
9294
account_name = props_json["accountName"]
9395
endpoint = props_json["endpoint"]
9496
container_name = props_json["containerName"]
97+
credential_type = props_json.get("credentials", {}).get("credentialsType")
9598

9699
# 2. Get the SAS token to use for accessing the blob store
97100
# REST API documentation:
98101
# https://learn.microsoft.com/rest/api/azureml/datastores/list-secrets?view=rest-azureml-2024-10-01
99-
blob_store_credential: Optional[Union[AzureSasCredential, str]] = None
100-
if include_credentials:
102+
blob_store_credential: Optional[Union[AzureSasCredential, TokenCredential, str]]
103+
if not include_credentials:
104+
blob_store_credential = None
105+
elif credential_type and credential_type.lower() == "none":
106+
# If storage account key access is disabled, and only Microsoft Entra ID authentication is available,
107+
# the credentialsType will be "None" and we should not attempt to get the secrets.
108+
blob_store_credential = self.get_credential()
109+
else:
101110
url = self._generate_path(
102111
*PATH_ML_WORKSPACES, workspace_name, "datastores", "workspaceblobstore", "listSecrets"
103112
)
104113
secrets_response = self._http_client.request(
105114
method="POST",
106115
url=url,
116+
json={
117+
"expirableSecret": True,
118+
"expireAfterHours": int(kwargs.get("key_expiration_hours", 1)),
119+
},
107120
params={
108121
QUERY_KEY_API_VERSION: self._api_version,
109122
},
@@ -114,10 +127,13 @@ def workspace_get_default_datastore(self, workspace_name: str, include_credentia
114127
secrets_json = secrets_response.json()
115128
secrets_type = secrets_json["secretsType"].lower()
116129

130+
# As per this website, only SAS tokens, access tokens, or Entra IDs are valid for accessing blob data
131+
# stores:
132+
# https://learn.microsoft.com/rest/api/storageservices/authorize-requests-to-azure-storage.
117133
if secrets_type == "sas":
118134
blob_store_credential = AzureSasCredential(secrets_json["sasToken"])
119135
elif secrets_type == "accountkey":
120-
# To support olders versions of azure-storage-blob better, we return a string here instead of
136+
# To support older versions of azure-storage-blob better, we return a string here instead of
121137
# an AzureNamedKeyCredential
122138
blob_store_credential = secrets_json["key"]
123139
else:
@@ -164,19 +180,19 @@ def _throw_on_http_error(response: HttpResponse, description: str, valid_status:
164180
# nothing to see here, move along
165181
return
166182

167-
additional_info: Optional[str] = None
183+
message = f"The {description} request failed with HTTP {response.status_code}"
168184
try:
169185
error_json = response.json()["error"]
170186
additional_info = f"({error_json['code']}) {error_json['message']}"
187+
message += f" - {additional_info}"
171188
except (JSONDecodeError, ValueError, KeyError):
172189
pass
173190

174191
raise EvaluationException(
175-
message=f"The {description} request failed with HTTP {response.status_code}",
192+
message=message,
176193
target=ErrorTarget.EVALUATE,
177194
category=ErrorCategory.FAILED_EXECUTION,
178195
blame=ErrorBlame.SYSTEM_ERROR,
179-
internal_message=additional_info,
180196
)
181197

182198
def _generate_path(self, *paths: str) -> str:

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_azure/_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88

99
from typing import Dict, List, NamedTuple, Optional, Union
1010
from msrest.serialization import Model
11-
from azure.core.credentials import AzureSasCredential
11+
from azure.core.credentials import AzureSasCredential, TokenCredential
1212

1313

1414
class BlobStoreInfo(NamedTuple):
1515
name: str
1616
account_name: str
1717
endpoint: str
1818
container_name: str
19-
credential: Optional[Union[AzureSasCredential, str]]
19+
credential: Optional[Union[AzureSasCredential, TokenCredential, str]]
2020

2121

2222
class WorkspaceHubConfig(Model):

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluate/_eval_run.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,9 @@ def log_artifact(self, artifact_folder: str, artifact_name: str = EVALUATION_ART
421421
local_paths.append(local_file_path)
422422

423423
# We will write the artifacts to the workspaceblobstore
424-
datastore = self._management_client.workspace_get_default_datastore(self._workspace_name, True)
424+
datastore = self._management_client.workspace_get_default_datastore(
425+
self._workspace_name, include_credentials=True
426+
)
425427
account_url = f"{datastore.account_name}.blob.{datastore.endpoint}"
426428

427429
svc_client = BlobServiceClient(account_url=account_url, credential=datastore.credential)

sdk/evaluation/azure-ai-evaluation/tests/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,22 @@ def project_scope(request, dev_connections: Dict[str, Any]) -> dict:
393393
return dev_connections[conn_name]["value"]
394394

395395

396+
@pytest.fixture
397+
def datastore_project_scopes(connection_file, project_scope, mock_project_scope):
398+
conn_name = "azure_ai_entra_id_project_scope"
399+
if not is_live():
400+
entra_id = mock_project_scope
401+
else:
402+
entra_id = connection_file.get(conn_name)
403+
if not entra_id:
404+
raise ValueError(f"Connection '{conn_name}' not found in dev connections.")
405+
406+
return {
407+
"sas": project_scope,
408+
"none": entra_id,
409+
}
410+
411+
396412
@pytest.fixture
397413
def mock_trace_destination_to_cloud(project_scope: dict):
398414
"""Mock trace destination to cloud."""

sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_lite_management_client.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
import logging
3-
from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential
3+
from azure.core.credentials import AzureSasCredential, TokenCredential
44
from azure.ai.evaluation._azure._clients import LiteMLClient
55

66

@@ -34,7 +34,12 @@ def test_get_token(self, project_scope, azure_cred):
3434

3535
@pytest.mark.azuretest
3636
@pytest.mark.parametrize("include_credentials", [False, True])
37-
def test_workspace_get_default_store(self, project_scope, azure_cred, include_credentials: bool):
37+
@pytest.mark.parametrize("config_name", ["sas", "none"])
38+
def test_workspace_get_default_store(
39+
self, azure_cred, datastore_project_scopes, config_name: str, include_credentials: bool
40+
):
41+
project_scope = datastore_project_scopes[config_name]
42+
3843
client = LiteMLClient(
3944
subscription_id=project_scope["subscription_id"],
4045
resource_group=project_scope["resource_group_name"],
@@ -52,7 +57,11 @@ def test_workspace_get_default_store(self, project_scope, azure_cred, include_cr
5257
assert store.endpoint
5358
assert store.container_name
5459
if include_credentials:
55-
assert isinstance(store.credential, str) or isinstance(store.credential, AzureSasCredential)
60+
assert (
61+
(config_name == "account_key" and isinstance(store.credential, str))
62+
or (config_name == "sas" and isinstance(store.credential, AzureSasCredential))
63+
or (config_name == "none" and isinstance(store.credential, TokenCredential))
64+
)
5665
else:
5766
assert store.credential == None
5867

sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_metrics_upload.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,23 @@ def test_logging_metrics(self, caplog, project_scope, azure_ml_client):
118118
self._assert_no_errors_for_module(caplog.records, EvalRun.__module__)
119119

120120
@pytest.mark.azuretest
121-
def test_log_artifact(self, project_scope, azure_ml_client, caplog, tmp_path):
121+
@pytest.mark.parametrize("config_name", ["sas", "none"])
122+
def test_log_artifact(self, project_scope, azure_cred, datastore_project_scopes, caplog, tmp_path, config_name):
122123
"""Test uploading artifact to the service."""
123124
logger = logging.getLogger(EvalRun.__module__)
124125
# All loggers, having promptflow. prefix will have "promptflow" logger
125126
# as a parent. This logger does not propagate the logs and cannot be
126127
# captured by caplog. Here we will skip this logger to capture logs.
127128
logger.parent = logging.root
129+
130+
project_scope = datastore_project_scopes[config_name]
131+
azure_ml_client = LiteMLClient(
132+
subscription_id=project_scope["subscription_id"],
133+
resource_group=project_scope["resource_group_name"],
134+
logger=logger,
135+
credential=azure_cred,
136+
)
137+
128138
with EvalRun(
129139
run_name="test",
130140
tracking_uri=_get_tracking_uri(azure_ml_client, project_scope),

0 commit comments

Comments
 (0)