Skip to content

Commit 8f1824c

Browse files
author
pintaoz
committed
Use list endpoints to check endpoint type
1 parent 3482b14 commit 8f1824c

File tree

4 files changed

+44
-36
lines changed

4 files changed

+44
-36
lines changed

src/sagemaker/hyperpod/inference/hp_endpoint.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -222,21 +222,24 @@ def list_pods(cls, namespace=None):
222222
namespace = get_default_namespace()
223223

224224
v1 = client.CoreV1Api()
225-
response = v1.list_namespaced_pod(namespace=namespace)
225+
list_pods_response = v1.list_namespaced_pod(namespace=namespace)
226+
227+
list_response = cls.call_list_api(
228+
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
229+
namespace=namespace,
230+
)
231+
232+
endpoints = set()
233+
if list_response and list_response["items"]:
234+
for item in list_response["items"]:
235+
endpoints.add(item["metadata"]["name"])
226236

227237
pods = []
228-
for item in response.items:
238+
for item in list_pods_response.items:
229239
app_name = item.metadata.labels.get("app", None)
230-
try:
240+
if app_name in endpoints:
231241
# list_namespaced_pod will return all pods in the namespace, so we need to filter
232242
# out the pods that are created by custom endpoint
233-
cls.call_get_api(
234-
name=app_name,
235-
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
236-
namespace=namespace,
237-
)
238243
pods.append(item.metadata.name)
239-
except Exception:
240-
continue
241244

242245
return pods

src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -251,21 +251,24 @@ def list_pods(cls, namespace=None):
251251
namespace = get_default_namespace()
252252

253253
v1 = client.CoreV1Api()
254-
response = v1.list_namespaced_pod(namespace=namespace)
254+
list_pods_response = v1.list_namespaced_pod(namespace=namespace)
255+
256+
list_response = cls.call_list_api(
257+
kind=JUMPSTART_MODEL_KIND,
258+
namespace=namespace,
259+
)
260+
261+
endpoints = set()
262+
if list_response and list_response["items"]:
263+
for item in list_response["items"]:
264+
endpoints.add(item["metadata"]["name"])
255265

256266
pods = []
257-
for item in response.items:
267+
for item in list_pods_response.items:
258268
app_name = item.metadata.labels.get("app", None)
259-
try:
269+
if app_name in endpoints:
260270
# list_namespaced_pod will return all pods in the namespace, so we need to filter
261271
# out the pods that are created by jumpstart endpoint
262-
cls.call_get_api(
263-
name=app_name,
264-
kind=JUMPSTART_MODEL_KIND,
265-
namespace=namespace,
266-
)
267272
pods.append(item.metadata.name)
268-
except Exception:
269-
continue
270273

271274
return pods

test/unit_tests/inference/test_hp_endpoint.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,10 @@ def test_invoke(self, mock_endpoint_get, mock_get_cluster_context):
195195
)
196196
self.assertEqual(result, "response")
197197

198-
@patch.object(HPEndpoint, "call_get_api")
198+
@patch.object(HPEndpoint, "call_list_api")
199199
@patch("kubernetes.client.CoreV1Api")
200200
@patch.object(HPEndpoint, "verify_kube_config")
201-
def test_list_pods(self, mock_verify_config, mock_core_api, mock_get_api):
201+
def test_list_pods(self, mock_verify_config, mock_core_api, mock_list_api):
202202
mock_pod1 = MagicMock()
203203
mock_pod1.metadata.name = "custom-endpoint-pod1"
204204
mock_pod1.metadata.labels = {"app": "custom-endpoint"}
@@ -214,12 +214,13 @@ def test_list_pods(self, mock_verify_config, mock_core_api, mock_get_api):
214214
mock_pod3,
215215
]
216216

217-
def mock_behavior(name, kind, namespace):
218-
if name.startswith("custom-endpoint"):
219-
return
220-
else:
221-
raise Exception("Endpoint not found")
222-
mock_get_api.side_effect = mock_behavior
217+
mock_list_api.return_value = {
218+
"items": [
219+
{
220+
"metadata": {"name": "custom-endpoint"}
221+
}
222+
]
223+
}
223224

224225
result = self.endpoint.list_pods(namespace="test-ns")
225226

test/unit_tests/inference/test_hp_jumpstart_endpoint.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,10 @@ def test_invoke(self, mock_endpoint_get, mock_get_cluster_context):
141141
)
142142
self.assertEqual(result, "response")
143143

144-
@patch.object(HPJumpStartEndpoint, "call_get_api")
144+
@patch.object(HPJumpStartEndpoint, "call_list_api")
145145
@patch("kubernetes.client.CoreV1Api")
146146
@patch.object(HPJumpStartEndpoint, "verify_kube_config")
147-
def test_list_pods(self, mock_verify_config, mock_core_api, mock_get_api):
147+
def test_list_pods(self, mock_verify_config, mock_core_api, mock_list_api):
148148
mock_pod1 = MagicMock()
149149
mock_pod1.metadata.name = "js-endpoint-pod1"
150150
mock_pod1.metadata.labels = {"app": "js-endpoint"}
@@ -160,12 +160,13 @@ def test_list_pods(self, mock_verify_config, mock_core_api, mock_get_api):
160160
mock_pod3,
161161
]
162162

163-
def mock_behavior(name, kind, namespace):
164-
if name.startswith("js-endpoint"):
165-
return
166-
else:
167-
raise Exception("Endpoint not found")
168-
mock_get_api.side_effect = mock_behavior
163+
mock_list_api.return_value = {
164+
"items": [
165+
{
166+
"metadata": {"name": "js-endpoint"}
167+
}
168+
]
169+
}
169170

170171
result = self.endpoint.list_pods(namespace="test-ns")
171172

0 commit comments

Comments
 (0)