Skip to content

Commit 91504e9

Browse files
pintaoz-awspintaoz
andauthored
Add enpoint_name argument for list_pods() (#232)
* Add enpoint_name argument for list_pods() * update test name --------- Co-authored-by: pintaoz <[email protected]>
1 parent 6f452bf commit 91504e9

File tree

6 files changed

+90
-22
lines changed

6 files changed

+90
-22
lines changed

src/sagemaker/hyperpod/cli/commands/inference.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,16 +614,23 @@ def custom_delete(
614614
default="default",
615615
help="Optional. The namespace of the jumpstart model to list pods for. Default set to 'default'.",
616616
)
617+
@click.option(
618+
"--endpoint-name",
619+
type=click.STRING,
620+
required=False,
621+
help="Optional. The name of the jumpstart endpoint to list pods.",
622+
)
617623
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "list_pods_js_endpoint_cli")
618624
@handle_cli_exceptions()
619625
def js_list_pods(
620626
namespace: Optional[str],
627+
endpoint_name: Optional[str],
621628
):
622629
"""
623630
List all pods related to jumpstart model endpoint.
624631
"""
625632
my_endpoint = HPJumpStartEndpoint.model_construct()
626-
pods = my_endpoint.list_pods(namespace=namespace)
633+
pods = my_endpoint.list_pods(namespace=namespace, endpoint_name=endpoint_name)
627634
click.echo(pods)
628635

629636

@@ -635,16 +642,23 @@ def js_list_pods(
635642
default="default",
636643
help="Optional. The namespace of the custom model to list pods for. Default set to 'default'.",
637644
)
645+
@click.option(
646+
"--endpoint-name",
647+
type=click.STRING,
648+
required=False,
649+
help="Optional. The name of the custom model endpoint to list pods.",
650+
)
638651
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "list_pods_custom_endpoint_cli")
639652
@handle_cli_exceptions()
640653
def custom_list_pods(
641654
namespace: Optional[str],
655+
endpoint_name: Optional[str],
642656
):
643657
"""
644658
List all pods related to custom model endpoint.
645659
"""
646660
my_endpoint = HPEndpoint.model_construct()
647-
pods = my_endpoint.list_pods(namespace=namespace)
661+
pods = my_endpoint.list_pods(namespace=namespace, endpoint_name=endpoint_name)
648662
click.echo(pods)
649663

650664

src/sagemaker/hyperpod/inference/hp_endpoint.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def validate_instance_type(self, instance_type: str):
215215

216216
@classmethod
217217
@_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pods_endpoint")
218-
def list_pods(cls, namespace=None):
218+
def list_pods(cls, namespace=None, endpoint_name=None):
219219
cls.verify_kube_config()
220220

221221
if not namespace:
@@ -224,15 +224,17 @@ def list_pods(cls, namespace=None):
224224
v1 = client.CoreV1Api()
225225
list_pods_response = v1.list_namespaced_pod(namespace=namespace)
226226

227-
list_response = cls.call_list_api(
228-
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
229-
namespace=namespace,
230-
)
231-
232227
endpoints = set()
233-
if list_response and list_response["items"]:
234-
for item in list_response["items"]:
235-
endpoints.add(item["metadata"]["name"])
228+
if endpoint_name:
229+
endpoints.add(endpoint_name)
230+
else:
231+
list_response = cls.call_list_api(
232+
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
233+
namespace=namespace,
234+
)
235+
if list_response and list_response["items"]:
236+
for item in list_response["items"]:
237+
endpoints.add(item["metadata"]["name"])
236238

237239
pods = []
238240
for item in list_pods_response.items:

src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def validate_instance_type(self, model_id: str, instance_type: str):
244244

245245
@classmethod
246246
@_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pods_endpoint")
247-
def list_pods(cls, namespace=None):
247+
def list_pods(cls, namespace=None, endpoint_name=None):
248248
cls.verify_kube_config()
249249

250250
if not namespace:
@@ -253,15 +253,17 @@ def list_pods(cls, namespace=None):
253253
v1 = client.CoreV1Api()
254254
list_pods_response = v1.list_namespaced_pod(namespace=namespace)
255255

256-
list_response = cls.call_list_api(
257-
kind=JUMPSTART_MODEL_KIND,
258-
namespace=namespace,
259-
)
260-
261256
endpoints = set()
262-
if list_response and list_response["items"]:
263-
for item in list_response["items"]:
264-
endpoints.add(item["metadata"]["name"])
257+
if endpoint_name:
258+
endpoints.add(endpoint_name)
259+
else:
260+
list_response = cls.call_list_api(
261+
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
262+
namespace=namespace,
263+
)
264+
if list_response and list_response["items"]:
265+
for item in list_response["items"]:
266+
endpoints.add(item["metadata"]["name"])
265267

266268
pods = []
267269
for item in list_pods_response.items:

test/unit_tests/cli/test_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def test_js_list_pods(mock_hp, mock_namespace_exists):
309309
inst = Mock(list_pods=Mock(return_value="pods"))
310310
mock_hp.model_construct.return_value = inst
311311
runner = CliRunner()
312-
result = runner.invoke(js_list_pods, ['--namespace', 'ns'])
312+
result = runner.invoke(js_list_pods, ['--namespace', 'ns', '--endpoint-name', 'js-endpoint'])
313313
assert result.exit_code == 0
314314
assert 'pods' in result.output
315315

@@ -320,7 +320,7 @@ def test_custom_list_pods(mock_hp, mock_namespace_exists):
320320
inst = Mock(list_pods=Mock(return_value="pods"))
321321
mock_hp.model_construct.return_value = inst
322322
runner = CliRunner()
323-
result = runner.invoke(custom_list_pods, ['--namespace', 'ns'])
323+
result = runner.invoke(custom_list_pods, ['--namespace', 'ns', '--endpoint-name', 'custom-endpoint'])
324324
assert result.exit_code == 0
325325
assert 'pods' in result.output
326326

test/unit_tests/inference/test_hp_endpoint.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,28 @@ def test_list_pods(self, mock_verify_config, mock_core_api, mock_list_api):
228228
mock_core_api.return_value.list_namespaced_pod.assert_called_once_with(
229229
namespace="test-ns"
230230
)
231+
232+
@patch("kubernetes.client.CoreV1Api")
233+
@patch.object(HPEndpoint, "verify_kube_config")
234+
def test_list_pods_with_endpoint_name(self, mock_verify_config, mock_core_api):
235+
mock_pod1 = MagicMock()
236+
mock_pod1.metadata.name = "custom-endpoint1-pod1"
237+
mock_pod1.metadata.labels = {"app": "custom-endpoint1"}
238+
mock_pod2 = MagicMock()
239+
mock_pod2.metadata.name = "custom-endpoint1-pod2"
240+
mock_pod2.metadata.labels = {"app": "custom-endpoint1"}
241+
mock_pod3 = MagicMock()
242+
mock_pod3.metadata.name = "custom-endpoint2-pod2"
243+
mock_pod3.metadata.labels = {"app": "custom-endpoint2"}
244+
mock_core_api.return_value.list_namespaced_pod.return_value.items = [
245+
mock_pod1,
246+
mock_pod2,
247+
mock_pod3,
248+
]
249+
250+
result = self.endpoint.list_pods(namespace="test-ns", endpoint_name="custom-endpoint1")
251+
252+
self.assertEqual(result, ["custom-endpoint1-pod1", "custom-endpoint1-pod2"])
253+
mock_core_api.return_value.list_namespaced_pod.assert_called_once_with(
254+
namespace="test-ns"
255+
)

test/unit_tests/inference/test_hp_jumpstart_endpoint.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,28 @@ def test_list_pods(self, mock_verify_config, mock_core_api, mock_list_api):
174174
mock_core_api.return_value.list_namespaced_pod.assert_called_once_with(
175175
namespace="test-ns"
176176
)
177+
178+
@patch("kubernetes.client.CoreV1Api")
179+
@patch.object(HPJumpStartEndpoint, "verify_kube_config")
180+
def test_list_pods_with_endpoint_name(self, mock_verify_config, mock_core_api):
181+
mock_pod1 = MagicMock()
182+
mock_pod1.metadata.name = "js-endpoint1-pod1"
183+
mock_pod1.metadata.labels = {"app": "js-endpoint1"}
184+
mock_pod2 = MagicMock()
185+
mock_pod2.metadata.name = "js-endpoint1-pod2"
186+
mock_pod2.metadata.labels = {"app": "js-endpoint1"}
187+
mock_pod3 = MagicMock()
188+
mock_pod3.metadata.name = "js-endpoint2-pod"
189+
mock_pod3.metadata.labels = {"app": "js-endpoint2"}
190+
mock_core_api.return_value.list_namespaced_pod.return_value.items = [
191+
mock_pod1,
192+
mock_pod2,
193+
mock_pod3,
194+
]
195+
196+
result = self.endpoint.list_pods(namespace="test-ns", endpoint_name="js-endpoint1")
197+
198+
self.assertEqual(result, ["js-endpoint1-pod1", "js-endpoint1-pod2"])
199+
mock_core_api.return_value.list_namespaced_pod.assert_called_once_with(
200+
namespace="test-ns"
201+
)

0 commit comments

Comments
 (0)