11import pytest
22import yaml
3+ import json
4+ import copy
35from test_utilities .utils import verify_entity_load_and_dump
4-
6+ from azure .ai .ml ._restclient .v2022_02_01_preview .models import (
7+ OnlineEndpointData ,
8+ EndpointAuthKeys as RestEndpointAuthKeys ,
9+ EndpointAuthToken as RestEndpointAuthToken ,
10+ )
11+ from azure .ai .ml ._restclient .v2023_10_01 .models import BatchEndpoint as BatchEndpointData
512from azure .ai .ml import load_batch_endpoint , load_online_endpoint
6- from azure .ai .ml .entities import BatchEndpoint , Endpoint , ManagedOnlineDeployment , OnlineEndpoint
13+ from azure .ai .ml .entities import (
14+ BatchEndpoint ,
15+ ManagedOnlineEndpoint ,
16+ KubernetesOnlineEndpoint ,
17+ OnlineEndpoint ,
18+ EndpointAuthKeys ,
19+ EndpointAuthToken ,
20+ )
21+ from azure .ai .ml .exceptions import ValidationException
722
823
924@pytest .mark .production_experiences_test
@@ -12,6 +27,7 @@ class TestOnlineEndpointYAML:
1227 SIMPLE_ENDPOINT_WITH_BLUE_BAD = "tests/test_configs/endpoints/online/online_endpoint_create_aks_bad.yml"
1328 MINIMAL_ENDPOINT = "tests/test_configs/endpoints/online/online_endpoint_minimal.yaml"
1429 MINIMAL_DEPLOYMENT = "tests/test_configs/deployments/online/online_endpoint_deployment_k8s_minimum.yml"
30+ ONLINE_ENDPOINT_REST = "tests/test_configs/endpoints/online/online_endpoint_rest.json"
1531
1632 def test_specific_endpoint_load_and_dump (self ) -> None :
1733 with open (TestOnlineEndpointYAML .MINIMAL_ENDPOINT , "r" ) as f :
@@ -39,10 +55,41 @@ def test_online_endpoint_to_rest_object_with_no_issue(self) -> None:
3955 endpoint = load_online_endpoint (TestOnlineEndpointYAML .MINIMAL_ENDPOINT )
4056 endpoint ._to_rest_online_endpoint ("westus2" )
4157
58+ def test_from_rest_object_kubenetes (self ) -> None :
59+ with open (TestOnlineEndpointYAML .ONLINE_ENDPOINT_REST , "r" ) as f :
60+ online_deployment_rest = OnlineEndpointData .deserialize (json .load (f ))
61+ online_endpoint = OnlineEndpoint ._from_rest_object (online_deployment_rest )
62+ assert isinstance (online_endpoint , KubernetesOnlineEndpoint )
63+ assert online_endpoint .name == online_deployment_rest .name
64+ assert online_endpoint .compute == online_deployment_rest .properties .compute
65+ assert online_endpoint .tags == online_deployment_rest .tags
66+ assert online_endpoint .traffic == online_deployment_rest .properties .traffic
67+ assert online_endpoint .description == online_deployment_rest .properties .description
68+ assert online_endpoint .provisioning_state == online_deployment_rest .properties .provisioning_state
69+ assert online_endpoint .identity .type == "system_assigned"
70+ assert online_endpoint .identity .principal_id == online_deployment_rest .identity .principal_id
71+ assert online_endpoint .properties ["createdBy" ] == online_deployment_rest .system_data .created_by
72+
73+ def test_from_rest_object_managed (self ) -> None :
74+ with open (TestOnlineEndpointYAML .ONLINE_ENDPOINT_REST , "r" ) as f :
75+ online_deployment_rest = OnlineEndpointData .deserialize (json .load (f ))
76+ online_deployment_rest .properties .compute = None
77+ online_endpoint = OnlineEndpoint ._from_rest_object (online_deployment_rest )
78+ assert isinstance (online_endpoint , ManagedOnlineEndpoint )
79+ assert online_endpoint .name == online_deployment_rest .name
80+ assert online_endpoint .tags == online_deployment_rest .tags
81+ assert online_endpoint .traffic == online_deployment_rest .properties .traffic
82+ assert online_endpoint .description == online_deployment_rest .properties .description
83+ assert online_endpoint .provisioning_state == online_deployment_rest .properties .provisioning_state
84+ assert online_endpoint .identity .type == "system_assigned"
85+ assert online_endpoint .identity .principal_id == online_deployment_rest .identity .principal_id
86+ assert online_endpoint .properties ["createdBy" ] == online_deployment_rest .system_data .created_by
87+
4288
4389@pytest .mark .unittest
4490class TestBatchEndpointYAML :
4591 BATCH_ENDPOINT_WITH_BLUE = "tests/test_configs/endpoints/batch/batch_endpoint.yaml"
92+ BATCH_ENDPOINT_REST = "tests/test_configs/endpoints/batch/batch_endpoint_rest.json"
4693
4794 def test_generic_endpoint_load_and_dump_2 (self ) -> None :
4895 with open (TestBatchEndpointYAML .BATCH_ENDPOINT_WITH_BLUE , "r" ) as f :
@@ -71,6 +118,32 @@ def test_to_rest_batch_endpoint(self) -> None:
71118 assert len (rest_batch_endpoint .tags )
72119 assert rest_batch_endpoint .tags == target ["tags" ]
73120
121+ def test_to_dict (self ) -> None :
122+ endpoint = load_batch_endpoint (TestBatchEndpointYAML .BATCH_ENDPOINT_WITH_BLUE )
123+ endpoint_dict = endpoint ._to_dict ()
124+
125+ assert endpoint_dict ["name" ] == endpoint .name
126+ assert endpoint_dict ["description" ] == endpoint .description
127+ assert endpoint_dict ["auth_mode" ] == endpoint .auth_mode
128+ assert endpoint_dict ["tags" ] == endpoint .tags
129+ assert endpoint_dict ["auth_mode" ] == "aad_token"
130+ assert endpoint_dict ["properties" ] == endpoint .properties
131+
132+ def test_from_rest (self ) -> None :
133+ with open (TestBatchEndpointYAML .BATCH_ENDPOINT_REST , "r" ) as f :
134+ batch_endpoint_rest = BatchEndpointData .deserialize (json .load (f ))
135+ batch_endpoint = BatchEndpoint ._from_rest_object (batch_endpoint_rest )
136+ assert batch_endpoint .name == batch_endpoint_rest .name
137+ assert batch_endpoint .id == batch_endpoint_rest .id
138+ assert batch_endpoint .tags == batch_endpoint_rest .tags
139+ assert batch_endpoint .properties == batch_endpoint_rest .properties .properties
140+ assert batch_endpoint .auth_mode == "aad_token"
141+ assert batch_endpoint .description == batch_endpoint_rest .properties .description
142+ assert batch_endpoint .location == batch_endpoint_rest .location
143+ assert batch_endpoint .provisioning_state == batch_endpoint_rest .properties .provisioning_state
144+ assert batch_endpoint .scoring_uri == batch_endpoint_rest .properties .scoring_uri
145+ assert batch_endpoint .openapi_uri == batch_endpoint_rest .properties .swagger_uri
146+
74147 def test_batch_endpoint_with_deployment_name_promoted_param_only (self ) -> None :
75148 endpoint = BatchEndpoint (
76149 name = "my-batch-endpoint" ,
@@ -103,3 +176,226 @@ def test_batch_endpoint_with_deployment_no_defaults(self) -> None:
103176 )
104177
105178 assert endpoint .defaults is None
179+
180+
181+ class TestKubernetesOnlineEndopint :
182+ K8S_ONLINE_ENDPOINT = "tests/test_configs/endpoints/online/online_endpoint_create_k8s.yml"
183+
184+ def test_merge_with (self ) -> None :
185+ online_endpoint = load_online_endpoint (TestKubernetesOnlineEndopint .K8S_ONLINE_ENDPOINT )
186+ other_online_endpoint = copy .deepcopy (online_endpoint )
187+ other_online_endpoint .compute = "k8ecompute"
188+ other_online_endpoint .tags = {"tag3" : "value3" }
189+ other_online_endpoint .traffic = {"blue" : 90 , "green" : 10 }
190+ other_online_endpoint .description = "new description"
191+ other_online_endpoint .mirror_traffic = {"blue" : 30 }
192+ other_online_endpoint .auth_mode = "aml_token"
193+ other_online_endpoint .properties = {"some-prop" : "value" }
194+
195+ online_endpoint ._merge_with (other_online_endpoint )
196+
197+ assert isinstance (online_endpoint , KubernetesOnlineEndpoint )
198+ assert online_endpoint .compute == "k8ecompute"
199+ assert online_endpoint .tags == {"tag1" : "value1" , "tag2" : "value2" , "tag3" : "value3" }
200+ assert online_endpoint .description == "new description"
201+ assert online_endpoint .traffic == {"blue" : 90 , "green" : 10 }
202+ assert online_endpoint .mirror_traffic == {"blue" : 30 }
203+ assert online_endpoint .auth_mode == "aml_token"
204+ assert online_endpoint .properties == {"some-prop" : "value" }
205+
206+ def test_merge_with_throws_exception_when_name_masmatch (self ) -> None :
207+ online_endpoint = load_online_endpoint (TestKubernetesOnlineEndopint .K8S_ONLINE_ENDPOINT )
208+ other_online_endpoint = copy .deepcopy (online_endpoint )
209+ other_online_endpoint .name = "new_name"
210+
211+ with pytest .raises (ValidationException ) as ex :
212+ online_endpoint ._merge_with (other_online_endpoint )
213+ assert (
214+ ex .value .exc_msg
215+ == "The endpoint name: k8se2etest and new_name are not matched when merging., NoneType: None"
216+ )
217+
218+ def test_to_rest_online_endpoint (self ) -> None :
219+ online_endpoint = load_online_endpoint (TestKubernetesOnlineEndopint .K8S_ONLINE_ENDPOINT )
220+ online_endpoint .public_network_access = "Enabled"
221+ online_endpoint_rest = online_endpoint ._to_rest_online_endpoint ("westus2" )
222+ assert online_endpoint_rest .tags == online_endpoint .tags
223+ assert online_endpoint_rest .properties .compute == online_endpoint .compute
224+ assert online_endpoint_rest .properties .traffic == online_endpoint .traffic
225+ assert online_endpoint_rest .properties .description == online_endpoint .description
226+ assert online_endpoint_rest .properties .mirror_traffic == online_endpoint .mirror_traffic
227+ assert online_endpoint_rest .properties .auth_mode .lower () == online_endpoint .auth_mode
228+ assert online_endpoint_rest .location == "westus2"
229+ assert online_endpoint_rest .identity .type == "SystemAssigned"
230+ assert online_endpoint_rest .properties .public_network_access == online_endpoint .public_network_access
231+
232+ def test_to_rest_online_endpoint_when_identity_none (self ) -> None :
233+ online_endpoint = load_online_endpoint (TestKubernetesOnlineEndopint .K8S_ONLINE_ENDPOINT )
234+ online_endpoint .identity = None
235+ online_endpoint_rest = online_endpoint ._to_rest_online_endpoint ("westus2" )
236+ assert online_endpoint_rest .tags == online_endpoint .tags
237+ assert online_endpoint_rest .properties .compute == online_endpoint .compute
238+ assert online_endpoint_rest .properties .traffic == online_endpoint .traffic
239+ assert online_endpoint_rest .properties .description == online_endpoint .description
240+ assert online_endpoint_rest .properties .mirror_traffic == online_endpoint .mirror_traffic
241+ assert online_endpoint_rest .properties .auth_mode .lower () == online_endpoint .auth_mode
242+ assert online_endpoint_rest .location == "westus2"
243+ assert online_endpoint_rest .identity .type == "SystemAssigned"
244+
245+ def test_to_rest_online_endpoint_raise_exception_identity_type_none (self ) -> None :
246+ online_endpoint = load_online_endpoint (TestKubernetesOnlineEndopint .K8S_ONLINE_ENDPOINT )
247+ online_endpoint .identity .type = None
248+ with pytest .raises (ValidationException ) as ex :
249+ online_endpoint ._to_rest_online_endpoint ("westus2" )
250+ assert str (ex .value ) == "Identity type not found in provided yaml file."
251+
252+ def test_to_rest_online_endpoint_traffic_update (self ) -> None :
253+ online_endpoint = load_online_endpoint (TestKubernetesOnlineEndopint .K8S_ONLINE_ENDPOINT )
254+ online_endpoint_rest = online_endpoint ._to_rest_online_endpoint_traffic_update ("westus2" )
255+ assert online_endpoint_rest .location == "westus2"
256+ assert online_endpoint_rest .tags == online_endpoint .tags
257+ assert online_endpoint_rest .identity .type == "system_assigned"
258+ assert online_endpoint_rest .properties .compute == online_endpoint .compute
259+ assert online_endpoint_rest .properties .description == online_endpoint .description
260+ assert online_endpoint_rest .properties .auth_mode .lower () == online_endpoint .auth_mode
261+ assert online_endpoint_rest .properties .traffic == online_endpoint .traffic
262+
263+ def test_to_dict (self ) -> None :
264+ online_endpoint = load_online_endpoint (TestKubernetesOnlineEndopint .K8S_ONLINE_ENDPOINT )
265+ online_endpoint_dict = online_endpoint ._to_dict ()
266+ assert online_endpoint_dict ["name" ] == online_endpoint .name
267+ assert online_endpoint_dict ["tags" ] == online_endpoint .tags
268+ assert online_endpoint_dict ["identity" ]["type" ] == online_endpoint .identity .type
269+ assert online_endpoint_dict ["traffic" ] == online_endpoint .traffic
270+ assert online_endpoint_dict ["compute" ] == "azureml:inferencecompute"
271+
272+ def test_dump (self ) -> None :
273+ online_endpoint = load_online_endpoint (TestKubernetesOnlineEndopint .K8S_ONLINE_ENDPOINT )
274+ online_endpoint_dict = online_endpoint .dump ()
275+ assert online_endpoint_dict ["name" ] == online_endpoint .name
276+ assert online_endpoint_dict ["tags" ] == online_endpoint .tags
277+ assert online_endpoint_dict ["identity" ]["type" ] == online_endpoint .identity .type
278+ assert online_endpoint_dict ["traffic" ] == online_endpoint .traffic
279+ assert online_endpoint_dict ["compute" ] == "azureml:inferencecompute"
280+
281+
282+ class TestManagedOnlineEndpoint :
283+ ONLINE_ENDPOINT = "tests/test_configs/endpoints/online/online_endpoint_create_mir_private.yml"
284+ BATCH_ENDPOINT_WITH_BLUE = "tests/test_configs/endpoints/batch/batch_endpoint.yaml"
285+
286+ def test_merge_with (self ) -> None :
287+ online_endpoint = load_online_endpoint (TestManagedOnlineEndpoint .ONLINE_ENDPOINT )
288+ other_online_endpoint = copy .deepcopy (online_endpoint )
289+ other_online_endpoint .tags = {"tag3" : "value3" }
290+ other_online_endpoint .traffic = {"blue" : 90 , "green" : 10 }
291+ other_online_endpoint .description = "new description"
292+ other_online_endpoint .mirror_traffic = {"blue" : 30 }
293+ other_online_endpoint .auth_mode = "aml_token"
294+ other_online_endpoint .defaults = {"deployment_name" : "blue" }
295+
296+ online_endpoint ._merge_with (other_online_endpoint )
297+
298+ assert isinstance (online_endpoint , ManagedOnlineEndpoint )
299+ assert online_endpoint .tags == {"dummy" : "dummy" , "endpointkey1" : "newval1" , "tag3" : "value3" }
300+ assert online_endpoint .description == "new description"
301+ assert online_endpoint .traffic == {"blue" : 90 , "green" : 10 }
302+ assert online_endpoint .mirror_traffic == {"blue" : 30 }
303+ assert online_endpoint .auth_mode == "aml_token"
304+ assert online_endpoint .defaults == {"deployment_name" : "blue" }
305+
306+ def test_merge_with_throws_exception_when_name_masmatch (self ) -> None :
307+ online_endpoint = load_online_endpoint (TestManagedOnlineEndpoint .ONLINE_ENDPOINT )
308+ other_online_endpoint = copy .deepcopy (online_endpoint )
309+ other_online_endpoint .name = "new_name"
310+
311+ with pytest .raises (ValidationException ) as ex :
312+ online_endpoint ._merge_with (other_online_endpoint )
313+ assert (
314+ ex .value .exc_msg
315+ == "The endpoint name: mire2etest and new_name are not matched when merging., NoneType: None"
316+ )
317+
318+ def test_to_dict (self ) -> None :
319+ online_endpoint = load_online_endpoint (TestManagedOnlineEndpoint .ONLINE_ENDPOINT )
320+ online_endpoint_dict = online_endpoint ._to_dict ()
321+ assert online_endpoint_dict ["name" ] == online_endpoint .name
322+ assert online_endpoint_dict ["tags" ] == online_endpoint .tags
323+ assert online_endpoint_dict ["identity" ]["type" ] == online_endpoint .identity .type
324+ assert online_endpoint_dict ["traffic" ] == online_endpoint .traffic
325+
326+ def test_dump (self ) -> None :
327+ online_endpoint = load_online_endpoint (TestManagedOnlineEndpoint .ONLINE_ENDPOINT )
328+ online_endpoint_dict = online_endpoint .dump ()
329+ assert online_endpoint_dict ["name" ] == online_endpoint .name
330+ assert online_endpoint_dict ["tags" ] == online_endpoint .tags
331+ assert online_endpoint_dict ["identity" ]["type" ] == online_endpoint .identity .type
332+ assert online_endpoint_dict ["traffic" ] == online_endpoint .traffic
333+
334+ def test_equality (self ) -> None :
335+ online_endpoint = load_online_endpoint (TestManagedOnlineEndpoint .ONLINE_ENDPOINT )
336+ batch_online_endpoint = load_batch_endpoint (TestManagedOnlineEndpoint .BATCH_ENDPOINT_WITH_BLUE )
337+
338+ assert online_endpoint .__eq__ (None )
339+ assert online_endpoint .__eq__ (batch_online_endpoint )
340+
341+ other_online_endpoint = copy .deepcopy (online_endpoint )
342+ assert online_endpoint == other_online_endpoint
343+ assert not online_endpoint != other_online_endpoint
344+
345+ other_online_endpoint .auth_mode = None
346+ assert not online_endpoint == other_online_endpoint
347+ assert online_endpoint != other_online_endpoint
348+
349+ other_online_endpoint .auth_mode = online_endpoint .auth_mode
350+ other_online_endpoint .name = "new_name"
351+ assert not online_endpoint == other_online_endpoint
352+
353+ online_endpoint .name = None
354+ assert not online_endpoint == other_online_endpoint
355+
356+ other_online_endpoint .name = None
357+ assert online_endpoint == other_online_endpoint
358+
359+
360+ class TestEndpointAuthKeys :
361+ def test_to_rest_object (self ) -> None :
362+ auth_keys = EndpointAuthKeys (primary_key = "primary_key" , secondary_key = "secondary_key" )
363+ auth_keys_rest = auth_keys ._to_rest_object ()
364+ assert auth_keys_rest .primary_key == "primary_key"
365+ assert auth_keys_rest .secondary_key == "secondary_key"
366+
367+ def test_from_rest_object (self ) -> None :
368+ rest_auth_keys = RestEndpointAuthKeys (primary_key = "primary_key" , secondary_key = "secondary_key" )
369+ auth_keys = EndpointAuthKeys ._from_rest_object (rest_auth_keys )
370+ assert auth_keys .primary_key == "primary_key"
371+ assert auth_keys .secondary_key == "secondary_key"
372+
373+
374+ class TestEndpointAuthToken :
375+ def test_to_rest_object (self ) -> None :
376+ auth_token = (
377+ EndpointAuthToken (
378+ access_token = "token" ,
379+ expiry_time_utc = "2021-10-01T00:00:00Z" ,
380+ refresh_after_time_utc = "2021-10-01T00:00:00Z" ,
381+ token_type = "Bearer" ,
382+ ),
383+ )
384+ auth_token_rest = auth_token [0 ]._to_rest_object ()
385+ assert auth_token_rest .access_token == "token"
386+ assert auth_token_rest .expiry_time_utc == "2021-10-01T00:00:00Z"
387+ assert auth_token_rest .refresh_after_time_utc == "2021-10-01T00:00:00Z"
388+ assert auth_token_rest .token_type == "Bearer"
389+
390+ def test_from_rest_object (self ) -> None :
391+ rest_auth_token = RestEndpointAuthToken (
392+ access_token = "token" ,
393+ expiry_time_utc = "2021-10-01T00:00:00Z" ,
394+ refresh_after_time_utc = "2021-10-01T00:00:00Z" ,
395+ token_type = "Bearer" ,
396+ )
397+ auth_token = EndpointAuthToken ._from_rest_object (rest_auth_token )
398+ assert auth_token .access_token == "token"
399+ assert auth_token .expiry_time_utc == "2021-10-01T00:00:00Z"
400+ assert auth_token .refresh_after_time_utc == "2021-10-01T00:00:00Z"
401+ assert auth_token .token_type == "Bearer"
0 commit comments