Skip to content

Commit ee8a6b4

Browse files
authored
Refactor model download to generate the path prefix for new registry model URI types (Azure#27838)
* fix and refactor model download getting prefix
1 parent a02c1e3 commit ee8a6b4

File tree

4 files changed

+63
-5
lines changed

4 files changed

+63
-5
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@
4040
BLOB_STORAGE_CLIENT_NAME = "BlobStorageClient"
4141
GEN2_STORAGE_CLIENT_NAME = "Gen2StorageClient"
4242
DEFAULT_CONNECTION_TIMEOUT = 14400
43+
STORAGE_URI_REGEX = r"(https:\/\/([a-zA-Z0-9@:%_\\\-+~#?&=]+)[a-zA-Z0-9@:%._\\\-+~#?&=]+\.?)\/([a-zA-Z0-9@:%._\\\-+~#?&=]+)\/(.*)"

sdk/ml/azure-ai-ml/azure/ai/ml/_utils/_storage_utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import re
77
from typing import Tuple, Union
8+
from azure.ai.ml._artifacts._constants import STORAGE_URI_REGEX
89

910
from azure.ai.ml._artifacts._blob_storage_helper import BlobStorageClient
1011
from azure.ai.ml._artifacts._fileshare_storage_helper import FileStorageClient
@@ -182,7 +183,19 @@ def get_artifact_path_from_storage_url(blob_url: str, container_name: dict) -> s
182183
return blob_url
183184

184185

185-
def get_ds_name_and_path_prefix(asset_uri: str) -> Tuple[str, str]:
186-
ds_name = asset_uri.split("paths")[0].split("/")[-2]
187-
path_prefix = asset_uri.split("paths")[1][1:]
186+
def get_ds_name_and_path_prefix(asset_uri: str, registry_name: str = None) -> Tuple[str, str]:
187+
if registry_name:
188+
try:
189+
split_paths = re.findall(STORAGE_URI_REGEX, asset_uri)
190+
path_prefix = split_paths[0][3]
191+
except Exception:
192+
raise Exception("Registry asset URI could not be parsed." )
193+
ds_name = None
194+
else:
195+
try:
196+
ds_name = asset_uri.split("paths")[0].split("/")[-2]
197+
path_prefix = asset_uri.split("paths")[1][1:]
198+
except Exception:
199+
raise Exception("Workspace asset URI could not be parsed.")
200+
188201
return ds_name, path_prefix

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_model_operations.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def download(self, name: str, version: str, download_path: Union[PathLike, str]
305305
"""
306306

307307
model_uri = self.get(name=name, version=version).path
308+
ds_name, path_prefix = get_ds_name_and_path_prefix(model_uri, self._registry_name)
308309
if self._registry_name:
309310
sas_uri = get_storage_details_for_registry_assets(
310311
service_client=self._service_client,
@@ -316,10 +317,8 @@ def download(self, name: str, version: str, download_path: Union[PathLike, str]
316317
uri=model_uri,
317318
)
318319
storage_client = get_storage_client(credential=None, storage_account=None, account_url=sas_uri)
319-
path_prefix = model_uri.split("/")[-1]
320320

321321
else:
322-
ds_name, path_prefix = get_ds_name_and_path_prefix(model_uri)
323322
ds = self._datastore_operation.get(ds_name, include_secrets=True)
324323
acc_name = ds.account_name
325324

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
2+
import pytest
3+
4+
from azure.ai.ml._utils._storage_utils import get_ds_name_and_path_prefix
5+
6+
7+
@pytest.mark.unittest
8+
class TestStorageUtils:
9+
def test_storage_uri_to_prefix(
10+
self,
11+
) -> None:
12+
# These are the asset storage patterns supported for download
13+
reg_uri_1 = 'https://ccccccccddddd345.blob.core.windows.net/demoregist-16d33653-20bf-549b-a3c1-17d975359581/ExperimentRun/dcid.5823bbb4-bb28-497c-b9f2-1ff3a0778b10/model'
14+
reg_uri_2 = 'https://ccccccccccc1978ccc.blob.core.windows.net/demoregist-b46fb119-d3f8-5994-a971-a9c730227846/LocalUpload/0c225a0230907e61c00ea33eac35a54d/model.pkl'
15+
reg_uri_3 = 'https://ccccccccddr546ddd.blob.core.windows.net/some-reg-9717e928-33c2-50c2-90f5-f410b12b8727/sklearn_regression_model.pkl'
16+
workspace_uri_1 = 'azureml://subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/000000000000000/workspaces/some_test_3/datastores/workspaceblobstore/paths/LocalUpload/26960525964086056a7301dd061fb9be/lightgbm_mlflow_model'
17+
18+
assert get_ds_name_and_path_prefix(reg_uri_1, "registry_name") == (None,'ExperimentRun/dcid.5823bbb4-bb28-497c-b9f2-1ff3a0778b10/model')
19+
assert get_ds_name_and_path_prefix(reg_uri_2, "registry_name") == (None, 'LocalUpload/0c225a0230907e61c00ea33eac35a54d/model.pkl')
20+
assert get_ds_name_and_path_prefix(reg_uri_3, "registry_name") == (None, 'sklearn_regression_model.pkl')
21+
assert get_ds_name_and_path_prefix(workspace_uri_1) == ('workspaceblobstore','LocalUpload/26960525964086056a7301dd061fb9be/lightgbm_mlflow_model')
22+
23+
24+
def test_storage_uri_to_prefix_malformed(
25+
self,
26+
) -> None:
27+
reg_uri_bad = 'https://ccccccccddd4512d.blob.core.windows.net/5823bbb4-bb28-497c-b9f2-1ff3a0778b10'
28+
workspace_uri_bad = 'azureml://subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/000000000000000/workspaces/some_test_3/datastores/workspaceblobstore/path/LocalUpload/26960525964086056a7301dd061fb9be/lightgbm_mlflow_model'
29+
30+
with pytest.raises(Exception) as e:
31+
get_ds_name_and_path_prefix(reg_uri_bad, "registry_name")
32+
assert (
33+
'Registry asset URI could not be parsed.'
34+
in str(e.value)
35+
)
36+
37+
with pytest.raises(Exception) as e:
38+
get_ds_name_and_path_prefix(workspace_uri_bad)
39+
assert (
40+
'Workspace asset URI could not be parsed.'
41+
in str(e.value)
42+
)
43+
44+
45+

0 commit comments

Comments
 (0)