Skip to content

Commit 0c555dd

Browse files
authored
Add unit tests for endpoint entities (Azure#39604)
* add tests for kubenetes online endpoint * add tests for managed online endpoint * code fotmatting * add test for EndpointAuthKeys * add unit tests for online endpoint * code formatting * add unit tests for batch endpoint * add tests for identity validation * add tests for EndpointAuthToken * add equality tests for ManagedOnlineEndpoint * git fix failing test * add defaults to merge with
1 parent ba26689 commit 0c555dd

File tree

4 files changed

+376
-2
lines changed

4 files changed

+376
-2
lines changed

sdk/ml/azure-ai-ml/tests/batch_online_common/unittests/test_endpoint_entity.py

Lines changed: 298 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,24 @@
11
import pytest
22
import yaml
3+
import json
4+
import copy
35
from test_utilities.utils import verify_entity_load_and_dump
4-
6+
from azure.ai.ml._restclient.v2022_02_01_preview.models import (
7+
OnlineEndpointData,
8+
EndpointAuthKeys as RestEndpointAuthKeys,
9+
EndpointAuthToken as RestEndpointAuthToken,
10+
)
11+
from azure.ai.ml._restclient.v2023_10_01.models import BatchEndpoint as BatchEndpointData
512
from azure.ai.ml import load_batch_endpoint, load_online_endpoint
6-
from azure.ai.ml.entities import BatchEndpoint, Endpoint, ManagedOnlineDeployment, OnlineEndpoint
13+
from azure.ai.ml.entities import (
14+
BatchEndpoint,
15+
ManagedOnlineEndpoint,
16+
KubernetesOnlineEndpoint,
17+
OnlineEndpoint,
18+
EndpointAuthKeys,
19+
EndpointAuthToken,
20+
)
21+
from azure.ai.ml.exceptions import ValidationException
722

823

924
@pytest.mark.production_experiences_test
@@ -12,6 +27,7 @@ class TestOnlineEndpointYAML:
1227
SIMPLE_ENDPOINT_WITH_BLUE_BAD = "tests/test_configs/endpoints/online/online_endpoint_create_aks_bad.yml"
1328
MINIMAL_ENDPOINT = "tests/test_configs/endpoints/online/online_endpoint_minimal.yaml"
1429
MINIMAL_DEPLOYMENT = "tests/test_configs/deployments/online/online_endpoint_deployment_k8s_minimum.yml"
30+
ONLINE_ENDPOINT_REST = "tests/test_configs/endpoints/online/online_endpoint_rest.json"
1531

1632
def test_specific_endpoint_load_and_dump(self) -> None:
1733
with open(TestOnlineEndpointYAML.MINIMAL_ENDPOINT, "r") as f:
@@ -39,10 +55,41 @@ def test_online_endpoint_to_rest_object_with_no_issue(self) -> None:
3955
endpoint = load_online_endpoint(TestOnlineEndpointYAML.MINIMAL_ENDPOINT)
4056
endpoint._to_rest_online_endpoint("westus2")
4157

58+
def test_from_rest_object_kubenetes(self) -> None:
59+
with open(TestOnlineEndpointYAML.ONLINE_ENDPOINT_REST, "r") as f:
60+
online_deployment_rest = OnlineEndpointData.deserialize(json.load(f))
61+
online_endpoint = OnlineEndpoint._from_rest_object(online_deployment_rest)
62+
assert isinstance(online_endpoint, KubernetesOnlineEndpoint)
63+
assert online_endpoint.name == online_deployment_rest.name
64+
assert online_endpoint.compute == online_deployment_rest.properties.compute
65+
assert online_endpoint.tags == online_deployment_rest.tags
66+
assert online_endpoint.traffic == online_deployment_rest.properties.traffic
67+
assert online_endpoint.description == online_deployment_rest.properties.description
68+
assert online_endpoint.provisioning_state == online_deployment_rest.properties.provisioning_state
69+
assert online_endpoint.identity.type == "system_assigned"
70+
assert online_endpoint.identity.principal_id == online_deployment_rest.identity.principal_id
71+
assert online_endpoint.properties["createdBy"] == online_deployment_rest.system_data.created_by
72+
73+
def test_from_rest_object_managed(self) -> None:
74+
with open(TestOnlineEndpointYAML.ONLINE_ENDPOINT_REST, "r") as f:
75+
online_deployment_rest = OnlineEndpointData.deserialize(json.load(f))
76+
online_deployment_rest.properties.compute = None
77+
online_endpoint = OnlineEndpoint._from_rest_object(online_deployment_rest)
78+
assert isinstance(online_endpoint, ManagedOnlineEndpoint)
79+
assert online_endpoint.name == online_deployment_rest.name
80+
assert online_endpoint.tags == online_deployment_rest.tags
81+
assert online_endpoint.traffic == online_deployment_rest.properties.traffic
82+
assert online_endpoint.description == online_deployment_rest.properties.description
83+
assert online_endpoint.provisioning_state == online_deployment_rest.properties.provisioning_state
84+
assert online_endpoint.identity.type == "system_assigned"
85+
assert online_endpoint.identity.principal_id == online_deployment_rest.identity.principal_id
86+
assert online_endpoint.properties["createdBy"] == online_deployment_rest.system_data.created_by
87+
4288

4389
@pytest.mark.unittest
4490
class TestBatchEndpointYAML:
4591
BATCH_ENDPOINT_WITH_BLUE = "tests/test_configs/endpoints/batch/batch_endpoint.yaml"
92+
BATCH_ENDPOINT_REST = "tests/test_configs/endpoints/batch/batch_endpoint_rest.json"
4693

4794
def test_generic_endpoint_load_and_dump_2(self) -> None:
4895
with open(TestBatchEndpointYAML.BATCH_ENDPOINT_WITH_BLUE, "r") as f:
@@ -71,6 +118,32 @@ def test_to_rest_batch_endpoint(self) -> None:
71118
assert len(rest_batch_endpoint.tags)
72119
assert rest_batch_endpoint.tags == target["tags"]
73120

121+
def test_to_dict(self) -> None:
122+
endpoint = load_batch_endpoint(TestBatchEndpointYAML.BATCH_ENDPOINT_WITH_BLUE)
123+
endpoint_dict = endpoint._to_dict()
124+
125+
assert endpoint_dict["name"] == endpoint.name
126+
assert endpoint_dict["description"] == endpoint.description
127+
assert endpoint_dict["auth_mode"] == endpoint.auth_mode
128+
assert endpoint_dict["tags"] == endpoint.tags
129+
assert endpoint_dict["auth_mode"] == "aad_token"
130+
assert endpoint_dict["properties"] == endpoint.properties
131+
132+
def test_from_rest(self) -> None:
133+
with open(TestBatchEndpointYAML.BATCH_ENDPOINT_REST, "r") as f:
134+
batch_endpoint_rest = BatchEndpointData.deserialize(json.load(f))
135+
batch_endpoint = BatchEndpoint._from_rest_object(batch_endpoint_rest)
136+
assert batch_endpoint.name == batch_endpoint_rest.name
137+
assert batch_endpoint.id == batch_endpoint_rest.id
138+
assert batch_endpoint.tags == batch_endpoint_rest.tags
139+
assert batch_endpoint.properties == batch_endpoint_rest.properties.properties
140+
assert batch_endpoint.auth_mode == "aad_token"
141+
assert batch_endpoint.description == batch_endpoint_rest.properties.description
142+
assert batch_endpoint.location == batch_endpoint_rest.location
143+
assert batch_endpoint.provisioning_state == batch_endpoint_rest.properties.provisioning_state
144+
assert batch_endpoint.scoring_uri == batch_endpoint_rest.properties.scoring_uri
145+
assert batch_endpoint.openapi_uri == batch_endpoint_rest.properties.swagger_uri
146+
74147
def test_batch_endpoint_with_deployment_name_promoted_param_only(self) -> None:
75148
endpoint = BatchEndpoint(
76149
name="my-batch-endpoint",
@@ -103,3 +176,226 @@ def test_batch_endpoint_with_deployment_no_defaults(self) -> None:
103176
)
104177

105178
assert endpoint.defaults is None
179+
180+
181+
class TestKubernetesOnlineEndopint:
182+
K8S_ONLINE_ENDPOINT = "tests/test_configs/endpoints/online/online_endpoint_create_k8s.yml"
183+
184+
def test_merge_with(self) -> None:
185+
online_endpoint = load_online_endpoint(TestKubernetesOnlineEndopint.K8S_ONLINE_ENDPOINT)
186+
other_online_endpoint = copy.deepcopy(online_endpoint)
187+
other_online_endpoint.compute = "k8ecompute"
188+
other_online_endpoint.tags = {"tag3": "value3"}
189+
other_online_endpoint.traffic = {"blue": 90, "green": 10}
190+
other_online_endpoint.description = "new description"
191+
other_online_endpoint.mirror_traffic = {"blue": 30}
192+
other_online_endpoint.auth_mode = "aml_token"
193+
other_online_endpoint.properties = {"some-prop": "value"}
194+
195+
online_endpoint._merge_with(other_online_endpoint)
196+
197+
assert isinstance(online_endpoint, KubernetesOnlineEndpoint)
198+
assert online_endpoint.compute == "k8ecompute"
199+
assert online_endpoint.tags == {"tag1": "value1", "tag2": "value2", "tag3": "value3"}
200+
assert online_endpoint.description == "new description"
201+
assert online_endpoint.traffic == {"blue": 90, "green": 10}
202+
assert online_endpoint.mirror_traffic == {"blue": 30}
203+
assert online_endpoint.auth_mode == "aml_token"
204+
assert online_endpoint.properties == {"some-prop": "value"}
205+
206+
def test_merge_with_throws_exception_when_name_masmatch(self) -> None:
207+
online_endpoint = load_online_endpoint(TestKubernetesOnlineEndopint.K8S_ONLINE_ENDPOINT)
208+
other_online_endpoint = copy.deepcopy(online_endpoint)
209+
other_online_endpoint.name = "new_name"
210+
211+
with pytest.raises(ValidationException) as ex:
212+
online_endpoint._merge_with(other_online_endpoint)
213+
assert (
214+
ex.value.exc_msg
215+
== "The endpoint name: k8se2etest and new_name are not matched when merging., NoneType: None"
216+
)
217+
218+
def test_to_rest_online_endpoint(self) -> None:
219+
online_endpoint = load_online_endpoint(TestKubernetesOnlineEndopint.K8S_ONLINE_ENDPOINT)
220+
online_endpoint.public_network_access = "Enabled"
221+
online_endpoint_rest = online_endpoint._to_rest_online_endpoint("westus2")
222+
assert online_endpoint_rest.tags == online_endpoint.tags
223+
assert online_endpoint_rest.properties.compute == online_endpoint.compute
224+
assert online_endpoint_rest.properties.traffic == online_endpoint.traffic
225+
assert online_endpoint_rest.properties.description == online_endpoint.description
226+
assert online_endpoint_rest.properties.mirror_traffic == online_endpoint.mirror_traffic
227+
assert online_endpoint_rest.properties.auth_mode.lower() == online_endpoint.auth_mode
228+
assert online_endpoint_rest.location == "westus2"
229+
assert online_endpoint_rest.identity.type == "SystemAssigned"
230+
assert online_endpoint_rest.properties.public_network_access == online_endpoint.public_network_access
231+
232+
def test_to_rest_online_endpoint_when_identity_none(self) -> None:
233+
online_endpoint = load_online_endpoint(TestKubernetesOnlineEndopint.K8S_ONLINE_ENDPOINT)
234+
online_endpoint.identity = None
235+
online_endpoint_rest = online_endpoint._to_rest_online_endpoint("westus2")
236+
assert online_endpoint_rest.tags == online_endpoint.tags
237+
assert online_endpoint_rest.properties.compute == online_endpoint.compute
238+
assert online_endpoint_rest.properties.traffic == online_endpoint.traffic
239+
assert online_endpoint_rest.properties.description == online_endpoint.description
240+
assert online_endpoint_rest.properties.mirror_traffic == online_endpoint.mirror_traffic
241+
assert online_endpoint_rest.properties.auth_mode.lower() == online_endpoint.auth_mode
242+
assert online_endpoint_rest.location == "westus2"
243+
assert online_endpoint_rest.identity.type == "SystemAssigned"
244+
245+
def test_to_rest_online_endpoint_raise_exception_identity_type_none(self) -> None:
246+
online_endpoint = load_online_endpoint(TestKubernetesOnlineEndopint.K8S_ONLINE_ENDPOINT)
247+
online_endpoint.identity.type = None
248+
with pytest.raises(ValidationException) as ex:
249+
online_endpoint._to_rest_online_endpoint("westus2")
250+
assert str(ex.value) == "Identity type not found in provided yaml file."
251+
252+
def test_to_rest_online_endpoint_traffic_update(self) -> None:
253+
online_endpoint = load_online_endpoint(TestKubernetesOnlineEndopint.K8S_ONLINE_ENDPOINT)
254+
online_endpoint_rest = online_endpoint._to_rest_online_endpoint_traffic_update("westus2")
255+
assert online_endpoint_rest.location == "westus2"
256+
assert online_endpoint_rest.tags == online_endpoint.tags
257+
assert online_endpoint_rest.identity.type == "system_assigned"
258+
assert online_endpoint_rest.properties.compute == online_endpoint.compute
259+
assert online_endpoint_rest.properties.description == online_endpoint.description
260+
assert online_endpoint_rest.properties.auth_mode.lower() == online_endpoint.auth_mode
261+
assert online_endpoint_rest.properties.traffic == online_endpoint.traffic
262+
263+
def test_to_dict(self) -> None:
264+
online_endpoint = load_online_endpoint(TestKubernetesOnlineEndopint.K8S_ONLINE_ENDPOINT)
265+
online_endpoint_dict = online_endpoint._to_dict()
266+
assert online_endpoint_dict["name"] == online_endpoint.name
267+
assert online_endpoint_dict["tags"] == online_endpoint.tags
268+
assert online_endpoint_dict["identity"]["type"] == online_endpoint.identity.type
269+
assert online_endpoint_dict["traffic"] == online_endpoint.traffic
270+
assert online_endpoint_dict["compute"] == "azureml:inferencecompute"
271+
272+
def test_dump(self) -> None:
273+
online_endpoint = load_online_endpoint(TestKubernetesOnlineEndopint.K8S_ONLINE_ENDPOINT)
274+
online_endpoint_dict = online_endpoint.dump()
275+
assert online_endpoint_dict["name"] == online_endpoint.name
276+
assert online_endpoint_dict["tags"] == online_endpoint.tags
277+
assert online_endpoint_dict["identity"]["type"] == online_endpoint.identity.type
278+
assert online_endpoint_dict["traffic"] == online_endpoint.traffic
279+
assert online_endpoint_dict["compute"] == "azureml:inferencecompute"
280+
281+
282+
class TestManagedOnlineEndpoint:
283+
ONLINE_ENDPOINT = "tests/test_configs/endpoints/online/online_endpoint_create_mir_private.yml"
284+
BATCH_ENDPOINT_WITH_BLUE = "tests/test_configs/endpoints/batch/batch_endpoint.yaml"
285+
286+
def test_merge_with(self) -> None:
287+
online_endpoint = load_online_endpoint(TestManagedOnlineEndpoint.ONLINE_ENDPOINT)
288+
other_online_endpoint = copy.deepcopy(online_endpoint)
289+
other_online_endpoint.tags = {"tag3": "value3"}
290+
other_online_endpoint.traffic = {"blue": 90, "green": 10}
291+
other_online_endpoint.description = "new description"
292+
other_online_endpoint.mirror_traffic = {"blue": 30}
293+
other_online_endpoint.auth_mode = "aml_token"
294+
other_online_endpoint.defaults = {"deployment_name": "blue"}
295+
296+
online_endpoint._merge_with(other_online_endpoint)
297+
298+
assert isinstance(online_endpoint, ManagedOnlineEndpoint)
299+
assert online_endpoint.tags == {"dummy": "dummy", "endpointkey1": "newval1", "tag3": "value3"}
300+
assert online_endpoint.description == "new description"
301+
assert online_endpoint.traffic == {"blue": 90, "green": 10}
302+
assert online_endpoint.mirror_traffic == {"blue": 30}
303+
assert online_endpoint.auth_mode == "aml_token"
304+
assert online_endpoint.defaults == {"deployment_name": "blue"}
305+
306+
def test_merge_with_throws_exception_when_name_masmatch(self) -> None:
307+
online_endpoint = load_online_endpoint(TestManagedOnlineEndpoint.ONLINE_ENDPOINT)
308+
other_online_endpoint = copy.deepcopy(online_endpoint)
309+
other_online_endpoint.name = "new_name"
310+
311+
with pytest.raises(ValidationException) as ex:
312+
online_endpoint._merge_with(other_online_endpoint)
313+
assert (
314+
ex.value.exc_msg
315+
== "The endpoint name: mire2etest and new_name are not matched when merging., NoneType: None"
316+
)
317+
318+
def test_to_dict(self) -> None:
319+
online_endpoint = load_online_endpoint(TestManagedOnlineEndpoint.ONLINE_ENDPOINT)
320+
online_endpoint_dict = online_endpoint._to_dict()
321+
assert online_endpoint_dict["name"] == online_endpoint.name
322+
assert online_endpoint_dict["tags"] == online_endpoint.tags
323+
assert online_endpoint_dict["identity"]["type"] == online_endpoint.identity.type
324+
assert online_endpoint_dict["traffic"] == online_endpoint.traffic
325+
326+
def test_dump(self) -> None:
327+
online_endpoint = load_online_endpoint(TestManagedOnlineEndpoint.ONLINE_ENDPOINT)
328+
online_endpoint_dict = online_endpoint.dump()
329+
assert online_endpoint_dict["name"] == online_endpoint.name
330+
assert online_endpoint_dict["tags"] == online_endpoint.tags
331+
assert online_endpoint_dict["identity"]["type"] == online_endpoint.identity.type
332+
assert online_endpoint_dict["traffic"] == online_endpoint.traffic
333+
334+
def test_equality(self) -> None:
335+
online_endpoint = load_online_endpoint(TestManagedOnlineEndpoint.ONLINE_ENDPOINT)
336+
batch_online_endpoint = load_batch_endpoint(TestManagedOnlineEndpoint.BATCH_ENDPOINT_WITH_BLUE)
337+
338+
assert online_endpoint.__eq__(None)
339+
assert online_endpoint.__eq__(batch_online_endpoint)
340+
341+
other_online_endpoint = copy.deepcopy(online_endpoint)
342+
assert online_endpoint == other_online_endpoint
343+
assert not online_endpoint != other_online_endpoint
344+
345+
other_online_endpoint.auth_mode = None
346+
assert not online_endpoint == other_online_endpoint
347+
assert online_endpoint != other_online_endpoint
348+
349+
other_online_endpoint.auth_mode = online_endpoint.auth_mode
350+
other_online_endpoint.name = "new_name"
351+
assert not online_endpoint == other_online_endpoint
352+
353+
online_endpoint.name = None
354+
assert not online_endpoint == other_online_endpoint
355+
356+
other_online_endpoint.name = None
357+
assert online_endpoint == other_online_endpoint
358+
359+
360+
class TestEndpointAuthKeys:
361+
def test_to_rest_object(self) -> None:
362+
auth_keys = EndpointAuthKeys(primary_key="primary_key", secondary_key="secondary_key")
363+
auth_keys_rest = auth_keys._to_rest_object()
364+
assert auth_keys_rest.primary_key == "primary_key"
365+
assert auth_keys_rest.secondary_key == "secondary_key"
366+
367+
def test_from_rest_object(self) -> None:
368+
rest_auth_keys = RestEndpointAuthKeys(primary_key="primary_key", secondary_key="secondary_key")
369+
auth_keys = EndpointAuthKeys._from_rest_object(rest_auth_keys)
370+
assert auth_keys.primary_key == "primary_key"
371+
assert auth_keys.secondary_key == "secondary_key"
372+
373+
374+
class TestEndpointAuthToken:
375+
def test_to_rest_object(self) -> None:
376+
auth_token = (
377+
EndpointAuthToken(
378+
access_token="token",
379+
expiry_time_utc="2021-10-01T00:00:00Z",
380+
refresh_after_time_utc="2021-10-01T00:00:00Z",
381+
token_type="Bearer",
382+
),
383+
)
384+
auth_token_rest = auth_token[0]._to_rest_object()
385+
assert auth_token_rest.access_token == "token"
386+
assert auth_token_rest.expiry_time_utc == "2021-10-01T00:00:00Z"
387+
assert auth_token_rest.refresh_after_time_utc == "2021-10-01T00:00:00Z"
388+
assert auth_token_rest.token_type == "Bearer"
389+
390+
def test_from_rest_object(self) -> None:
391+
rest_auth_token = RestEndpointAuthToken(
392+
access_token="token",
393+
expiry_time_utc="2021-10-01T00:00:00Z",
394+
refresh_after_time_utc="2021-10-01T00:00:00Z",
395+
token_type="Bearer",
396+
)
397+
auth_token = EndpointAuthToken._from_rest_object(rest_auth_token)
398+
assert auth_token.access_token == "token"
399+
assert auth_token.expiry_time_utc == "2021-10-01T00:00:00Z"
400+
assert auth_token.refresh_after_time_utc == "2021-10-01T00:00:00Z"
401+
assert auth_token.token_type == "Bearer"
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"id": "/subscriptions/some-sub-id/resourceGroups/some-rg/providers/Microsoft.MachineLearningServices/workspaces/some-ws/batchEndpoints/some-batch-endpoint-name",
3+
"name": "some-batch-endpoint-name",
4+
"type": "Microsoft.MachineLearningServices/workspaces/batchEndpoints",
5+
"properties": {
6+
"description": "A hello world endpoint for component deployments",
7+
"properties": {
8+
"BatchEndpointCreationApiVersion": "2023-10-01",
9+
"azureml.onlineendpointid": "/subscriptions/some-sub-id/resourceGroups/some-rg/providers/Microsoft.MachineLearningServices/workspaces/some-ws/batchEndpoints/some-batch-endpoint-name"
10+
},
11+
"scoringUri": "https://some-batch-endpoint-name.eastus.inference.ml.azure.com/jobs",
12+
"swaggerUri": null,
13+
"authMode": "AADToken",
14+
"defaults": {
15+
"deploymentName": "hello-world-1"
16+
},
17+
"provisioningState": "Succeeded"
18+
},
19+
"systemData": {
20+
"createdAt": "2025-01-29T06:39:10.8357986+00:00",
21+
"createdBy": "Someone",
22+
"lastModifiedAt": "2025-02-01T17:53:36.0766595+00:00"
23+
},
24+
"tags": {},
25+
"location": "eastus",
26+
"identity": {
27+
"type": "SystemAssigned",
28+
"principalId": "d6133098-1d22-4ea2-a875-709b47f277d1",
29+
"tenantId": "72f988bf-86f1-41af-91ab-2d7cd011db47"
30+
}
31+
}

0 commit comments

Comments
 (0)