|
6 | 6 |
|
7 | 7 | Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize |
8 | 8 | """ |
9 | | -import uuid |
10 | | -from os import PathLike |
11 | | -from pathlib import Path |
12 | | -from typing import Any, Dict, List, Tuple, Union, Optional |
13 | | -from typing_extensions import Self |
| 9 | +from typing import List |
14 | 10 |
|
15 | | -from azure.core import PipelineClient |
16 | | -from azure.core.credentials import TokenCredential |
17 | | -from azure.core.pipeline import policies |
18 | | - |
19 | | -from ._client import AIProjectClient as ClientGenerated |
20 | | -from ._configuration import AIProjectClientConfiguration |
21 | | -from ._serialization import Deserializer, Serializer |
22 | | -from .operations import AgentsOperations, ConnectionsOperations, EvaluationsOperations, TelemetryOperations |
23 | | -from .operations._patch import InferenceOperations |
24 | | - |
25 | | - |
26 | | -class AIProjectClient( |
27 | | - ClientGenerated |
28 | | -): # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes |
29 | | - def __init__( # pylint: disable=super-init-not-called,too-many-statements |
30 | | - self, |
31 | | - endpoint: str, |
32 | | - subscription_id: str, |
33 | | - resource_group_name: str, |
34 | | - project_name: str, |
35 | | - credential: "TokenCredential", |
36 | | - **kwargs: Any, |
37 | | - ) -> None: |
38 | | - # TODO: Validate input formats with regex match (e.g. subscription ID) |
39 | | - if not endpoint: |
40 | | - raise ValueError("endpoint is required") |
41 | | - if not subscription_id: |
42 | | - raise ValueError("subscription_id ID is required") |
43 | | - if not resource_group_name: |
44 | | - raise ValueError("resource_group_name is required") |
45 | | - if not project_name: |
46 | | - raise ValueError("project_name is required") |
47 | | - if not credential: |
48 | | - raise ValueError("credential is required") |
49 | | - if "api_version" in kwargs: |
50 | | - raise ValueError("No support for overriding the API version") |
51 | | - if "credential_scopes" in kwargs: |
52 | | - raise ValueError("No support for overriding the credential scopes") |
53 | | - |
54 | | - kwargs0 = kwargs.copy() |
55 | | - kwargs1 = kwargs.copy() |
56 | | - kwargs2 = kwargs.copy() |
57 | | - kwargs3 = kwargs.copy() |
58 | | - |
59 | | - self._user_agent: Optional[str] = kwargs.get("user_agent", None) |
60 | | - |
61 | | - # For getting AppInsights connection string from the AppInsights resource. |
62 | | - # The AppInsights resource URL is not known at this point. We need to get it from the |
63 | | - # AzureML "Workspace - Get" REST API call. It will have the form: |
64 | | - # https://management.azure.com/subscriptions/{appinsights_subscription_id}/resourceGroups/{appinsights_resource_group_name}/providers/microsoft.insights/components/{appinsights_resource_name} |
65 | | - _endpoint0 = "https://management.azure.com" |
66 | | - self._config0: AIProjectClientConfiguration = AIProjectClientConfiguration( |
67 | | - endpoint=endpoint, |
68 | | - subscription_id=subscription_id, |
69 | | - resource_group_name=resource_group_name, |
70 | | - project_name=project_name, |
71 | | - credential=credential, |
72 | | - api_version="2020-02-02", |
73 | | - credential_scopes=["https://management.azure.com/.default"], |
74 | | - **kwargs0, |
75 | | - ) |
76 | | - |
77 | | - _policies0 = kwargs0.pop("policies", None) |
78 | | - if _policies0 is None: |
79 | | - _policies0 = [ |
80 | | - policies.RequestIdPolicy(**kwargs0), |
81 | | - self._config0.headers_policy, |
82 | | - self._config0.user_agent_policy, |
83 | | - self._config0.proxy_policy, |
84 | | - policies.ContentDecodePolicy(**kwargs0), |
85 | | - self._config0.redirect_policy, |
86 | | - self._config0.retry_policy, |
87 | | - self._config0.authentication_policy, |
88 | | - self._config0.custom_hook_policy, |
89 | | - self._config0.logging_policy, |
90 | | - policies.DistributedTracingPolicy(**kwargs0), |
91 | | - policies.SensitiveHeaderCleanupPolicy(**kwargs0) if self._config0.redirect_policy else None, |
92 | | - self._config0.http_logging_policy, |
93 | | - ] |
94 | | - self._client0: PipelineClient = PipelineClient(base_url=_endpoint0, policies=_policies0, **kwargs0) |
95 | | - |
96 | | - # For Endpoints operations (listing connections, getting connection properties, getting project properties) |
97 | | - _endpoint1 = ( |
98 | | - "https://management.azure.com/" |
99 | | - + f"subscriptions/{subscription_id}/" |
100 | | - + f"resourceGroups/{resource_group_name}/" |
101 | | - + "providers/Microsoft.MachineLearningServices/" |
102 | | - + f"workspaces/{project_name}" |
103 | | - ) |
104 | | - self._config1: AIProjectClientConfiguration = AIProjectClientConfiguration( |
105 | | - endpoint=endpoint, |
106 | | - subscription_id=subscription_id, |
107 | | - resource_group_name=resource_group_name, |
108 | | - project_name=project_name, |
109 | | - credential=credential, |
110 | | - api_version="2024-07-01-preview", |
111 | | - credential_scopes=["https://management.azure.com/.default"], |
112 | | - **kwargs1, |
113 | | - ) |
114 | | - _policies1 = kwargs1.pop("policies", None) |
115 | | - if _policies1 is None: |
116 | | - _policies1 = [ |
117 | | - policies.RequestIdPolicy(**kwargs1), |
118 | | - self._config1.headers_policy, |
119 | | - self._config1.user_agent_policy, |
120 | | - self._config1.proxy_policy, |
121 | | - policies.ContentDecodePolicy(**kwargs1), |
122 | | - self._config1.redirect_policy, |
123 | | - self._config1.retry_policy, |
124 | | - self._config1.authentication_policy, |
125 | | - self._config1.custom_hook_policy, |
126 | | - self._config1.logging_policy, |
127 | | - policies.DistributedTracingPolicy(**kwargs1), |
128 | | - policies.SensitiveHeaderCleanupPolicy(**kwargs1) if self._config1.redirect_policy else None, |
129 | | - self._config1.http_logging_policy, |
130 | | - ] |
131 | | - self._client1: PipelineClient = PipelineClient(base_url=_endpoint1, policies=_policies1, **kwargs1) |
132 | | - |
133 | | - # For Agents operations |
134 | | - _endpoint2 = f"{endpoint}/agents/v1.0/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.MachineLearningServices/workspaces/{project_name}" # pylint: disable=line-too-long |
135 | | - self._config2 = AIProjectClientConfiguration( |
136 | | - endpoint=endpoint, |
137 | | - subscription_id=subscription_id, |
138 | | - resource_group_name=resource_group_name, |
139 | | - project_name=project_name, |
140 | | - credential=credential, |
141 | | - api_version="2024-12-01-preview", |
142 | | - credential_scopes=["https://ml.azure.com/.default"], |
143 | | - **kwargs2, |
144 | | - ) |
145 | | - _policies2 = kwargs2.pop("policies", None) |
146 | | - if _policies2 is None: |
147 | | - _policies2 = [ |
148 | | - policies.RequestIdPolicy(**kwargs2), |
149 | | - self._config2.headers_policy, |
150 | | - self._config2.user_agent_policy, |
151 | | - self._config2.proxy_policy, |
152 | | - policies.ContentDecodePolicy(**kwargs2), |
153 | | - self._config2.redirect_policy, |
154 | | - self._config2.retry_policy, |
155 | | - self._config2.authentication_policy, |
156 | | - self._config2.custom_hook_policy, |
157 | | - self._config2.logging_policy, |
158 | | - policies.DistributedTracingPolicy(**kwargs2), |
159 | | - policies.SensitiveHeaderCleanupPolicy(**kwargs2) if self._config2.redirect_policy else None, |
160 | | - self._config2.http_logging_policy, |
161 | | - ] |
162 | | - self._client2: PipelineClient = PipelineClient(base_url=_endpoint2, policies=_policies2, **kwargs2) |
163 | | - |
164 | | - # For Cloud Evaluations operations |
165 | | - # cSpell:disable-next-line |
166 | | - _endpoint3 = f"{endpoint}/raisvc/v1.0/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.MachineLearningServices/workspaces/{project_name}" # pylint: disable=line-too-long |
167 | | - self._config3 = AIProjectClientConfiguration( |
168 | | - endpoint=endpoint, |
169 | | - subscription_id=subscription_id, |
170 | | - resource_group_name=resource_group_name, |
171 | | - project_name=project_name, |
172 | | - credential=credential, |
173 | | - api_version="2024-07-01-preview", # TODO: Update me |
174 | | - credential_scopes=["https://ml.azure.com/.default"], # TODO: Update once service changes are ready |
175 | | - **kwargs3, |
176 | | - ) |
177 | | - _policies3 = kwargs3.pop("policies", None) |
178 | | - if _policies3 is None: |
179 | | - _policies3 = [ |
180 | | - policies.RequestIdPolicy(**kwargs3), |
181 | | - self._config3.headers_policy, |
182 | | - self._config3.user_agent_policy, |
183 | | - self._config3.proxy_policy, |
184 | | - policies.ContentDecodePolicy(**kwargs3), |
185 | | - self._config3.redirect_policy, |
186 | | - self._config3.retry_policy, |
187 | | - self._config3.authentication_policy, |
188 | | - self._config3.custom_hook_policy, |
189 | | - self._config3.logging_policy, |
190 | | - policies.DistributedTracingPolicy(**kwargs3), |
191 | | - policies.SensitiveHeaderCleanupPolicy(**kwargs3) if self._config3.redirect_policy else None, |
192 | | - self._config3.http_logging_policy, |
193 | | - ] |
194 | | - self._client3: PipelineClient = PipelineClient(base_url=_endpoint3, policies=_policies3, **kwargs3) |
195 | | - |
196 | | - self._serialize = Serializer() |
197 | | - self._deserialize = Deserializer() |
198 | | - self._serialize.client_side_validation = False |
199 | | - |
200 | | - self.telemetry = TelemetryOperations( |
201 | | - self._client0, self._config0, self._serialize, self._deserialize, outer_instance=self |
202 | | - ) |
203 | | - self.connections = ConnectionsOperations(self._client1, self._config1, self._serialize, self._deserialize) |
204 | | - self.agents = AgentsOperations(self._client2, self._config2, self._serialize, self._deserialize) |
205 | | - self.evaluations = EvaluationsOperations(self._client3, self._config3, self._serialize, self._deserialize) |
206 | | - self.inference = InferenceOperations(self) |
207 | | - |
208 | | - def close(self) -> None: |
209 | | - self._client0.close() |
210 | | - self._client1.close() |
211 | | - self._client2.close() |
212 | | - self._client3.close() |
213 | | - |
214 | | - def __enter__(self) -> Self: |
215 | | - self._client0.__enter__() |
216 | | - self._client1.__enter__() |
217 | | - self._client2.__enter__() |
218 | | - self._client3.__enter__() |
219 | | - return self |
220 | | - |
221 | | - def __exit__(self, *exc_details: Any) -> None: |
222 | | - self._client0.__exit__(*exc_details) |
223 | | - self._client1.__exit__(*exc_details) |
224 | | - self._client2.__exit__(*exc_details) |
225 | | - self._client3.__exit__(*exc_details) |
226 | | - |
227 | | - @classmethod |
228 | | - def from_connection_string(cls, conn_str: str, credential: "TokenCredential", **kwargs) -> Self: |
229 | | - """ |
230 | | - Create an AIProjectClient from a connection string. |
231 | | -
|
232 | | - :param str conn_str: The connection string, copied from your AI Foundry project. |
233 | | - :param TokenCredential credential: Credential used to authenticate requests to the service. |
234 | | - :return: An AIProjectClient instance. |
235 | | - :rtype: AIProjectClient |
236 | | - """ |
237 | | - if not conn_str: |
238 | | - raise ValueError("Connection string is required") |
239 | | - parts = conn_str.split(";") |
240 | | - if len(parts) != 4: |
241 | | - raise ValueError("Invalid connection string format") |
242 | | - endpoint = "https://" + parts[0] |
243 | | - subscription_id = parts[1] |
244 | | - resource_group_name = parts[2] |
245 | | - project_name = parts[3] |
246 | | - return cls(endpoint, subscription_id, resource_group_name, project_name, credential, **kwargs) |
247 | | - |
248 | | - def upload_file(self, file_path: Union[Path, str, PathLike]) -> Tuple[str, str]: |
249 | | - """Upload a file to the Azure AI Foundry project. |
250 | | - This method required *azure-ai-ml* to be installed. |
251 | | -
|
252 | | - :param file_path: The path to the file to upload. |
253 | | - :type file_path: Union[str, Path, PathLike] |
254 | | - :return: The tuple, containing asset id and asset URI of uploaded file. |
255 | | - :rtype: Tuple[str] |
256 | | - """ |
257 | | - try: |
258 | | - from azure.ai.ml import MLClient # type: ignore |
259 | | - from azure.ai.ml.constants import AssetTypes # type: ignore |
260 | | - from azure.ai.ml.entities import Data # type: ignore |
261 | | - except ImportError as e: |
262 | | - raise ImportError( |
263 | | - "azure-ai-ml must be installed to use this function. Please install it using `pip install azure-ai-ml`" |
264 | | - ) from e |
265 | | - |
266 | | - data = Data( |
267 | | - path=str(file_path), |
268 | | - type=AssetTypes.URI_FILE, |
269 | | - name=str(uuid.uuid4()), # generating random name |
270 | | - is_anonymous=True, |
271 | | - version="1", |
272 | | - ) |
273 | | - |
274 | | - ml_client = MLClient( |
275 | | - self._config3.credential, |
276 | | - self._config3.subscription_id, |
277 | | - self._config3.resource_group_name, |
278 | | - self._config3.project_name, |
279 | | - ) |
280 | | - |
281 | | - data_asset = ml_client.data.create_or_update(data) |
282 | | - |
283 | | - return data_asset.id, data_asset.path |
284 | | - |
285 | | - @property |
286 | | - def scope(self) -> Dict[str, str]: |
287 | | - return { |
288 | | - "subscription_id": self._config3.subscription_id, |
289 | | - "resource_group_name": self._config3.resource_group_name, |
290 | | - "project_name": self._config3.project_name, |
291 | | - } |
292 | | - |
293 | | - |
294 | | -__all__: List[str] = [ |
295 | | - "AIProjectClient", |
296 | | -] # Add all objects you want publicly available to users at this package level |
| 11 | +__all__: List[str] = [] # Add all objects you want publicly available to users at this package level |
297 | 12 |
|
298 | 13 |
|
299 | 14 | def patch_sdk(): |
|
0 commit comments