Skip to content

Commit 935a4d9

Browse files
pintaoz-awspintaoz
andauthored
Update list_pods to only display pods of corresponding endpoint type (#227)
* Update list_pods to only display pods of corresponding endpoint type * Use list endpoints to check endpoint type --------- Co-authored-by: pintaoz <[email protected]>
1 parent f571859 commit 935a4d9

File tree

6 files changed

+132
-36
lines changed

6 files changed

+132
-36
lines changed

src/sagemaker/hyperpod/inference/hp_endpoint.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Dict, List, Optional
2020
from sagemaker_core.main.resources import Endpoint
2121
from pydantic import Field, ValidationError
22+
from kubernetes import client
2223

2324

2425
class HPEndpoint(_HPEndpoint, HPEndpointBase):
@@ -211,3 +212,34 @@ def validate_instance_type(self, instance_type: str):
211212
raise Exception(
212213
f"Current HyperPod cluster does not have instance type {instance_type}. Supported instance types are {cluster_instance_types}"
213214
)
215+
216+
@classmethod
217+
@_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pods_endpoint")
218+
def list_pods(cls, namespace=None):
219+
cls.verify_kube_config()
220+
221+
if not namespace:
222+
namespace = get_default_namespace()
223+
224+
v1 = client.CoreV1Api()
225+
list_pods_response = v1.list_namespaced_pod(namespace=namespace)
226+
227+
list_response = cls.call_list_api(
228+
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
229+
namespace=namespace,
230+
)
231+
232+
endpoints = set()
233+
if list_response and list_response["items"]:
234+
for item in list_response["items"]:
235+
endpoints.add(item["metadata"]["name"])
236+
237+
pods = []
238+
for item in list_pods_response.items:
239+
app_name = item.metadata.labels.get("app", None)
240+
if app_name in endpoints:
241+
# list_namespaced_pod will return all pods in the namespace, so we need to filter
242+
# out the pods that are created by custom endpoint
243+
pods.append(item.metadata.name)
244+
245+
return pods

src/sagemaker/hyperpod/inference/hp_endpoint_base.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -209,23 +209,6 @@ def get_logs(
209209

210210
return logs
211211

212-
@classmethod
213-
@_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pods_endpoint")
214-
def list_pods(cls, namespace=None):
215-
cls.verify_kube_config()
216-
217-
if not namespace:
218-
namespace = get_default_namespace()
219-
220-
v1 = client.CoreV1Api()
221-
response = v1.list_namespaced_pod(namespace=namespace)
222-
223-
pods = []
224-
for item in response.items:
225-
pods.append(item.metadata.name)
226-
227-
return pods
228-
229212
@classmethod
230213
@_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_namespaces")
231214
def list_namespaces(cls):

src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
_hyperpod_telemetry_emitter,
2121
)
2222
from sagemaker.hyperpod.common.telemetry.constants import Feature
23+
from kubernetes import client
2324

2425

2526
class HPJumpStartEndpoint(_HPJumpStartEndpoint, HPEndpointBase):
@@ -240,3 +241,34 @@ def validate_instance_type(self, model_id: str, instance_type: str):
240241
raise Exception(
241242
f"Current HyperPod cluster does not have instance type {instance_type}. Supported instance types are {cluster_instance_types}"
242243
)
244+
245+
@classmethod
246+
@_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pods_endpoint")
247+
def list_pods(cls, namespace=None):
248+
cls.verify_kube_config()
249+
250+
if not namespace:
251+
namespace = get_default_namespace()
252+
253+
v1 = client.CoreV1Api()
254+
list_pods_response = v1.list_namespaced_pod(namespace=namespace)
255+
256+
list_response = cls.call_list_api(
257+
kind=JUMPSTART_MODEL_KIND,
258+
namespace=namespace,
259+
)
260+
261+
endpoints = set()
262+
if list_response and list_response["items"]:
263+
for item in list_response["items"]:
264+
endpoints.add(item["metadata"]["name"])
265+
266+
pods = []
267+
for item in list_pods_response.items:
268+
app_name = item.metadata.labels.get("app", None)
269+
if app_name in endpoints:
270+
# list_namespaced_pod will return all pods in the namespace, so we need to filter
271+
# out the pods that are created by jumpstart endpoint
272+
pods.append(item.metadata.name)
273+
274+
return pods

test/unit_tests/inference/test_hp_endpoint.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,37 @@ def test_invoke(self, mock_endpoint_get, mock_get_cluster_context):
194194
body={"input": "test"}, content_type="application/json"
195195
)
196196
self.assertEqual(result, "response")
197+
198+
@patch.object(HPEndpoint, "call_list_api")
199+
@patch("kubernetes.client.CoreV1Api")
200+
@patch.object(HPEndpoint, "verify_kube_config")
201+
def test_list_pods(self, mock_verify_config, mock_core_api, mock_list_api):
202+
mock_pod1 = MagicMock()
203+
mock_pod1.metadata.name = "custom-endpoint-pod1"
204+
mock_pod1.metadata.labels = {"app": "custom-endpoint"}
205+
mock_pod2 = MagicMock()
206+
mock_pod2.metadata.name = "custom-endpoint-pod2"
207+
mock_pod2.metadata.labels = {"app": "custom-endpoint"}
208+
mock_pod3 = MagicMock()
209+
mock_pod3.metadata.name = "not-custom-endpoint-pod"
210+
mock_pod3.metadata.labels = {"app": "not-custom-endpoint"}
211+
mock_core_api.return_value.list_namespaced_pod.return_value.items = [
212+
mock_pod1,
213+
mock_pod2,
214+
mock_pod3,
215+
]
216+
217+
mock_list_api.return_value = {
218+
"items": [
219+
{
220+
"metadata": {"name": "custom-endpoint"}
221+
}
222+
]
223+
}
224+
225+
result = self.endpoint.list_pods(namespace="test-ns")
226+
227+
self.assertEqual(result, ["custom-endpoint-pod1", "custom-endpoint-pod2"])
228+
mock_core_api.return_value.list_namespaced_pod.assert_called_once_with(
229+
namespace="test-ns"
230+
)

test/unit_tests/inference/test_hp_endpoint_base.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -109,25 +109,6 @@ def test_get_logs(self, mock_verify_config, mock_core_api):
109109
timestamps=True,
110110
)
111111

112-
@patch("kubernetes.client.CoreV1Api")
113-
@patch.object(HPEndpointBase, "verify_kube_config")
114-
def test_list_pods(self, mock_verify_config, mock_core_api):
115-
mock_pod1 = MagicMock()
116-
mock_pod1.metadata.name = "pod1"
117-
mock_pod2 = MagicMock()
118-
mock_pod2.metadata.name = "pod2"
119-
mock_core_api.return_value.list_namespaced_pod.return_value.items = [
120-
mock_pod1,
121-
mock_pod2,
122-
]
123-
124-
result = self.base.list_pods(namespace="test-ns")
125-
126-
self.assertEqual(result, ["pod1", "pod2"])
127-
mock_core_api.return_value.list_namespaced_pod.assert_called_once_with(
128-
namespace="test-ns"
129-
)
130-
131112
@patch("kubernetes.client.CoreV1Api")
132113
@patch.object(HPEndpointBase, "verify_kube_config")
133114
def test_list_namespaces(self, mock_verify_config, mock_core_api):

test/unit_tests/inference/test_hp_jumpstart_endpoint.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,37 @@ def test_invoke(self, mock_endpoint_get, mock_get_cluster_context):
140140
body={"input": "test"}, content_type="application/json"
141141
)
142142
self.assertEqual(result, "response")
143+
144+
@patch.object(HPJumpStartEndpoint, "call_list_api")
145+
@patch("kubernetes.client.CoreV1Api")
146+
@patch.object(HPJumpStartEndpoint, "verify_kube_config")
147+
def test_list_pods(self, mock_verify_config, mock_core_api, mock_list_api):
148+
mock_pod1 = MagicMock()
149+
mock_pod1.metadata.name = "js-endpoint-pod1"
150+
mock_pod1.metadata.labels = {"app": "js-endpoint"}
151+
mock_pod2 = MagicMock()
152+
mock_pod2.metadata.name = "js-endpoint-pod2"
153+
mock_pod2.metadata.labels = {"app": "js-endpoint"}
154+
mock_pod3 = MagicMock()
155+
mock_pod3.metadata.name = "not-js-endpoint-pod"
156+
mock_pod3.metadata.labels = {"app": "not-js-endpoint"}
157+
mock_core_api.return_value.list_namespaced_pod.return_value.items = [
158+
mock_pod1,
159+
mock_pod2,
160+
mock_pod3,
161+
]
162+
163+
mock_list_api.return_value = {
164+
"items": [
165+
{
166+
"metadata": {"name": "js-endpoint"}
167+
}
168+
]
169+
}
170+
171+
result = self.endpoint.list_pods(namespace="test-ns")
172+
173+
self.assertEqual(result, ["js-endpoint-pod1", "js-endpoint-pod2"])
174+
mock_core_api.return_value.list_namespaced_pod.assert_called_once_with(
175+
namespace="test-ns"
176+
)

0 commit comments

Comments
 (0)