|
1 |
| -# pylint: disable=line-too-long,useless-suppression |
2 | 1 | # ------------------------------------
|
3 | 2 | # Copyright (c) Microsoft Corporation.
|
4 | 3 | # Licensed under the MIT License.
|
|
7 | 6 |
|
8 | 7 | Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
|
9 | 8 | """
|
10 |
| -import uuid |
11 |
| -from os import PathLike |
12 |
| -from pathlib import Path |
13 |
| -from typing import Any, Dict, List, Tuple, Union, Optional |
14 |
| -from typing_extensions import Self |
| 9 | +from typing import List |
15 | 10 |
|
16 |
| -from azure.core import PipelineClient |
17 |
| -from azure.core.credentials import TokenCredential |
18 |
| -from azure.core.pipeline import policies |
19 |
| - |
20 |
| -from ._client import AIProjectClient as ClientGenerated |
21 |
| -from ._configuration import AIProjectClientConfiguration |
22 |
| -from ._serialization import Deserializer, Serializer |
23 |
| -from .operations import AgentsOperations, ConnectionsOperations, EvaluationsOperations, TelemetryOperations |
24 |
| -from .operations._patch import InferenceOperations |
25 |
| - |
26 |
| - |
27 |
| -class AIProjectClient( |
28 |
| - ClientGenerated |
29 |
| -): # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes |
30 |
| - def __init__( # pylint: disable=super-init-not-called,too-many-statements |
31 |
| - self, |
32 |
| - endpoint: str, |
33 |
| - subscription_id: str, |
34 |
| - resource_group_name: str, |
35 |
| - project_name: str, |
36 |
| - credential: "TokenCredential", |
37 |
| - **kwargs: Any, |
38 |
| - ) -> None: |
39 |
| - # TODO: Validate input formats with regex match (e.g. subscription ID) |
40 |
| - if not endpoint: |
41 |
| - raise ValueError("endpoint is required") |
42 |
| - if not subscription_id: |
43 |
| - raise ValueError("subscription_id ID is required") |
44 |
| - if not resource_group_name: |
45 |
| - raise ValueError("resource_group_name is required") |
46 |
| - if not project_name: |
47 |
| - raise ValueError("project_name is required") |
48 |
| - if not credential: |
49 |
| - raise ValueError("credential is required") |
50 |
| - if "api_version" in kwargs: |
51 |
| - raise ValueError("No support for overriding the API version") |
52 |
| - if "credential_scopes" in kwargs: |
53 |
| - raise ValueError("No support for overriding the credential scopes") |
54 |
| - |
55 |
| - kwargs0 = kwargs.copy() |
56 |
| - kwargs1 = kwargs.copy() |
57 |
| - kwargs2 = kwargs.copy() |
58 |
| - kwargs3 = kwargs.copy() |
59 |
| - |
60 |
| - self._user_agent: Optional[str] = kwargs.get("user_agent", None) |
61 |
| - |
62 |
| - # For getting AppInsights connection string from the AppInsights resource. |
63 |
| - # The AppInsights resource URL is not known at this point. We need to get it from the |
64 |
| - # AzureML "Workspace - Get" REST API call. It will have the form: |
65 |
| - # https://management.azure.com/subscriptions/{appinsights_subscription_id}/resourceGroups/{appinsights_resource_group_name}/providers/microsoft.insights/components/{appinsights_resource_name} |
66 |
| - _endpoint0 = "https://management.azure.com" |
67 |
| - self._config0: AIProjectClientConfiguration = AIProjectClientConfiguration( |
68 |
| - endpoint=endpoint, |
69 |
| - subscription_id=subscription_id, |
70 |
| - resource_group_name=resource_group_name, |
71 |
| - project_name=project_name, |
72 |
| - credential=credential, |
73 |
| - api_version="2020-02-02", |
74 |
| - credential_scopes=["https://management.azure.com/.default"], |
75 |
| - **kwargs0, |
76 |
| - ) |
77 |
| - |
78 |
| - _policies0 = kwargs0.pop("policies", None) |
79 |
| - if _policies0 is None: |
80 |
| - _policies0 = [ |
81 |
| - policies.RequestIdPolicy(**kwargs0), |
82 |
| - self._config0.headers_policy, |
83 |
| - self._config0.user_agent_policy, |
84 |
| - self._config0.proxy_policy, |
85 |
| - policies.ContentDecodePolicy(**kwargs0), |
86 |
| - self._config0.redirect_policy, |
87 |
| - self._config0.retry_policy, |
88 |
| - self._config0.authentication_policy, |
89 |
| - self._config0.custom_hook_policy, |
90 |
| - self._config0.logging_policy, |
91 |
| - policies.DistributedTracingPolicy(**kwargs0), |
92 |
| - policies.SensitiveHeaderCleanupPolicy(**kwargs0) if self._config0.redirect_policy else None, |
93 |
| - self._config0.http_logging_policy, |
94 |
| - ] |
95 |
| - self._client0: PipelineClient = PipelineClient(base_url=_endpoint0, policies=_policies0, **kwargs0) |
96 |
| - |
97 |
| - # For Endpoints operations (listing connections, getting connection properties, getting project properties) |
98 |
| - _endpoint1 = ( |
99 |
| - "https://management.azure.com/" |
100 |
| - + f"subscriptions/{subscription_id}/" |
101 |
| - + f"resourceGroups/{resource_group_name}/" |
102 |
| - + "providers/Microsoft.MachineLearningServices/" |
103 |
| - + f"workspaces/{project_name}" |
104 |
| - ) |
105 |
| - self._config1: AIProjectClientConfiguration = AIProjectClientConfiguration( |
106 |
| - endpoint=endpoint, |
107 |
| - subscription_id=subscription_id, |
108 |
| - resource_group_name=resource_group_name, |
109 |
| - project_name=project_name, |
110 |
| - credential=credential, |
111 |
| - api_version="2024-07-01-preview", |
112 |
| - credential_scopes=["https://management.azure.com/.default"], |
113 |
| - **kwargs1, |
114 |
| - ) |
115 |
| - _policies1 = kwargs1.pop("policies", None) |
116 |
| - if _policies1 is None: |
117 |
| - _policies1 = [ |
118 |
| - policies.RequestIdPolicy(**kwargs1), |
119 |
| - self._config1.headers_policy, |
120 |
| - self._config1.user_agent_policy, |
121 |
| - self._config1.proxy_policy, |
122 |
| - policies.ContentDecodePolicy(**kwargs1), |
123 |
| - self._config1.redirect_policy, |
124 |
| - self._config1.retry_policy, |
125 |
| - self._config1.authentication_policy, |
126 |
| - self._config1.custom_hook_policy, |
127 |
| - self._config1.logging_policy, |
128 |
| - policies.DistributedTracingPolicy(**kwargs1), |
129 |
| - policies.SensitiveHeaderCleanupPolicy(**kwargs1) if self._config1.redirect_policy else None, |
130 |
| - self._config1.http_logging_policy, |
131 |
| - ] |
132 |
| - self._client1: PipelineClient = PipelineClient(base_url=_endpoint1, policies=_policies1, **kwargs1) |
133 |
| - |
134 |
| - # For Agents operations |
135 |
| - _endpoint2 = f"{endpoint}/agents/v1.0/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.MachineLearningServices/workspaces/{project_name}" # pylint: disable=line-too-long |
136 |
| - self._config2 = AIProjectClientConfiguration( |
137 |
| - endpoint=endpoint, |
138 |
| - subscription_id=subscription_id, |
139 |
| - resource_group_name=resource_group_name, |
140 |
| - project_name=project_name, |
141 |
| - credential=credential, |
142 |
| - api_version="2024-12-01-preview", |
143 |
| - credential_scopes=["https://ml.azure.com/.default"], |
144 |
| - **kwargs2, |
145 |
| - ) |
146 |
| - _policies2 = kwargs2.pop("policies", None) |
147 |
| - if _policies2 is None: |
148 |
| - _policies2 = [ |
149 |
| - policies.RequestIdPolicy(**kwargs2), |
150 |
| - self._config2.headers_policy, |
151 |
| - self._config2.user_agent_policy, |
152 |
| - self._config2.proxy_policy, |
153 |
| - policies.ContentDecodePolicy(**kwargs2), |
154 |
| - self._config2.redirect_policy, |
155 |
| - self._config2.retry_policy, |
156 |
| - self._config2.authentication_policy, |
157 |
| - self._config2.custom_hook_policy, |
158 |
| - self._config2.logging_policy, |
159 |
| - policies.DistributedTracingPolicy(**kwargs2), |
160 |
| - policies.SensitiveHeaderCleanupPolicy(**kwargs2) if self._config2.redirect_policy else None, |
161 |
| - self._config2.http_logging_policy, |
162 |
| - ] |
163 |
| - self._client2: PipelineClient = PipelineClient(base_url=_endpoint2, policies=_policies2, **kwargs2) |
164 |
| - |
165 |
| - # For Cloud Evaluations operations |
166 |
| - # cSpell:disable-next-line |
167 |
| - _endpoint3 = f"{endpoint}/raisvc/v1.0/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.MachineLearningServices/workspaces/{project_name}" # pylint: disable=line-too-long |
168 |
| - self._config3 = AIProjectClientConfiguration( |
169 |
| - endpoint=endpoint, |
170 |
| - subscription_id=subscription_id, |
171 |
| - resource_group_name=resource_group_name, |
172 |
| - project_name=project_name, |
173 |
| - credential=credential, |
174 |
| - api_version="2024-07-01-preview", # TODO: Update me |
175 |
| - credential_scopes=["https://ml.azure.com/.default"], # TODO: Update once service changes are ready |
176 |
| - **kwargs3, |
177 |
| - ) |
178 |
| - _policies3 = kwargs3.pop("policies", None) |
179 |
| - if _policies3 is None: |
180 |
| - _policies3 = [ |
181 |
| - policies.RequestIdPolicy(**kwargs3), |
182 |
| - self._config3.headers_policy, |
183 |
| - self._config3.user_agent_policy, |
184 |
| - self._config3.proxy_policy, |
185 |
| - policies.ContentDecodePolicy(**kwargs3), |
186 |
| - self._config3.redirect_policy, |
187 |
| - self._config3.retry_policy, |
188 |
| - self._config3.authentication_policy, |
189 |
| - self._config3.custom_hook_policy, |
190 |
| - self._config3.logging_policy, |
191 |
| - policies.DistributedTracingPolicy(**kwargs3), |
192 |
| - policies.SensitiveHeaderCleanupPolicy(**kwargs3) if self._config3.redirect_policy else None, |
193 |
| - self._config3.http_logging_policy, |
194 |
| - ] |
195 |
| - self._client3: PipelineClient = PipelineClient(base_url=_endpoint3, policies=_policies3, **kwargs3) |
196 |
| - |
197 |
| - self._serialize = Serializer() |
198 |
| - self._deserialize = Deserializer() |
199 |
| - self._serialize.client_side_validation = False |
200 |
| - |
201 |
| - self.telemetry = TelemetryOperations( |
202 |
| - self._client0, self._config0, self._serialize, self._deserialize, outer_instance=self |
203 |
| - ) |
204 |
| - self.connections = ConnectionsOperations(self._client1, self._config1, self._serialize, self._deserialize) |
205 |
| - self.agents = AgentsOperations(self._client2, self._config2, self._serialize, self._deserialize) |
206 |
| - self.evaluations = EvaluationsOperations(self._client3, self._config3, self._serialize, self._deserialize) |
207 |
| - self.inference = InferenceOperations(self) |
208 |
| - |
209 |
| - def close(self) -> None: |
210 |
| - self._client0.close() |
211 |
| - self._client1.close() |
212 |
| - self._client2.close() |
213 |
| - self._client3.close() |
214 |
| - |
215 |
| - def __enter__(self) -> Self: |
216 |
| - self._client0.__enter__() |
217 |
| - self._client1.__enter__() |
218 |
| - self._client2.__enter__() |
219 |
| - self._client3.__enter__() |
220 |
| - return self |
221 |
| - |
222 |
| - def __exit__(self, *exc_details: Any) -> None: |
223 |
| - self._client0.__exit__(*exc_details) |
224 |
| - self._client1.__exit__(*exc_details) |
225 |
| - self._client2.__exit__(*exc_details) |
226 |
| - self._client3.__exit__(*exc_details) |
227 |
| - |
228 |
| - @classmethod |
229 |
| - def from_connection_string(cls, conn_str: str, credential: "TokenCredential", **kwargs) -> Self: |
230 |
| - """ |
231 |
| - Create an AIProjectClient from a connection string. |
232 |
| -
|
233 |
| - :param str conn_str: The connection string, copied from your AI Foundry project. |
234 |
| - :param TokenCredential credential: Credential used to authenticate requests to the service. |
235 |
| - :return: An AIProjectClient instance. |
236 |
| - :rtype: AIProjectClient |
237 |
| - """ |
238 |
| - if not conn_str: |
239 |
| - raise ValueError("Connection string is required") |
240 |
| - parts = conn_str.split(";") |
241 |
| - if len(parts) != 4: |
242 |
| - raise ValueError("Invalid connection string format") |
243 |
| - endpoint = "https://" + parts[0] |
244 |
| - subscription_id = parts[1] |
245 |
| - resource_group_name = parts[2] |
246 |
| - project_name = parts[3] |
247 |
| - return cls(endpoint, subscription_id, resource_group_name, project_name, credential, **kwargs) |
248 |
| - |
249 |
| - def upload_file(self, file_path: Union[Path, str, PathLike]) -> Tuple[str, str]: |
250 |
| - """Upload a file to the Azure AI Foundry project. |
251 |
| - This method required *azure-ai-ml* to be installed. |
252 |
| -
|
253 |
| - :param file_path: The path to the file to upload. |
254 |
| - :type file_path: Union[str, Path, PathLike] |
255 |
| - :return: The tuple, containing asset id and asset URI of uploaded file. |
256 |
| - :rtype: Tuple[str] |
257 |
| - """ |
258 |
| - try: |
259 |
| - from azure.ai.ml import MLClient # type: ignore |
260 |
| - from azure.ai.ml.constants import AssetTypes # type: ignore |
261 |
| - from azure.ai.ml.entities import Data # type: ignore |
262 |
| - except ImportError as e: |
263 |
| - raise ImportError( |
264 |
| - "azure-ai-ml must be installed to use this function. Please install it using `pip install azure-ai-ml`" |
265 |
| - ) from e |
266 |
| - |
267 |
| - data = Data( |
268 |
| - path=str(file_path), |
269 |
| - type=AssetTypes.URI_FILE, |
270 |
| - name=str(uuid.uuid4()), # generating random name |
271 |
| - is_anonymous=True, |
272 |
| - version="1", |
273 |
| - ) |
274 |
| - |
275 |
| - ml_client = MLClient( |
276 |
| - self._config3.credential, |
277 |
| - self._config3.subscription_id, |
278 |
| - self._config3.resource_group_name, |
279 |
| - self._config3.project_name, |
280 |
| - ) |
281 |
| - |
282 |
| - data_asset = ml_client.data.create_or_update(data) |
283 |
| - |
284 |
| - return data_asset.id, data_asset.path |
285 |
| - |
286 |
| - @property |
287 |
| - def scope(self) -> Dict[str, str]: |
288 |
| - return { |
289 |
| - "subscription_id": self._config3.subscription_id, |
290 |
| - "resource_group_name": self._config3.resource_group_name, |
291 |
| - "project_name": self._config3.project_name, |
292 |
| - } |
293 |
| - |
294 |
| - |
295 |
| -__all__: List[str] = [ |
296 |
| - "AIProjectClient", |
297 |
| -] # Add all objects you want publicly available to users at this package level |
| 11 | +__all__: List[str] = [] # Add all objects you want publicly available to users at this package level |
298 | 12 |
|
299 | 13 |
|
300 | 14 | def patch_sdk():
|
|
0 commit comments