Skip to content

Commit 5b61a95

Browse files
authored
Add method for registry model create/update with system metadata (Azure#40064)
* Add method for registry model create/update with system metadata * add e2e test * fix test * fix test recording * fix test recording * fix test recording * fix tests * fix mypy * run black * fix pylint * move request method * refactor code * fix tests * fix compiling * update recording * fix black
1 parent bc0faa3 commit 5b61a95

File tree

6 files changed

+210
-11
lines changed

6 files changed

+210
-11
lines changed

sdk/ml/azure-ai-ml/assets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "python",
44
"TagPrefix": "python/ml/azure-ai-ml",
5-
"Tag": "python/ml/azure-ai-ml_305b890d5b"
5+
"Tag": "python/ml/azure-ai-ml_9b9dd0f330"
66
}

sdk/ml/azure-ai-ml/azure/ai/ml/_ml_client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_set_cloud,
2222
)
2323
from azure.ai.ml._file_utils.file_utils import traverse_up_path_and_find_file
24+
from azure.ai.ml._restclient.model_dataplane import AzureMachineLearningWorkspaces as ServiceClientModelDataPlane
2425
from azure.ai.ml._restclient.v2020_09_01_dataplanepreview import (
2526
AzureMachineLearningWorkspaces as ServiceClient092020DataplanePreview,
2627
)
@@ -260,6 +261,13 @@ def __init__(
260261
if not workspace_name:
261262
workspace_name = workspace_reference
262263

264+
self._service_client_model_dataplane = ServiceClientModelDataPlane(
265+
credential=self._credential,
266+
subscription_id=subscription_id,
267+
base_url=self._service_client_10_2021_dataplanepreview._client._base_url,
268+
**kwargs,
269+
)
270+
263271
self._operation_scope = OperationScope(
264272
str(subscription_id),
265273
str(resource_group_name),
@@ -577,6 +585,7 @@ def __init__(
577585
else self._service_client_08_2023_preview
578586
),
579587
self._datastores,
588+
(self._service_client_model_dataplane if registry_name or registry_reference else None),
580589
self._operation_container,
581590
requests_pipeline=self._requests_pipeline,
582591
control_plane_client=self._service_client_08_2023_preview,

sdk/ml/azure-ai-ml/azure/ai/ml/_restclient/model_dataplane/operations/_models_operations.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
from azure.core.exceptions import ClientAuthenticationError, HttpResponseError, ResourceExistsError, ResourceNotFoundError, map_error
1313
from azure.core.pipeline import PipelineResponse
1414
from azure.core.pipeline.transport import HttpResponse
15+
from azure.core.polling import LROPoller, NoPolling
1516
from azure.core.rest import HttpRequest
1617
from azure.core.tracing.decorator import distributed_trace
1718
from azure.mgmt.core.exceptions import ARMErrorFormat
19+
from azure.mgmt.core.polling.arm_polling import ARMPolling
1820
from msrest import Serializer
1921

2022
from .. import models as _models
@@ -457,6 +459,50 @@ def build_deployment_settings_request(
457459
**kwargs
458460
)
459461

462+
463+
def build_create_or_update_request_initial(
464+
name, # type: str
465+
version, # type: str
466+
subscription_id, # type: str
467+
resource_group_name, # type: str
468+
registry_name, # type: str
469+
**kwargs # type: Any
470+
):
471+
# type: (...) -> HttpRequest
472+
api_version = kwargs.pop('api_version', "2021-10-01-dataplanepreview") # type: str
473+
content_type = kwargs.pop('content_type', None) # type: Optional[str]
474+
475+
accept = "application/json"
476+
# Construct URL
477+
url = kwargs.pop("template_url", '/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.MachineLearningServices/registries/{registryName}/models/{name}/versions/{version}')
478+
path_format_arguments = {
479+
"name": _SERIALIZER.url("name", name, 'str', pattern=r'^(?![\-_.])[a-zA-Z0-9\-_.]{1,255}(?<!\.)$'),
480+
"version": _SERIALIZER.url("version", version, 'str'),
481+
"subscriptionId": _SERIALIZER.url("subscription_id", subscription_id, 'str', min_length=1),
482+
"resourceGroupName": _SERIALIZER.url("resource_group_name", resource_group_name, 'str', max_length=90, min_length=1),
483+
"registryName": _SERIALIZER.url("registry_name", registry_name, 'str'),
484+
}
485+
486+
url = _format_url_section(url, **path_format_arguments)
487+
488+
# Construct parameters
489+
query_parameters = kwargs.pop("params", {}) # type: Dict[str, Any]
490+
query_parameters['api-version'] = _SERIALIZER.query("api_version", api_version, 'str')
491+
492+
# Construct headers
493+
header_parameters = kwargs.pop("headers", {}) # type: Dict[str, Any]
494+
if content_type is not None:
495+
header_parameters['Content-Type'] = _SERIALIZER.header("content_type", content_type, 'str')
496+
header_parameters['Accept'] = _SERIALIZER.header("accept", accept, 'str')
497+
498+
return HttpRequest(
499+
method="PUT",
500+
url=url,
501+
params=query_parameters,
502+
headers=header_parameters,
503+
**kwargs
504+
)
505+
460506
# fmt: on
461507
class ModelsOperations(object):
462508
"""ModelsOperations operations.
@@ -825,6 +871,107 @@ def create_unregistered_output_model(
825871
create_unregistered_output_model.metadata = {'url': '/modelregistry/v1.0/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.MachineLearningServices/workspaces/{workspaceName}/models/createUnregisteredOutput'} # type: ignore
826872

827873

874+
# this method is only used for model create/update with system metadata
875+
@distributed_trace
876+
def begin_create_or_update_model_with_system_metadata(
877+
self,
878+
subscription_id, # type: str
879+
name, # type: str
880+
version, # type: str
881+
resource_group_name, # type: str
882+
registry_name, # type: str
883+
body, # type: "ModelVersion"
884+
**kwargs, # type: Any
885+
):
886+
# type: (...) -> LROPoller[None]
887+
"""Create or update version.
888+
889+
Create or update version.
890+
891+
:param name: Container name.
892+
:type name: str
893+
:param version: Version identifier.
894+
:type version: str
895+
:param resource_group_name: The name of the resource group. The name is case insensitive.
896+
:type resource_group_name: str
897+
:param registry_name: Name of Azure Machine Learning registry.
898+
:type registry_name: str
899+
:param body: Version entity to create or update.
900+
:type body: ~azure.mgmt.machinelearningservices.models.ModelVersionData
901+
:keyword api_version: Api Version. The default value is "2021-10-01-dataplanepreview". Note
902+
that overriding this default value may result in unsupported behavior.
903+
:paramtype api_version: str
904+
:keyword callable cls: A custom type or function that will be passed the direct response
905+
:keyword str continuation_token: A continuation token to restart a poller from a saved state.
906+
:keyword polling: By default, your polling method will be ARMPolling. Pass in False for this
907+
operation to not poll, or pass in your own initialized polling object for a personal polling
908+
strategy.
909+
:paramtype polling: bool or ~azure.core.polling.PollingMethod
910+
:keyword int polling_interval: Default waiting time between two polls for LRO operations if no
911+
Retry-After header is present.
912+
:return: An instance of LROPoller that returns either None or the result of cls(response)
913+
:rtype: ~azure.core.polling.LROPoller[None]
914+
:raises: ~azure.core.exceptions.HttpResponseError
915+
"""
916+
917+
error_map = {401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError}
918+
error_map.update(kwargs.pop("error_map", {}))
919+
920+
_json = self._serialize.body(body, "ModelVersionData")
921+
_json["properties"]["system_metadata"] = body.properties.system_metadata
922+
923+
request = build_create_or_update_request_initial(
924+
name=name,
925+
version=version,
926+
subscription_id=subscription_id,
927+
resource_group_name=resource_group_name,
928+
registry_name=registry_name,
929+
json=_json,
930+
template_url=self.begin_create_or_update_model_with_system_metadata.metadata["url"],
931+
)
932+
request = _convert_request(request)
933+
request.url = self._client.format_url(request.url)
934+
935+
pipeline_response = self._client._pipeline.run(request, stream=False, **kwargs)
936+
response = pipeline_response.http_response
937+
938+
if response.status_code not in [202]:
939+
map_error(status_code=response.status_code, response=response, error_map=error_map)
940+
raise HttpResponseError(response=response, error_format=ARMErrorFormat)
941+
942+
response_headers = {}
943+
response_headers["x-ms-async-operation-timeout"] = self._deserialize(
944+
"duration", response.headers.get("x-ms-async-operation-timeout")
945+
)
946+
response_headers["Location"] = self._deserialize(
947+
"str", response.headers.get("Location")
948+
)
949+
response_headers["Retry-After"] = self._deserialize(
950+
"int", response.headers.get("Retry-After")
951+
)
952+
953+
cls = kwargs.pop("cls", None)
954+
955+
def get_long_running_output(pipeline_response):
956+
if cls:
957+
return cls(pipeline_response, None, {})
958+
return None
959+
960+
polling = kwargs.pop("polling", True)
961+
if polling is True:
962+
lro_delay = kwargs.pop("polling_interval", self._config.polling_interval)
963+
polling_method = ARMPolling(lro_delay, **kwargs)
964+
elif polling is False:
965+
polling_method = NoPolling()
966+
else:
967+
polling_method = polling
968+
969+
return LROPoller(
970+
self._client, pipeline_response, get_long_running_output, polling_method
971+
)
972+
973+
begin_create_or_update_model_with_system_metadata.metadata = {"url": "/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.MachineLearningServices/registries/{registryName}/models/{name}/versions/{version}"} # type: ignore
974+
828975
@distributed_trace
829976
def batch_get_resolved_uris(
830977
self,

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_model_operations.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
CHANGED_ASSET_PATH_MSG_NO_PERSONAL_DATA,
2323
)
2424
from azure.ai.ml._exception_helper import log_and_raise_error
25+
26+
from azure.ai.ml._restclient.model_dataplane import AzureMachineLearningWorkspaces as ServiceClientModelDataPlane
2527
from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import (
2628
AzureMachineLearningWorkspaces as ServiceClient102021Dataplane,
2729
)
@@ -105,13 +107,16 @@ def __init__(
105107
operation_config: OperationConfig,
106108
service_client: Union[ServiceClient082023Preview, ServiceClient102021Dataplane],
107109
datastore_operations: DatastoreOperations,
110+
service_client_model_dataplane: ServiceClientModelDataPlane = None,
108111
all_operations: Optional[OperationsContainer] = None,
109112
**kwargs,
110113
):
111114
super(ModelOperations, self).__init__(operation_scope, operation_config)
112115
ops_logger.update_filter()
113116
self._model_versions_operation = service_client.model_versions
114117
self._model_container_operation = service_client.model_containers
118+
if service_client_model_dataplane is not None:
119+
self._model_dataplane_operation = service_client_model_dataplane.models
115120
self._service_client = service_client
116121
self._datastore_operation = datastore_operations
117122
self._all_operations = all_operations
@@ -251,14 +256,14 @@ def create_or_update( # type: ignore
251256
model_version_resource = model._to_rest_object()
252257
auto_increment_version = model._auto_increment_version
253258
try:
259+
cont_token = self._scope_kwargs.pop("continuation_token", None) # type: Optional[str]
254260
result = (
255-
self._model_versions_operation.begin_create_or_update(
256-
name=name,
257-
version=version,
258-
body=model_version_resource,
259-
registry_name=self._registry_name,
260-
**self._scope_kwargs,
261-
).result()
261+
self._begin_create_or_update_registry_model(
262+
name,
263+
version,
264+
model_version_resource,
265+
cont_token,
266+
)
262267
if self._registry_name
263268
else self._model_versions_operation.create_or_update(
264269
name=name,
@@ -295,6 +300,31 @@ def create_or_update( # type: ignore
295300
else:
296301
raise ex
297302

303+
def _begin_create_or_update_registry_model(self, name, version, model_version_resource, cont_token):
304+
# if continuation token is None and system_metadata attribute found
305+
# we need to send the system_metadata values in the request
306+
return (
307+
self._model_dataplane_operation.begin_create_or_update_model_with_system_metadata(
308+
subscription_id=self._operation_scope._subscription_id,
309+
name=str(name),
310+
version=str(version),
311+
body=model_version_resource,
312+
registry_name=self._registry_name,
313+
**self._scope_kwargs,
314+
).result()
315+
if self._registry_name
316+
and cont_token is None
317+
and hasattr(model_version_resource.properties, "system_metadata")
318+
and self._model_dataplane_operation is not None
319+
else self._model_versions_operation.begin_create_or_update(
320+
name=name,
321+
version=version,
322+
body=model_version_resource,
323+
registry_name=self._registry_name,
324+
**self._scope_kwargs,
325+
).result()
326+
)
327+
298328
def _get(self, name: str, version: Optional[str] = None) -> ModelVersion: # name:latest
299329
if version:
300330
return (

sdk/ml/azure-ai-ml/tests/model/e2etests/test_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from azure.ai.ml._restclient.v2022_05_01.models import ListViewType
1414
from azure.ai.ml.constants._common import LONG_URI_REGEX_FORMAT
1515
from azure.ai.ml.entities._assets import Model
16+
from azure.core.exceptions import HttpResponseError
1617
from azure.core.paging import ItemPaged
1718

1819

@@ -97,6 +98,19 @@ def test_crud_model_with_stage(self, client: MLClient, randstr: Callable[[], str
9798
model_stage_list = [m.stage for m in model_list if m is not None]
9899
assert model.stage in model_stage_list
99100

101+
def test_crud_model_with_system_metadata(self, registry_client: MLClient, randstr: Callable[[], str]) -> None:
102+
path = Path("./tests/test_configs/model/model_with_system_metadata.yml")
103+
model_name = randstr("model_with_system_metadata")
104+
105+
model = load_model(path)
106+
model.name = model_name
107+
108+
model = registry_client.models.create_or_update(model)
109+
assert model.name == model_name
110+
assert model.version == "1"
111+
assert model.description == "this is my test model with system metadata"
112+
assert model.type == "mlflow_model"
113+
100114
def test_list_no_name(self, client: MLClient) -> None:
101115
models = client.models.list()
102116
assert isinstance(models, Iterator)

sdk/ml/azure-ai-ml/tests/test_configs/model/model_with_system_metadata.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
name: my-production-model
22
type: mlflow_model
33
path: ./lightgbm_mlflow_model
4-
version: 3
5-
description: "this is my test model with stage"
6-
stage: "Production"
4+
version: 1
5+
description: "this is my test model with system metadata"
76
tags:
87
foo: bar
98
abc: 123

0 commit comments

Comments
 (0)