Skip to content

Commit a824151

Browse files
Remove command flag from init pytorch job integ test: (#351)
1 parent 170bf15 commit a824151

File tree

1 file changed

+35
-30
lines changed

1 file changed

+35
-30
lines changed

test/integration_tests/init/test_pytorch_job_creation.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,21 @@ def test_configure_pytorch_job(runner, pytorch_job_name, test_directory):
9090
configure, [
9191
# Required fields only
9292
"--job-name", pytorch_job_name,
93-
"--image", "pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel",
94-
"--command", '["python", "-c", "import torch; print(torch.__version__); import time; time.sleep(3600)"]',
93+
"--image", "448049793756.dkr.ecr.us-west-2.amazonaws.com/ptjob:mnist",
94+
"--pull-policy", "Always",
95+
"--tasks-per-node", "1",
96+
"--max-retry", "1"
9597
], catch_exceptions=False
9698
)
9799
assert_command_succeeded(result)
98100

99101
# Simplified expected_config
100102
expected_config = {
101103
"job_name": pytorch_job_name,
102-
"image": "pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel",
103-
"command": ["python", "-c", "import torch; print(torch.__version__); import time; time.sleep(3600)"],
104+
"image": "448049793756.dkr.ecr.us-west-2.amazonaws.com/ptjob:mnist",
105+
"pull_policy": "Always",
106+
"tasks_per_node": 1,
107+
"max_retry": 1
104108
}
105109
assert_config_values("./", expected_config)
106110

@@ -124,6 +128,31 @@ def test_create_pytorch_job(runner, pytorch_job_name, test_directory):
124128
assert pytorch_job_name in result.output
125129

126130

131+
@pytest.mark.dependency(name="list_pods", depends=["create"])
132+
def test_list_pods(pytorch_job_name, test_directory):
133+
"""Test listing pods for a specific job."""
134+
# Wait a moment to ensure pods are created
135+
time.sleep(10)
136+
137+
list_pods_result = execute_command([
138+
"hyp", "list-pods", "hyp-pytorch-job",
139+
"--job-name", pytorch_job_name,
140+
"--namespace", NAMESPACE
141+
])
142+
assert list_pods_result.returncode == 0
143+
144+
# Verify the output contains expected headers and job name
145+
output = list_pods_result.stdout.strip()
146+
assert f"Pods for job: {pytorch_job_name}" in output
147+
assert "POD NAME" in output
148+
assert "NAMESPACE" in output
149+
150+
# Verify at least one pod is listed (should contain the job name in the pod name)
151+
assert f"{pytorch_job_name}-pod-" in output
152+
153+
print(f"[INFO] Successfully listed pods for job: {pytorch_job_name}")
154+
155+
127156
@pytest.mark.dependency(name="wait", depends=["create"])
128157
def test_wait_for_job_running(pytorch_job_name, test_directory):
129158
"""Poll SDK until PyTorch job reaches Running state."""
@@ -158,31 +187,7 @@ def test_wait_for_job_running(pytorch_job_name, test_directory):
158187
pytest.fail(f"[ERROR] Timed out waiting for job {pytorch_job_name} to be Running")
159188

160189

161-
@pytest.mark.dependency(name="list_pods", depends=["wait"])
162-
def test_list_pods(pytorch_job_name, test_directory):
163-
"""Test listing pods for a specific job."""
164-
# Wait a moment to ensure pods are created
165-
time.sleep(10)
166-
167-
list_pods_result = execute_command([
168-
"hyp", "list-pods", "hyp-pytorch-job",
169-
"--job-name", pytorch_job_name,
170-
"--namespace", NAMESPACE
171-
])
172-
assert list_pods_result.returncode == 0
173-
174-
# Verify the output contains expected headers and job name
175-
output = list_pods_result.stdout.strip()
176-
assert f"Pods for job: {pytorch_job_name}" in output
177-
assert "POD NAME" in output
178-
assert "NAMESPACE" in output
179-
180-
# Verify at least one pod is listed (should contain the job name in the pod name)
181-
assert f"{pytorch_job_name}-pod-" in output
182-
183-
print(f"[INFO] Successfully listed pods for job: {pytorch_job_name}")
184-
185-
190+
@pytest.mark.run(order=99)
186191
@pytest.mark.dependency(depends=["create"])
187192
def test_pytorch_job_delete(pytorch_job_name, test_directory):
188193
"""Clean up deployed PyTorch job using CLI delete command and verify deletion."""
@@ -198,7 +203,7 @@ def test_pytorch_job_delete(pytorch_job_name, test_directory):
198203
time.sleep(5)
199204

200205
# Verify the job is no longer listed
201-
list_result = execute_command(["hyp", "list", "hyp-pytorch-job", "--namespace", NAMESPACE])
206+
list_result = execute_command(["hyp", "list", "hyp-pytorch-job"])
202207
assert list_result.returncode == 0
203208

204209
# The job name should no longer be in the output

0 commit comments

Comments
 (0)