Skip to content

Commit 5baed21

Browse files
authored
[ml] Support of Environment promotion from workspace to registry (Azure#28588)
* promo of environment tests
1 parent f2a27fc commit 5baed21

File tree

9 files changed

+136
-44
lines changed

9 files changed

+136
-44
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_ml_client.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
Registry,
5858
Workspace,
5959
)
60-
from azure.ai.ml.entities._assets import WorkspaceModelReference
60+
from azure.ai.ml.entities._assets import WorkspaceAssetReference
6161
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
6262
from azure.ai.ml.operations import (
6363
BatchDeploymentOperations,
@@ -758,7 +758,7 @@ def _get_workspace_info(cls, found_path: Optional[str]) -> Tuple[str, str, str]:
758758

759759
# T = valid inputs/outputs for create_or_update
760760
# Each entry here requires a registered _create_or_update function below
761-
T = TypeVar("T", Job, Model, Environment, Component, Datastore, WorkspaceModelReference)
761+
T = TypeVar("T", Job, Model, Environment, Component, Datastore, WorkspaceAssetReference)
762762

763763
def create_or_update(
764764
self,
@@ -770,7 +770,7 @@ def create_or_update(
770770
:param entity: The resource to create or update.
771771
:type entity: typing.Union[~azure.ai.ml.entities.Job
772772
, ~azure.ai.ml.entities.Model, ~azure.ai.ml.entities.Environment, ~azure.ai.ml.entities.Component
773-
, ~azure.ai.ml.entities.Datastore, ~azure.ai.ml.entities.WorkspaceModelReference]
773+
, ~azure.ai.ml.entities.Datastore, ~azure.ai.ml.entities.WorkspaceAssetReference]
774774
:return: The created or updated resource.
775775
:rtype: typing.Union[~azure.ai.ml.entities.Job, ~azure.ai.ml.entities.Model
776776
, ~azure.ai.ml.entities.Environment, ~azure.ai.ml.entities.Component, ~azure.ai.ml.entities.Datastore]
@@ -835,8 +835,8 @@ def _(entity: Model, operations):
835835
return operations[AzureMLResourceType.MODEL].create_or_update(entity)
836836

837837

838-
@_create_or_update.register(WorkspaceModelReference)
839-
def _(entity: WorkspaceModelReference, operations):
838+
@_create_or_update.register(WorkspaceAssetReference)
839+
def _(entity: WorkspaceAssetReference, operations):
840840
module_logger.debug("Promoting model to registry")
841841
return operations[AzureMLResourceType.MODEL].create_or_update(entity)
842842

@@ -847,6 +847,12 @@ def _(entity: Environment, operations):
847847
return operations[AzureMLResourceType.ENVIRONMENT].create_or_update(entity)
848848

849849

850+
@_create_or_update.register(WorkspaceAssetReference)
851+
def _(entity: WorkspaceAssetReference, operations):
852+
module_logger.debug("Promoting environment to registry")
853+
return operations[AzureMLResourceType.ENVIRONMENT].create_or_update(entity)
854+
855+
850856
@_create_or_update.register(Component)
851857
def _(entity: Component, operations, **kwargs):
852858
module_logger.debug("Creating or updating components")

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .assets.data import DataSchema
1010
from .assets.environment import AnonymousEnvironmentSchema, EnvironmentSchema
1111
from .assets.model import ModelSchema
12-
from .assets.workspace_model_reference import WorkspaceModelReferenceSchema
12+
from .assets.workspace_asset_reference import WorkspaceAssetReferenceSchema
1313
from .component import CommandComponentSchema
1414
from .core.fields import (
1515
ArmStr,
@@ -52,5 +52,5 @@
5252
"AnonymousCodeAssetSchema",
5353
"ExperimentalField",
5454
"RegistryStr",
55-
"WorkspaceModelReferenceSchema",
55+
"WorkspaceAssetReferenceSchema",
5656
]

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/assets/workspace_model_reference.py renamed to sdk/ml/azure-ai-ml/azure/ai/ml/_schema/assets/workspace_asset_reference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
module_logger = logging.getLogger(__name__)
1616

1717

18-
class WorkspaceModelReferenceSchema(AssetSchema):
18+
class WorkspaceAssetReferenceSchema(AssetSchema):
1919
destination_name = fields.Str()
2020
destination_version = fields.Str()
2121
source_asset_id = fields.Str(required=True)
2222

2323
@post_load
2424
def make(self, data, **kwargs):
25-
from azure.ai.ml.entities._assets import WorkspaceModelReference
25+
from azure.ai.ml.entities._assets import WorkspaceAssetReference
2626

27-
return WorkspaceModelReference(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)
27+
return WorkspaceAssetReference(base_path=self.context[BASE_PATH_CONTEXT_KEY], **data)

sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ._assets._artifacts.model import Model
1616
from ._assets.asset import Asset
1717
from ._assets.environment import BuildContext, Environment
18-
from ._assets.workspace_model_reference import WorkspaceModelReference
18+
from ._assets.workspace_asset_reference import WorkspaceAssetReference
1919
from ._builders import Command, Parallel, Pipeline, Spark, Sweep
2020
from ._component.command_component import CommandComponent
2121
from ._component.component import Component
@@ -246,7 +246,7 @@
246246
"SynapseSparkCompute",
247247
"AutoScaleSettings",
248248
"AutoPauseSettings",
249-
"WorkspaceModelReference",
249+
"WorkspaceAssetReference",
250250
# builders
251251
"Command",
252252
"Parallel",

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_assets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
from ._artifacts.data import Data
1111
from ._artifacts.model import Model
1212
from .environment import Environment
13-
from .workspace_model_reference import WorkspaceModelReference
13+
from .workspace_asset_reference import WorkspaceAssetReference
1414

15-
__all__ = ["Artifact", "Model", "Code", "Data", "Environment", "WorkspaceModelReference"]
15+
__all__ = ["Artifact", "Model", "Code", "Data", "Environment", "WorkspaceAssetReference"]

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_assets/workspace_model_reference.py renamed to sdk/ml/azure-ai-ml/azure/ai/ml/entities/_assets/workspace_asset_reference.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
ResourceManagementAssetReferenceData,
1111
ResourceManagementAssetReferenceDetails,
1212
)
13-
from azure.ai.ml._schema import WorkspaceModelReferenceSchema
13+
from azure.ai.ml._schema import WorkspaceAssetReferenceSchema
1414
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY
1515
from azure.ai.ml.entities._assets.asset import Asset
1616
from azure.ai.ml.entities._util import load_from_dict
1717

1818

19-
class WorkspaceModelReference(Asset):
19+
class WorkspaceAssetReference(Asset):
2020
"""Workspace Model Reference.
2121
2222
:param name: Model name
@@ -53,14 +53,14 @@ def _load(
5353
yaml_path: Optional[Union[os.PathLike, str]] = None,
5454
params_override: Optional[list] = None,
5555
**kwargs,
56-
) -> "WorkspaceModelReference":
56+
) -> "WorkspaceAssetReference":
5757
data = data or {}
5858
params_override = params_override or []
5959
context = {
6060
BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
6161
PARAMS_OVERRIDE_KEY: params_override,
6262
}
63-
return load_from_dict(WorkspaceModelReferenceSchema, data, context, **kwargs)
63+
return load_from_dict(WorkspaceAssetReferenceSchema, data, context, **kwargs)
6464

6565
def _to_rest_object(self) -> ResourceManagementAssetReferenceData:
6666
resource_management_details = ResourceManagementAssetReferenceDetails(
@@ -72,9 +72,9 @@ def _to_rest_object(self) -> ResourceManagementAssetReferenceData:
7272
return resource_management
7373

7474
@classmethod
75-
def _from_rest_object(cls, resource_object: ResourceManagementAssetReferenceData) -> "WorkspaceModelReference":
75+
def _from_rest_object(cls, resource_object: ResourceManagementAssetReferenceData) -> "WorkspaceAssetReference":
7676

77-
resource_management = WorkspaceModelReference(
77+
resource_management = WorkspaceAssetReference(
7878
name=resource_object.properties.destination_name,
7979
version=resource_object.properties.destination_version,
8080
asset_id=resource_object.properties.source_asset_id,
@@ -84,4 +84,4 @@ def _from_rest_object(cls, resource_object: ResourceManagementAssetReferenceData
8484

8585
def _to_dict(self) -> Dict:
8686
# pylint: disable=no-member
87-
return WorkspaceModelReferenceSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
87+
return WorkspaceAssetReferenceSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_environment_operations.py

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@
3131
)
3232
from azure.ai.ml._utils._logger_utils import OpsLogger
3333
from azure.ai.ml._utils._registry_utils import get_asset_body_for_registry_storage, get_sas_uri_for_registry_asset
34-
from azure.ai.ml.constants._common import ARM_ID_PREFIX, AzureMLResourceType
35-
from azure.ai.ml.entities._assets import Environment
34+
from azure.ai.ml.constants._common import ARM_ID_PREFIX, AzureMLResourceType, ASSET_ID_FORMAT
35+
from azure.ai.ml.entities._assets import Environment, WorkspaceAssetReference
36+
from azure.core.exceptions import ResourceNotFoundError
3637
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
3738

3839
ops_logger = OpsLogger(__name__)
@@ -92,6 +93,37 @@ def create_or_update(self, environment: Environment) -> Environment:
9293
)
9394
sas_uri = None
9495
if self._registry_name:
96+
if isinstance(environment, WorkspaceAssetReference):
97+
# verify that environment is not already in registry
98+
try:
99+
self._version_operations.get(
100+
name=environment.name,
101+
version=environment.version,
102+
resource_group_name=self._resource_group_name,
103+
registry_name=self._registry_name,
104+
)
105+
except Exception as err: # pylint: disable=broad-except
106+
if isinstance(err, ResourceNotFoundError):
107+
pass
108+
else:
109+
raise err
110+
else:
111+
msg = "A environment with this name and version already exists in registry"
112+
raise ValidationException(
113+
message=msg,
114+
no_personal_data_message=msg,
115+
target=ErrorTarget.ENVIRONMENT,
116+
error_category=ErrorCategory.USER_ERROR,
117+
)
118+
119+
environment = environment._to_rest_object()
120+
result = self._service_client.resource_management_asset_reference.begin_import_method(
121+
resource_group_name=self._resource_group_name,
122+
registry_name=self._registry_name,
123+
body=environment,
124+
)
125+
return result
126+
95127
sas_uri = get_sas_uri_for_registry_asset(
96128
service_client=self._service_client,
97129
name=environment.name,
@@ -111,9 +143,7 @@ def create_or_update(self, environment: Environment) -> Environment:
111143
)
112144
return self.get(name=environment.name, version=environment.version)
113145

114-
environment = _check_and_upload_env_build_context(
115-
environment=environment, operations=self, sas_uri=sas_uri
116-
)
146+
environment = _check_and_upload_env_build_context(environment=environment, operations=self, sas_uri=sas_uri)
117147
env_version_resource = environment._to_rest_object()
118148
env_rest_obj = (
119149
self._version_operations.begin_create_or_update(
@@ -279,7 +309,7 @@ def archive(
279309
name: str,
280310
version: Optional[str] = None,
281311
label: Optional[str] = None,
282-
**kwargs # pylint:disable=unused-argument
312+
**kwargs, # pylint:disable=unused-argument
283313
) -> None:
284314
"""Archive an environment or an environment version.
285315
@@ -307,7 +337,7 @@ def restore(
307337
name: str,
308338
version: Optional[str] = None,
309339
label: Optional[str] = None,
310-
**kwargs # pylint:disable=unused-argument
340+
**kwargs, # pylint:disable=unused-argument
311341
) -> None:
312342
"""Restore an archived environment version.
313343
@@ -343,6 +373,43 @@ def _get_latest_version(self, name: str) -> Environment:
343373
)
344374
return Environment._from_rest_object(result)
345375

376+
# pylint: disable=no-self-use
377+
def _prepare_to_copy(
378+
self, environment: Environment, name: Optional[str] = None, version: Optional[str] = None
379+
) -> WorkspaceAssetReference:
380+
381+
"""Returns WorkspaceAssetReference
382+
to copy a registered environment to registry given the asset id
383+
384+
:param environment: Registered environment
385+
:type environment: Environment
386+
:param name: Destination name
387+
:type name: str
388+
:param version: Destination version
389+
:type version: str
390+
"""
391+
# Get workspace info to get workspace GUID
392+
workspace = self._service_client.workspaces.get(
393+
resource_group_name=self._resource_group_name, workspace_name=self._workspace_name
394+
)
395+
workspace_guid = workspace.workspace_id
396+
workspace_location = workspace.location
397+
398+
# Get environment asset ID
399+
asset_id = ASSET_ID_FORMAT.format(
400+
workspace_location,
401+
workspace_guid,
402+
AzureMLResourceType.ENVIRONMENT,
403+
environment.name,
404+
environment.version,
405+
)
406+
407+
return WorkspaceAssetReference(
408+
name=name if name else environment.name,
409+
version=version if version else environment.version,
410+
asset_id=asset_id,
411+
)
412+
346413

347414
def _preprocess_environment_name(environment_name: str) -> str:
348415
if environment_name.startswith(ARM_ID_PREFIX):

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_model_operations.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from azure.ai.ml._utils._storage_utils import get_ds_name_and_path_prefix, get_storage_client
4444
from azure.ai.ml._utils.utils import resolve_short_datastore_url, validate_ml_flow_folder
4545
from azure.ai.ml.constants._common import ASSET_ID_FORMAT, AzureMLResourceType
46-
from azure.ai.ml.entities._assets import Model, WorkspaceModelReference
46+
from azure.ai.ml.entities._assets import Model, WorkspaceAssetReference
4747
from azure.ai.ml.entities._credentials import AccountKeyConfiguration
4848
from azure.ai.ml.exceptions import (
4949
AssetPathException,
@@ -89,7 +89,7 @@ def __init__(
8989

9090
# @monitor_with_activity(logger, "Model.CreateOrUpdate", ActivityType.PUBLICAPI)
9191
def create_or_update(
92-
self, model: Union[Model, WorkspaceModelReference]
92+
self, model: Union[Model, WorkspaceAssetReference]
9393
) -> Model: # TODO: Are we going to implement job_name?
9494
"""Returns created or updated model asset.
9595
@@ -120,7 +120,7 @@ def create_or_update(
120120

121121
if self._registry_name:
122122
# Case of copy model to registry
123-
if isinstance(model, WorkspaceModelReference):
123+
if isinstance(model, WorkspaceAssetReference):
124124
# verify that model is not already in registry
125125
try:
126126
self._model_versions_operation.get(
@@ -464,9 +464,9 @@ def _get_latest_version(self, name: str) -> Model:
464464
# pylint: disable=no-self-use
465465
def _prepare_to_copy(
466466
self, model: Model, name: Optional[str] = None, version: Optional[str] = None
467-
) -> WorkspaceModelReference:
467+
) -> WorkspaceAssetReference:
468468

469-
"""Returns WorkspaceModelReference
469+
"""Returns WorkspaceAssetReference
470470
to copy a registered model to registry given the asset id
471471
472472
:param model: Registered model
@@ -492,7 +492,7 @@ def _prepare_to_copy(
492492
model.version,
493493
)
494494

495-
return WorkspaceModelReference(
495+
return WorkspaceAssetReference(
496496
name=name if name else model.name,
497497
version=version if version else model.version,
498498
asset_id=asset_id,

0 commit comments

Comments
 (0)