Skip to content

Commit e13a975

Browse files
authored
[ML] AOAI Deployment List (#35417)
* list aoai deployments no translation working * list aoai deployments with translation working, just need to add target url and connection name * lint * add connection name and target to deployment response * run black on files
1 parent 2429dfc commit e13a975

File tree

13 files changed

+39718
-38598
lines changed

13 files changed

+39718
-38598
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_ml_client.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
from azure.ai.ml.entities._assets import WorkspaceAssetReference
7474
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
7575
from azure.ai.ml.operations import (
76+
AzureOpenAIDeploymentOperations,
7677
BatchDeploymentOperations,
7778
BatchEndpointOperations,
7879
ComponentOperations,
@@ -360,6 +361,13 @@ def __init__(
360361
**kwargs,
361362
)
362363

364+
self._service_client_04_2024_preview = ServiceClient042024Preview(
365+
credential=self._credential,
366+
subscription_id=self._operation_scope._subscription_id,
367+
base_url=base_url,
368+
**kwargs,
369+
)
370+
363371
# A general purpose, user-configurable pipeline for making
364372
# http requests
365373
self._requests_pipeline = HttpPipeline(**kwargs)
@@ -686,6 +694,12 @@ def __init__(
686694
self._service_client_10_2023,
687695
**ops_kwargs, # type: ignore[arg-type]
688696
)
697+
self._azure_openai_deployments = AzureOpenAIDeploymentOperations(
698+
self._operation_scope,
699+
self._operation_config,
700+
self._service_client_04_2024_preview,
701+
self._connections,
702+
)
689703

690704
self._serverless_endpoints = ServerlessEndpointOperations(
691705
self._operation_scope,
@@ -1042,6 +1056,16 @@ def indexes(self) -> IndexOperations:
10421056
"""
10431057
return self._indexes
10441058

1059+
@property
1060+
@experimental
1061+
def azure_openai_deployments(self) -> AzureOpenAIDeploymentOperations:
1062+
"""A collection of Azure OpenAI deployment related operations.
1063+
1064+
:return: Azure OpenAI deployment operations.
1065+
:rtype: ~azure.ai.ml.operations.AzureOpenAIDeploymentOperations
1066+
"""
1067+
return self._azure_openai_deployments
1068+
10451069
@property
10461070
def subscription_id(self) -> str:
10471071
"""Get the subscription ID of an MLClient object.

sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@
3939
from ._assets.environment import BuildContext, Environment
4040
from ._assets.intellectual_property import IntellectualProperty
4141
from ._assets.workspace_asset_reference import WorkspaceAssetReference as WorkspaceModelReference
42-
from ._autogen_entities.models import MarketplaceSubscription, ServerlessEndpoint, MarketplacePlan
42+
from ._autogen_entities.models import (
43+
AzureOpenAIDeployment,
44+
MarketplaceSubscription,
45+
ServerlessEndpoint,
46+
MarketplacePlan,
47+
)
4348
from ._builders import Command, Parallel, Pipeline, Spark, Sweep
4449
from ._component.command_component import CommandComponent
4550
from ._component.component import Component
@@ -480,6 +485,7 @@
480485
"AccountKeyConfiguration",
481486
"AadCredentialConfiguration",
482487
"Index",
488+
"AzureOpenAIDeployment",
483489
]
484490

485491
# Allow importing these types for backwards compatibility

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_autogen_entities/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
77
# --------------------------------------------------------------------------
88

9+
from ._models import AzureOpenAIDeployment
910
from ._models import ServerlessEndpoint
1011
from ._models import MarketplaceSubscription
1112
from ._patch import __all__ as _patch_all
1213
from ._patch import * # pylint: disable=unused-wildcard-import
1314
from ._patch import patch_sdk as _patch_sdk
1415

15-
__all__ = ["ServerlessEndpoint", "MarketplaceSubscription"]
16+
__all__ = ["AzureOpenAIDeployment", "ServerlessEndpoint", "MarketplaceSubscription"]
1617
__all__.extend([p for p in _patch_all if p not in __all__])
1718
_patch_sdk()

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_autogen_entities/models/_models.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,43 @@
1717
from .. import models as _models
1818

1919

20+
class AzureOpenAIDeployment(_model_base.Model):
21+
"""Azure OpenAI Deployment Information.
22+
23+
Readonly variables are only populated by the server, and will be ignored when sending a request.
24+
25+
:ivar name: The deployment name.
26+
:vartype name: str
27+
:ivar model_name: The name of the model to deploy.
28+
:vartype model_name: str
29+
:ivar model_version: The model version to deploy.
30+
:vartype model_version: str
31+
:ivar connection_name: The name of the connection to deploy to.
32+
:vartype connection_name: str
33+
:ivar target_url: The target URL of the AOAI resource for the deployment.
34+
:vartype target_url: str
35+
:ivar id: The ARM resource id of the deployment.
36+
:vartype id: str
37+
:ivar properties: Properties of the deployment.
38+
:vartype properties: dict[str, str]
39+
:ivar tags: Tags of the deployment.
40+
:vartype tags: dict[str, str]
41+
"""
42+
43+
name: Optional[str] = rest_field(visibility=["read"])
44+
"""The deployment name."""
45+
model_name: Optional[str] = rest_field(visibility=["read"])
46+
"""The name of the model to deploy."""
47+
model_version: Optional[str] = rest_field(visibility=["read"])
48+
"""The model version to deploy."""
49+
connection_name: Optional[str] = rest_field(visibility=["read"])
50+
"""The name of the connection to deploy to."""
51+
target_url: Optional[str] = rest_field(visibility=["read"])
52+
"""The target URL of the AOAI resource for the deployment."""
53+
id: Optional[str] = rest_field(visibility=["read"])
54+
"""The ARM resource id of the deployment."""
55+
56+
2057
class MarketplacePlan(_model_base.Model):
2158
"""Marketplace Subscription Definition.
2259

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_autogen_entities/models/_patch.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
from azure.ai.ml.entities._system_data import SystemData
1616
from azure.ai.ml._utils._experimental import experimental
17+
from azure.ai.ml._restclient.v2024_04_01_preview.models import (
18+
EndpointDeploymentResourcePropertiesBasicResource,
19+
OpenAIEndpointDeploymentResourceProperties,
20+
)
1721
from azure.ai.ml._utils.utils import camel_to_snake
1822
from azure.ai.ml._restclient.v2024_01_01_preview.models import (
1923
ServerlessEndpoint as RestServerlessEndpoint,
@@ -25,20 +29,23 @@
2529
)
2630

2731
from ._models import (
32+
AzureOpenAIDeployment as _AzureOpenAIDeployment,
2833
ServerlessEndpoint as _ServerlessEndpoint,
2934
MarketplaceSubscription as _MarketplaceSubscription,
3035
MarketplacePlan as _MarketplacePlan,
3136
)
3237
from .._model_base import rest_field
3338

3439
__all__: List[str] = [
40+
"AzureOpenAIDeployment",
3541
"ServerlessEndpoint",
3642
"MarketplaceSubscription",
3743
"MarketplacePlan",
3844
] # Add all objects you want publicly available to users at this package level
3945

4046
_NULL = object()
4147

48+
4249
func_to_attr_type = {
4350
"_deserialize_dict": dict,
4451
"_deserialize_sequence": list,
@@ -81,6 +88,38 @@ def _validate(self) -> None:
8188
raise ValueError(f"Type of attr {attr} is of type {attr_type}, not {rest_field_type}")
8289

8390

91+
@experimental
92+
class AzureOpenAIDeployment(_AzureOpenAIDeployment):
93+
94+
system_data: Optional[SystemData] = rest_field(visibility=["read"])
95+
"""System data of the deployment."""
96+
97+
@classmethod
98+
def _from_rest_object(cls, obj: EndpointDeploymentResourcePropertiesBasicResource) -> "AzureOpenAIDeployment":
99+
properties: OpenAIEndpointDeploymentResourceProperties = obj.properties
100+
return cls(
101+
name=obj.name,
102+
model_name=properties.model.name,
103+
model_version=properties.model.version,
104+
id=obj.id,
105+
system_data=SystemData._from_rest_object(obj.system_data),
106+
)
107+
108+
def as_dict(self, *, exclude_readonly: bool = False) -> Dict[str, Any]:
109+
d = super().as_dict(exclude_readonly=exclude_readonly)
110+
d["system_data"] = json.loads(json.dumps(self.system_data._to_dict())) # type: ignore
111+
return d
112+
113+
114+
AzureOpenAIDeployment.__doc__ += (
115+
_AzureOpenAIDeployment.__doc__.strip() # type: ignore
116+
+ """
117+
:ivar system_data: System data of the deployment.
118+
:vartype system_data: ~azure.ai.ml.entities.SystemData
119+
"""
120+
)
121+
122+
84123
@experimental
85124
class MarketplacePlan(_MarketplacePlan):
86125
pass

sdk/ml/azure-ai-ml/azure/ai/ml/operations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
1010

1111

12+
from ._azure_openai_deployment_operations import AzureOpenAIDeploymentOperations
1213
from ._batch_deployment_operations import BatchDeploymentOperations
1314
from ._batch_endpoint_operations import BatchEndpointOperations
1415
from ._component_operations import ComponentOperations
@@ -56,4 +57,5 @@
5657
"ServerlessEndpointOperations",
5758
"MarketplaceSubscriptionOperations",
5859
"IndexOperations",
60+
"AzureOpenAIDeploymentOperations",
5961
]
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
# pylint: disable=protected-access
6+
7+
import logging
8+
from typing import Iterable
9+
10+
from azure.ai.ml._restclient.v2024_04_01_preview import AzureMachineLearningWorkspaces as ServiceClient2020404Preview
11+
from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations
12+
from azure.ai.ml.entities._autogen_entities.models import AzureOpenAIDeployment
13+
14+
from ._connections_operations import ConnectionsOperations
15+
16+
module_logger = logging.getLogger(__name__)
17+
18+
19+
class AzureOpenAIDeploymentOperations(_ScopeDependentOperations):
20+
def __init__(
21+
self,
22+
operation_scope: OperationScope,
23+
operation_config: OperationConfig,
24+
service_client: ServiceClient2020404Preview,
25+
connections_operations: ConnectionsOperations,
26+
):
27+
super().__init__(operation_scope, operation_config)
28+
self._service_client = service_client.connection
29+
self._connections_operations = connections_operations
30+
31+
def list(self, connection_name: str, **kwargs) -> Iterable[AzureOpenAIDeployment]:
32+
connection = self._connections_operations.get(connection_name)
33+
34+
def _from_rest_add_connection_name(obj):
35+
from_rest_deployment = AzureOpenAIDeployment._from_rest_object(obj)
36+
from_rest_deployment.connection_name = connection_name
37+
from_rest_deployment.target_url = connection.target
38+
return from_rest_deployment
39+
40+
return self._service_client.list_deployments(
41+
self._resource_group_name,
42+
self._workspace_name,
43+
connection_name,
44+
cls=lambda objs: [_from_rest_add_connection_name(obj) for obj in objs],
45+
**kwargs,
46+
)

0 commit comments

Comments
 (0)