Skip to content

Commit 60dd942

Browse files
diondrapeckkdestin
andauthored
Implement pending upload logic for code snapshots (Azure#30058)
* Generate April GA restclient and update readme with tags * Implement pending upload for code assets --------- Co-authored-by: kdestin <[email protected]>
1 parent 058b0e0 commit 60dd942

File tree

325 files changed

+79159
-160812
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

325 files changed

+79159
-160812
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_artifact_utilities.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
_validate_path,
3030
get_ignore_file,
3131
get_object_hash,
32+
get_content_hash,
3233
)
3334
from azure.ai.ml._utils._storage_utils import (
3435
AzureMLDatastorePathUri,
@@ -287,7 +288,8 @@ def _upload_to_datastore(
287288
asset_version: Optional[str] = None,
288289
asset_hash: Optional[str] = None,
289290
ignore_file: Optional[IgnoreFile] = None,
290-
sas_uri: Optional[str] = None, # contains regstry sas url
291+
sas_uri: Optional[str] = None,
292+
blob_uri: Optional[str] = None,
291293
) -> ArtifactStorageInfo:
292294
_validate_path(path, _type=artifact_type)
293295
if not ignore_file:
@@ -306,6 +308,9 @@ def _upload_to_datastore(
306308
ignore_file=ignore_file,
307309
sas_uri=sas_uri,
308310
)
311+
if blob_uri:
312+
artifact.storage_account_url = blob_uri
313+
309314
return artifact
310315

311316

@@ -366,6 +371,7 @@ def _check_and_upload_path(
366371
datastore_name: Optional[str] = None,
367372
sas_uri: Optional[str] = None,
368373
show_progress: bool = True,
374+
blob_uri: Optional[str] = None,
369375
) -> Tuple[T, str]:
370376
"""Checks whether `artifact` is a path or a uri and uploads it to the datastore if necessary.
371377
@@ -406,6 +412,7 @@ def _check_and_upload_path(
406412
artifact_type=artifact_type,
407413
show_progress=show_progress,
408414
ignore_file=getattr(artifact, "_ignore_file", None),
415+
blob_uri=blob_uri,
409416
)
410417
indicator_file = uploaded_artifact.indicator_file # reference to storage contents
411418
if artifact._is_anonymous:
@@ -440,3 +447,36 @@ def _check_and_upload_env_build_context(
440447
# TODO: Depending on decision trailing "/" needs to stay or not. EMS requires it to be present
441448
environment.build.path = uploaded_artifact.full_storage_path + "/"
442449
return environment
450+
451+
452+
def _get_snapshot_path_info(artifact) -> Tuple[str, str, str]:
453+
"""
454+
Validate an Artifact's local path and get its resolved path, ignore file, and hash
455+
:param artifact: Artifact object
456+
:type artifact: azure.ai.ml.entities._assets._artifacts.artifact.Artifact
457+
:return: Artifact's path, ignorefile, and hash
458+
:rtype: Tuple[str, str, str]
459+
"""
460+
if (
461+
hasattr(artifact, "local_path")
462+
and artifact.local_path is not None
463+
or (
464+
hasattr(artifact, "path")
465+
and artifact.path is not None
466+
and not (is_url(artifact.path) or is_mlflow_uri(artifact.path))
467+
)
468+
):
469+
path = (
470+
Path(artifact.path)
471+
if hasattr(artifact, "path") and artifact.path is not None
472+
else Path(artifact.local_path)
473+
)
474+
if not path.is_absolute():
475+
path = Path(artifact.base_path, path).resolve()
476+
477+
_validate_path(path, _type=ErrorTarget.CODE)
478+
479+
ignore_file = get_ignore_file(path)
480+
asset_hash = get_content_hash(path, ignore_file)
481+
482+
return path, ignore_file, asset_hash

sdk/ml/azure-ai-ml/azure/ai/ml/_ml_client.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
from azure.ai.ml._restclient.v2023_04_01_preview import (
4646
AzureMachineLearningWorkspaces as ServiceClient042023Preview,
4747
)
48+
from azure.ai.ml._restclient.v2023_04_01 import (
49+
AzureMachineLearningWorkspaces as ServiceClient042023,
50+
)
4851
from azure.ai.ml._scope_dependent_operations import (
4952
OperationConfig,
5053
OperationsContainer,
@@ -274,6 +277,13 @@ def __init__(
274277
**kwargs,
275278
)
276279

280+
self._service_client_04_2023 = ServiceClient042023(
281+
credential=self._credential,
282+
subscription_id=self._operation_scope._subscription_id,
283+
base_url=base_url,
284+
**kwargs,
285+
)
286+
277287
# A general purpose, user-configurable pipeline for making
278288
# http requests
279289
self._requests_pipeline = HttpPipeline(**kwargs)
@@ -373,7 +383,7 @@ def __init__(
373383
self._code = CodeOperations(
374384
self._operation_scope,
375385
self._operation_config,
376-
self._service_client_10_2021_dataplanepreview if registry_name else self._service_client_05_2022,
386+
self._service_client_10_2021_dataplanepreview if registry_name else self._service_client_04_2023,
377387
self._datastores,
378388
**ops_kwargs,
379389
)

sdk/ml/azure-ai-ml/azure/ai/ml/_utils/_asset_utils.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
44

5-
# pylint: disable=protected-access
5+
# pylint: disable=protected-access,too-many-lines
66

77
import hashlib
88
import logging
@@ -19,6 +19,7 @@
1919
from colorama import Fore
2020
from tqdm import TqdmWarning, tqdm
2121

22+
from azure.ai.ml._restclient.v2023_04_01.models import PendingUploadRequestDto
2223
from azure.ai.ml._artifacts._constants import (
2324
AML_IGNORE_FILE_NAME,
2425
ARTIFACT_ORIGIN,
@@ -966,3 +967,46 @@ def update_to(self, response):
966967
self.completed = current
967968
if current:
968969
self.update(current - self.n)
970+
971+
972+
def get_storage_info_for_non_registry_asset(
973+
service_client, workspace_name, name, version, resource_group
974+
) -> Dict[str, str]:
975+
"""Get SAS uri and blob uri for non-registry asset.
976+
:param service_client: Service client
977+
:type service_client: AzureMachineLearningWorkspaces
978+
:param name: Asset name
979+
:type name: str
980+
:param version: Asset version
981+
:type version: str
982+
:param resource_group: Resource group
983+
:rtype: Dict[str, str]
984+
"""
985+
request_body = PendingUploadRequestDto(pending_upload_type="TemporaryBlobReference")
986+
response = service_client.code_versions.create_or_get_start_pending_upload(
987+
resource_group_name=resource_group,
988+
workspace_name=workspace_name,
989+
name=name,
990+
version=version,
991+
body=request_body,
992+
)
993+
994+
sas_info = {
995+
"sas_uri": response.blob_reference_for_consumption.credential.sas_uri,
996+
"blob_uri": response.blob_reference_for_consumption.blob_uri,
997+
}
998+
999+
return sas_info
1000+
1001+
1002+
def _get_existing_asset_name_and_version(existing_asset):
1003+
import re
1004+
1005+
regex = r"/codes/([^/]+)/versions/([^/]+)"
1006+
1007+
arm_id = existing_asset.id
1008+
match = re.search(regex, arm_id)
1009+
name = match.group(1)
1010+
version = match.group(2)
1011+
1012+
return name, version

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_component/component.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,12 @@ def _customized_validate(self) -> MutableValidationResult:
477477

478478
return validation_result
479479

480+
def _get_anonymous_component_name_version(self):
481+
return ANONYMOUS_COMPONENT_NAME, self._get_anonymous_hash()
482+
480483
def _get_rest_name_version(self):
481484
if self._is_anonymous:
482-
return ANONYMOUS_COMPONENT_NAME, self._get_anonymous_hash()
485+
return self._get_anonymous_component_name_version()
483486
return self.name, self.version
484487

485488
def _to_rest_object(self) -> ComponentVersion:

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_code_operations.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,22 @@
88

99
from marshmallow.exceptions import ValidationError as SchemaValidationError
1010

11-
from azure.ai.ml._artifacts._artifact_utilities import _check_and_upload_path
11+
from azure.ai.ml._artifacts._artifact_utilities import (
12+
_check_and_upload_path,
13+
_get_snapshot_path_info,
14+
)
15+
from azure.ai.ml._utils._asset_utils import (
16+
get_content_hash_version,
17+
get_storage_info_for_non_registry_asset,
18+
_get_existing_asset_name_and_version,
19+
)
1220
from azure.ai.ml._artifacts._constants import (
1321
ASSET_PATH_ERROR,
1422
CHANGED_ASSET_PATH_MSG,
1523
CHANGED_ASSET_PATH_MSG_NO_PERSONAL_DATA,
1624
)
1725
from azure.ai.ml._exception_helper import log_and_raise_error
26+
from azure.ai.ml._restclient.v2023_04_01 import AzureMachineLearningWorkspaces as ServiceClient042023
1827
from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import (
1928
AzureMachineLearningWorkspaces as ServiceClient102021Dataplane,
2029
)
@@ -50,7 +59,7 @@ def __init__(
5059
self,
5160
operation_scope: OperationScope,
5261
operation_config: OperationConfig,
53-
service_client: Union[ServiceClient102022, ServiceClient102021Dataplane],
62+
service_client: Union[ServiceClient102022, ServiceClient102021Dataplane, ServiceClient042023],
5463
datastore_operations: DatastoreOperations,
5564
**kwargs: Dict,
5665
):
@@ -82,6 +91,7 @@ def create_or_update(self, code: Code) -> Code:
8291
name = code.name
8392
version = code.version
8493
sas_uri = None
94+
blob_uri = None
8595

8696
if self._registry_name:
8797
sas_uri = get_sas_uri_for_registry_asset(
@@ -92,12 +102,39 @@ def create_or_update(self, code: Code) -> Code:
92102
registry=self._registry_name,
93103
body=get_asset_body_for_registry_storage(self._registry_name, "codes", name, version),
94104
)
105+
else:
106+
_, _, asset_hash = _get_snapshot_path_info(code)
107+
existing_assets = list(
108+
self._version_operation.list(
109+
resource_group_name=self._resource_group_name,
110+
workspace_name=self._workspace_name,
111+
name=name,
112+
hash=asset_hash,
113+
hash_version=str(get_content_hash_version()),
114+
)
115+
)
116+
117+
if len(existing_assets) > 0:
118+
existing_asset = existing_assets[0]
119+
name, version = _get_existing_asset_name_and_version(existing_asset)
120+
return self.get(name=name, version=version)
121+
sas_info = get_storage_info_for_non_registry_asset(
122+
service_client=self._service_client,
123+
workspace_name=self._workspace_name,
124+
name=name,
125+
version=version,
126+
resource_group=self._resource_group_name,
127+
)
128+
sas_uri = sas_info["sas_uri"]
129+
blob_uri = sas_info["blob_uri"]
130+
95131
code, _ = _check_and_upload_path(
96132
artifact=code,
97133
asset_operations=self,
98134
sas_uri=sas_uri,
99135
artifact_type=ErrorTarget.CODE,
100136
show_progress=self._show_progress,
137+
blob_uri=blob_uri,
101138
)
102139

103140
# For anonymous code, if the code already exists in storage, we reuse the name,

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_operation_orchestrator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
parse_name_label,
2525
parse_prefixed_name_version,
2626
)
27-
from azure.ai.ml._utils._asset_utils import _resolve_label_to_asset
27+
from azure.ai.ml._utils._asset_utils import _resolve_label_to_asset, get_storage_info_for_non_registry_asset
2828
from azure.ai.ml._utils._storage_utils import AzureMLDatastorePathUri
2929
from azure.ai.ml.constants._common import (
3030
ARM_ID_PREFIX,
@@ -272,11 +272,20 @@ def _get_code_asset_arm_id(self, code_asset: Code, register_asset: bool = True)
272272
if register_asset:
273273
code_asset = self._code_assets.create_or_update(code_asset)
274274
return code_asset.id
275+
sas_info = get_storage_info_for_non_registry_asset(
276+
service_client=self._code_assets._service_client,
277+
workspace_name=self._operation_scope.workspace_name,
278+
name=code_asset.name,
279+
version=code_asset.version,
280+
resource_group=self._operation_scope.resource_group_name,
281+
)
275282
uploaded_code_asset, _ = _check_and_upload_path(
276283
artifact=code_asset,
277284
asset_operations=self._code_assets,
278285
artifact_type=ErrorTarget.CODE,
279286
show_progress=self._operation_config.show_progress,
287+
sas_uri=sas_info["sas_uri"],
288+
blob_uri=sas_info["blob_uri"],
280289
)
281290
uploaded_code_asset._id = get_arm_id_with_version(
282291
self._operation_scope,

sdk/ml/azure-ai-ml/tests/batch_services/e2etests/test_batch_deployment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_batch_deployment(self, client: MLClient, data_with_2_versions: str) ->
7272
)
7373
client.batch_endpoints.begin_delete(name=endpoint.name)
7474

75+
@pytest.mark.skip(reason="TODO (2349252): Scoring script not found in code configuration")
7576
def test_batch_deployment_dependency_label_resolution(
7677
self,
7778
client: MLClient,
@@ -135,6 +136,7 @@ def test_batch_deployment_dependency_label_resolution(
135136
)
136137
assert resolved_model.asset_name == model_name and resolved_model.asset_version == model_versions[-1]
137138

139+
@pytest.mark.skip(reason="TODO (2349249): 'Environment Id' is not a valid ARM resource identifier")
138140
def test_batch_job_download(
139141
self,
140142
client: MLClient,

sdk/ml/azure-ai-ml/tests/batch_services/e2etests/test_batch_endpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
@pytest.mark.e2etest
12-
@pytest.mark.usefixtures("recorded_test")
12+
@pytest.mark.usefixtures("recorded_test", "mock_asset_name")
1313
@pytest.mark.production_experiences_test
1414
class TestBatchEndpoint(AzureRecordedTestCase):
1515
def test_batch_endpoint_create(self, client: MLClient, rand_batch_name: Callable[[], str]) -> None:
@@ -67,6 +67,7 @@ def test_mlflow_batch_endpoint_create_and_update(
6767

6868
raise Exception(f"Batch endpoint {name} is supposed to be deleted.")
6969

70+
@pytest.mark.skip("TODO (2349930) SSL Certificate error")
7071
def test_batch_invoke(
7172
self, client: MLClient, rand_batch_name: Callable[[], str], rand_batch_deployment_name: Callable[[], str]
7273
) -> None:
@@ -142,6 +143,7 @@ def test_batch_component(
142143
)
143144
assert job
144145

146+
@pytest.mark.skip("TODO (2349930) SSL Certificate error")
145147
def test_batch_invoke_outputs(
146148
self,
147149
client: MLClient,

sdk/ml/azure-ai-ml/tests/command_job/e2etests/test_command_job.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"mock_code_hash",
3636
"mock_asset_name",
3737
"enable_environment_id_arm_expansion",
38+
"mock_anon_component_version",
3839
)
3940
@pytest.mark.training_experiences_test
4041
class TestCommandJob(AzureRecordedTestCase):

sdk/ml/azure-ai-ml/tests/component/e2etests/test_component.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,9 @@ def test_command_component_with_properties_e2e_flow(self, client: MLClient, rand
989989
# TODO(2037030): verify when backend ready
990990
# assert previous_dict == current_dict
991991

992+
@pytest.mark.skip(
993+
reason="TODO (2349965): Message: User/tenant/subscription is not allowed to access registry UnsecureTest-hello-world"
994+
)
992995
@pytest.mark.usefixtures("enable_private_preview_schema_features")
993996
def test_ipp_component_create(self, ipp_registry_client: MLClient, randstr: Callable[[str], str]):
994997
component_path = "./tests/test_configs/components/component_ipp.yml"

0 commit comments

Comments
 (0)