|
1 | | -# pylint: disable=line-too-long,useless-suppression |
2 | | -# ------------------------------------ |
3 | | -# Copyright (c) Microsoft Corporation. |
4 | | -# Licensed under the MIT License. |
5 | | -# ------------------------------------ |
| 1 | +# coding=utf-8 |
| 2 | +# -------------------------------------------------------------------------- |
| 3 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 4 | +# Licensed under the MIT License. See License.txt in the project root for license information. |
| 5 | +# -------------------------------------------------------------------------- |
6 | 6 | """Customize generated code here. |
7 | 7 |
|
8 | 8 | Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize |
9 | 9 | """ |
10 | | -import uuid |
11 | | -from os import PathLike |
12 | | -from pathlib import Path |
13 | | -from typing import Any, Dict, List, Tuple, Union, Optional |
14 | | -from typing_extensions import Self |
| 10 | +from typing import List |
15 | 11 |
|
16 | | -from azure.core import PipelineClient |
17 | | -from azure.core.credentials import TokenCredential |
18 | | -from azure.core.pipeline import policies |
19 | | - |
20 | | -from ._client import AIProjectClient as ClientGenerated |
21 | | -from ._configuration import AIProjectClientConfiguration |
22 | | -from ._serialization import Deserializer, Serializer |
23 | | -from .operations import AgentsOperations, ConnectionsOperations, EvaluationsOperations, TelemetryOperations |
24 | | -from .operations._patch import InferenceOperations |
25 | | - |
26 | | - |
27 | | -class AIProjectClient( |
28 | | - ClientGenerated |
29 | | -): # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes |
30 | | - def __init__( # pylint: disable=super-init-not-called,too-many-statements |
31 | | - self, |
32 | | - endpoint: str, |
33 | | - subscription_id: str, |
34 | | - resource_group_name: str, |
35 | | - project_name: str, |
36 | | - credential: "TokenCredential", |
37 | | - **kwargs: Any, |
38 | | - ) -> None: |
39 | | - # TODO: Validate input formats with regex match (e.g. subscription ID) |
40 | | - if not endpoint: |
41 | | - raise ValueError("endpoint is required") |
42 | | - if not subscription_id: |
43 | | - raise ValueError("subscription_id ID is required") |
44 | | - if not resource_group_name: |
45 | | - raise ValueError("resource_group_name is required") |
46 | | - if not project_name: |
47 | | - raise ValueError("project_name is required") |
48 | | - if not credential: |
49 | | - raise ValueError("credential is required") |
50 | | - if "api_version" in kwargs: |
51 | | - raise ValueError("No support for overriding the API version") |
52 | | - if "credential_scopes" in kwargs: |
53 | | - raise ValueError("No support for overriding the credential scopes") |
54 | | - |
55 | | - kwargs0 = kwargs.copy() |
56 | | - kwargs1 = kwargs.copy() |
57 | | - kwargs2 = kwargs.copy() |
58 | | - kwargs3 = kwargs.copy() |
59 | | - |
60 | | - self._user_agent: Optional[str] = kwargs.get("user_agent", None) |
61 | | - |
62 | | - # For getting AppInsights connection string from the AppInsights resource. |
63 | | - # The AppInsights resource URL is not known at this point. We need to get it from the |
64 | | - # AzureML "Workspace - Get" REST API call. It will have the form: |
65 | | - # https://management.azure.com/subscriptions/{appinsights_subscription_id}/resourceGroups/{appinsights_resource_group_name}/providers/microsoft.insights/components/{appinsights_resource_name} |
66 | | - _endpoint0 = "https://management.azure.com" |
67 | | - self._config0: AIProjectClientConfiguration = AIProjectClientConfiguration( |
68 | | - endpoint=endpoint, |
69 | | - subscription_id=subscription_id, |
70 | | - resource_group_name=resource_group_name, |
71 | | - project_name=project_name, |
72 | | - credential=credential, |
73 | | - api_version="2020-02-02", |
74 | | - credential_scopes=["https://management.azure.com/.default"], |
75 | | - **kwargs0, |
76 | | - ) |
77 | | - |
78 | | - _policies0 = kwargs0.pop("policies", None) |
79 | | - if _policies0 is None: |
80 | | - _policies0 = [ |
81 | | - policies.RequestIdPolicy(**kwargs0), |
82 | | - self._config0.headers_policy, |
83 | | - self._config0.user_agent_policy, |
84 | | - self._config0.proxy_policy, |
85 | | - policies.ContentDecodePolicy(**kwargs0), |
86 | | - self._config0.redirect_policy, |
87 | | - self._config0.retry_policy, |
88 | | - self._config0.authentication_policy, |
89 | | - self._config0.custom_hook_policy, |
90 | | - self._config0.logging_policy, |
91 | | - policies.DistributedTracingPolicy(**kwargs0), |
92 | | - policies.SensitiveHeaderCleanupPolicy(**kwargs0) if self._config0.redirect_policy else None, |
93 | | - self._config0.http_logging_policy, |
94 | | - ] |
95 | | - self._client0: PipelineClient = PipelineClient(base_url=_endpoint0, policies=_policies0, **kwargs0) |
96 | | - |
97 | | - # For Endpoints operations (listing connections, getting connection properties, getting project properties) |
98 | | - _endpoint1 = ( |
99 | | - "https://management.azure.com/" |
100 | | - + f"subscriptions/{subscription_id}/" |
101 | | - + f"resourceGroups/{resource_group_name}/" |
102 | | - + "providers/Microsoft.MachineLearningServices/" |
103 | | - + f"workspaces/{project_name}" |
104 | | - ) |
105 | | - self._config1: AIProjectClientConfiguration = AIProjectClientConfiguration( |
106 | | - endpoint=endpoint, |
107 | | - subscription_id=subscription_id, |
108 | | - resource_group_name=resource_group_name, |
109 | | - project_name=project_name, |
110 | | - credential=credential, |
111 | | - api_version="2024-07-01-preview", |
112 | | - credential_scopes=["https://management.azure.com/.default"], |
113 | | - **kwargs1, |
114 | | - ) |
115 | | - _policies1 = kwargs1.pop("policies", None) |
116 | | - if _policies1 is None: |
117 | | - _policies1 = [ |
118 | | - policies.RequestIdPolicy(**kwargs1), |
119 | | - self._config1.headers_policy, |
120 | | - self._config1.user_agent_policy, |
121 | | - self._config1.proxy_policy, |
122 | | - policies.ContentDecodePolicy(**kwargs1), |
123 | | - self._config1.redirect_policy, |
124 | | - self._config1.retry_policy, |
125 | | - self._config1.authentication_policy, |
126 | | - self._config1.custom_hook_policy, |
127 | | - self._config1.logging_policy, |
128 | | - policies.DistributedTracingPolicy(**kwargs1), |
129 | | - policies.SensitiveHeaderCleanupPolicy(**kwargs1) if self._config1.redirect_policy else None, |
130 | | - self._config1.http_logging_policy, |
131 | | - ] |
132 | | - self._client1: PipelineClient = PipelineClient(base_url=_endpoint1, policies=_policies1, **kwargs1) |
133 | | - |
134 | | - # For Agents operations |
135 | | - _endpoint2 = f"{endpoint}/agents/v1.0/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.MachineLearningServices/workspaces/{project_name}" # pylint: disable=line-too-long |
136 | | - self._config2 = AIProjectClientConfiguration( |
137 | | - endpoint=endpoint, |
138 | | - subscription_id=subscription_id, |
139 | | - resource_group_name=resource_group_name, |
140 | | - project_name=project_name, |
141 | | - credential=credential, |
142 | | - api_version="2024-12-01-preview", |
143 | | - credential_scopes=["https://ml.azure.com/.default"], |
144 | | - **kwargs2, |
145 | | - ) |
146 | | - _policies2 = kwargs2.pop("policies", None) |
147 | | - if _policies2 is None: |
148 | | - _policies2 = [ |
149 | | - policies.RequestIdPolicy(**kwargs2), |
150 | | - self._config2.headers_policy, |
151 | | - self._config2.user_agent_policy, |
152 | | - self._config2.proxy_policy, |
153 | | - policies.ContentDecodePolicy(**kwargs2), |
154 | | - self._config2.redirect_policy, |
155 | | - self._config2.retry_policy, |
156 | | - self._config2.authentication_policy, |
157 | | - self._config2.custom_hook_policy, |
158 | | - self._config2.logging_policy, |
159 | | - policies.DistributedTracingPolicy(**kwargs2), |
160 | | - policies.SensitiveHeaderCleanupPolicy(**kwargs2) if self._config2.redirect_policy else None, |
161 | | - self._config2.http_logging_policy, |
162 | | - ] |
163 | | - self._client2: PipelineClient = PipelineClient(base_url=_endpoint2, policies=_policies2, **kwargs2) |
164 | | - |
165 | | - # For Cloud Evaluations operations |
166 | | - # cSpell:disable-next-line |
167 | | - _endpoint3 = f"{endpoint}/raisvc/v1.0/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.MachineLearningServices/workspaces/{project_name}" # pylint: disable=line-too-long |
168 | | - self._config3 = AIProjectClientConfiguration( |
169 | | - endpoint=endpoint, |
170 | | - subscription_id=subscription_id, |
171 | | - resource_group_name=resource_group_name, |
172 | | - project_name=project_name, |
173 | | - credential=credential, |
174 | | - api_version="2024-07-01-preview", # TODO: Update me |
175 | | - credential_scopes=["https://ml.azure.com/.default"], # TODO: Update once service changes are ready |
176 | | - **kwargs3, |
177 | | - ) |
178 | | - _policies3 = kwargs3.pop("policies", None) |
179 | | - if _policies3 is None: |
180 | | - _policies3 = [ |
181 | | - policies.RequestIdPolicy(**kwargs3), |
182 | | - self._config3.headers_policy, |
183 | | - self._config3.user_agent_policy, |
184 | | - self._config3.proxy_policy, |
185 | | - policies.ContentDecodePolicy(**kwargs3), |
186 | | - self._config3.redirect_policy, |
187 | | - self._config3.retry_policy, |
188 | | - self._config3.authentication_policy, |
189 | | - self._config3.custom_hook_policy, |
190 | | - self._config3.logging_policy, |
191 | | - policies.DistributedTracingPolicy(**kwargs3), |
192 | | - policies.SensitiveHeaderCleanupPolicy(**kwargs3) if self._config3.redirect_policy else None, |
193 | | - self._config3.http_logging_policy, |
194 | | - ] |
195 | | - self._client3: PipelineClient = PipelineClient(base_url=_endpoint3, policies=_policies3, **kwargs3) |
196 | | - |
197 | | - self._serialize = Serializer() |
198 | | - self._deserialize = Deserializer() |
199 | | - self._serialize.client_side_validation = False |
200 | | - |
201 | | - self.telemetry = TelemetryOperations( |
202 | | - self._client0, self._config0, self._serialize, self._deserialize, outer_instance=self |
203 | | - ) |
204 | | - self.connections = ConnectionsOperations(self._client1, self._config1, self._serialize, self._deserialize) |
205 | | - self.agents = AgentsOperations(self._client2, self._config2, self._serialize, self._deserialize) |
206 | | - self.evaluations = EvaluationsOperations(self._client3, self._config3, self._serialize, self._deserialize) |
207 | | - self.inference = InferenceOperations(self) |
208 | | - |
209 | | - def close(self) -> None: |
210 | | - self._client0.close() |
211 | | - self._client1.close() |
212 | | - self._client2.close() |
213 | | - self._client3.close() |
214 | | - |
215 | | - def __enter__(self) -> Self: |
216 | | - self._client0.__enter__() |
217 | | - self._client1.__enter__() |
218 | | - self._client2.__enter__() |
219 | | - self._client3.__enter__() |
220 | | - return self |
221 | | - |
222 | | - def __exit__(self, *exc_details: Any) -> None: |
223 | | - self._client0.__exit__(*exc_details) |
224 | | - self._client1.__exit__(*exc_details) |
225 | | - self._client2.__exit__(*exc_details) |
226 | | - self._client3.__exit__(*exc_details) |
227 | | - |
228 | | - @classmethod |
229 | | - def from_connection_string(cls, conn_str: str, credential: "TokenCredential", **kwargs) -> Self: |
230 | | - """ |
231 | | - Create an AIProjectClient from a connection string. |
232 | | -
|
233 | | - :param str conn_str: The connection string, copied from your AI Foundry project. |
234 | | - :param TokenCredential credential: Credential used to authenticate requests to the service. |
235 | | - :return: An AIProjectClient instance. |
236 | | - :rtype: AIProjectClient |
237 | | - """ |
238 | | - if not conn_str: |
239 | | - raise ValueError("Connection string is required") |
240 | | - parts = conn_str.split(";") |
241 | | - if len(parts) != 4: |
242 | | - raise ValueError("Invalid connection string format") |
243 | | - endpoint = "https://" + parts[0] |
244 | | - subscription_id = parts[1] |
245 | | - resource_group_name = parts[2] |
246 | | - project_name = parts[3] |
247 | | - return cls(endpoint, subscription_id, resource_group_name, project_name, credential, **kwargs) |
248 | | - |
249 | | - def upload_file(self, file_path: Union[Path, str, PathLike]) -> Tuple[str, str]: |
250 | | - """Upload a file to the Azure AI Foundry project. |
251 | | - This method required *azure-ai-ml* to be installed. |
252 | | -
|
253 | | - :param file_path: The path to the file to upload. |
254 | | - :type file_path: Union[str, Path, PathLike] |
255 | | - :return: The tuple, containing asset id and asset URI of uploaded file. |
256 | | - :rtype: Tuple[str] |
257 | | - """ |
258 | | - try: |
259 | | - from azure.ai.ml import MLClient # type: ignore |
260 | | - from azure.ai.ml.constants import AssetTypes # type: ignore |
261 | | - from azure.ai.ml.entities import Data # type: ignore |
262 | | - except ImportError as e: |
263 | | - raise ImportError( |
264 | | - "azure-ai-ml must be installed to use this function. Please install it using `pip install azure-ai-ml`" |
265 | | - ) from e |
266 | | - |
267 | | - data = Data( |
268 | | - path=str(file_path), |
269 | | - type=AssetTypes.URI_FILE, |
270 | | - name=str(uuid.uuid4()), # generating random name |
271 | | - is_anonymous=True, |
272 | | - version="1", |
273 | | - ) |
274 | | - |
275 | | - ml_client = MLClient( |
276 | | - self._config3.credential, |
277 | | - self._config3.subscription_id, |
278 | | - self._config3.resource_group_name, |
279 | | - self._config3.project_name, |
280 | | - ) |
281 | | - |
282 | | - data_asset = ml_client.data.create_or_update(data) |
283 | | - |
284 | | - return data_asset.id, data_asset.path |
285 | | - |
286 | | - @property |
287 | | - def scope(self) -> Dict[str, str]: |
288 | | - return { |
289 | | - "subscription_id": self._config3.subscription_id, |
290 | | - "resource_group_name": self._config3.resource_group_name, |
291 | | - "project_name": self._config3.project_name, |
292 | | - } |
293 | | - |
294 | | - |
295 | | -__all__: List[str] = [ |
296 | | - "AIProjectClient", |
297 | | -] # Add all objects you want publicly available to users at this package level |
| 12 | +__all__: List[str] = [] # Add all objects you want publicly available to users at this package level |
298 | 13 |
|
299 | 14 |
|
300 | 15 | def patch_sdk(): |
|
0 commit comments