Skip to content

Commit 9c51633

Browse files
author
Roja Reddy Sareddy
committed
feat: add get_operator_logs to pytorch job
1 parent c2cc4b0 commit 9c51633

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

src/sagemaker/hyperpod/training/hyperpod_pytorch_job.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
PLURAL = "hyperpodpytorchjobs"
2525
KIND = "HyperPodPyTorchJob"
2626
TRAINING_OPERATOR_NAMESPACE = "aws-hyperpod"
27+
TRAINING_OPERATOR_POD_PREFIX = "hp-training-operator-hp-training-controller-manager-"
2728

2829

2930
class HyperPodPytorchJob(_HyperPodPytorchJob):
@@ -248,9 +249,19 @@ def get_operator_logs(cls, since_hours: float):
248249
f"No pod found in namespace {TRAINING_OPERATOR_NAMESPACE}"
249250
)
250251

251-
# Get logs from first pod
252-
first_pod = pods.items[0]
253-
pod_name = first_pod.metadata.name
252+
# Find the training operator pod
253+
operator_pod = None
254+
for pod in pods.items:
255+
if pod.metadata.name.startswith(TRAINING_OPERATOR_POD_PREFIX):
256+
operator_pod = pod
257+
break
258+
259+
if not operator_pod:
260+
raise Exception(
261+
f"No training operator pod found with prefix {TRAINING_OPERATOR_POD_PREFIX}"
262+
)
263+
264+
pod_name = operator_pod.metadata.name
254265

255266
try:
256267
logs = v1.read_namespaced_pod_log(

test/unit_tests/training/test_hyperpod_pytorch_job.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,16 +286,21 @@ def test_get_logs_from_pod_with_container_name(
286286
@patch("kubernetes.client.CoreV1Api")
287287
@patch.object(HyperPodPytorchJob, "verify_kube_config")
288288
def test_get_operator_logs(self, mock_verify_config, mock_core_api):
289-
mock_pod = MagicMock()
290-
mock_pod.metadata.name = "training-operator-pod"
291-
mock_core_api.return_value.list_namespaced_pod.return_value.items = [mock_pod]
289+
# Mock multiple pods, including the training operator pod
290+
mock_other_pod = MagicMock()
291+
mock_other_pod.metadata.name = "other-pod-123"
292+
293+
mock_operator_pod = MagicMock()
294+
mock_operator_pod.metadata.name = "hp-training-operator-hp-training-controller-manager-abc123"
295+
296+
mock_core_api.return_value.list_namespaced_pod.return_value.items = [mock_other_pod, mock_operator_pod]
292297
mock_core_api.return_value.read_namespaced_pod_log.return_value = "training operator logs"
293298

294299
result = HyperPodPytorchJob.get_operator_logs(2.5)
295300

296301
self.assertEqual(result, "training operator logs")
297302
mock_core_api.return_value.read_namespaced_pod_log.assert_called_once_with(
298-
name="training-operator-pod",
303+
name="hp-training-operator-hp-training-controller-manager-abc123",
299304
namespace="aws-hyperpod",
300305
timestamps=True,
301306
since_seconds=9000,

0 commit comments

Comments
 (0)