Skip to content

Commit 5148e14

Browse files
committed
Updating unit tests for internal-external method changes
1 parent 848dff9 commit 5148e14

File tree

3 files changed

+42
-42
lines changed

3 files changed

+42
-42
lines changed

sagemaker-train/tests/unit/train/aws_batch/test_batch_api_helper.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from unittest.mock import Mock, patch, MagicMock
1818

1919
from sagemaker.train.aws_batch.batch_api_helper import (
20-
submit_service_job,
21-
describe_service_job,
22-
terminate_service_job,
23-
list_service_job,
20+
_submit_service_job,
21+
_describe_service_job,
22+
_terminate_service_job,
23+
_list_service_job,
2424
)
2525
from .conftest import (
2626
JOB_NAME,
@@ -56,7 +56,7 @@ def test_submit_service_job_basic(self, mock_get_client):
5656
mock_client.submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP
5757
mock_get_client.return_value = mock_client
5858

59-
result = submit_service_job(
59+
result = _submit_service_job(
6060
TRAINING_JOB_PAYLOAD,
6161
JOB_NAME,
6262
JOB_QUEUE,
@@ -74,7 +74,7 @@ def test_submit_service_job_with_all_params(self, mock_get_client):
7474
mock_client.submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP
7575
mock_get_client.return_value = mock_client
7676

77-
result = submit_service_job(
77+
result = _submit_service_job(
7878
TRAINING_JOB_PAYLOAD,
7979
JOB_NAME,
8080
JOB_QUEUE,
@@ -103,7 +103,7 @@ def test_submit_service_job_with_tags(self, mock_get_client):
103103
payload = TRAINING_JOB_PAYLOAD.copy()
104104
payload["Tags"] = TRAINING_TAGS
105105

106-
result = submit_service_job(
106+
result = _submit_service_job(
107107
payload,
108108
JOB_NAME,
109109
JOB_QUEUE,
@@ -125,7 +125,7 @@ def test_submit_service_job_payload_serialized(self, mock_get_client):
125125
mock_client.submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP
126126
mock_get_client.return_value = mock_client
127127

128-
submit_service_job(
128+
_submit_service_job(
129129
TRAINING_JOB_PAYLOAD,
130130
JOB_NAME,
131131
JOB_QUEUE,
@@ -148,7 +148,7 @@ def test_describe_service_job(self, mock_get_client):
148148
mock_client.describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_RUNNING
149149
mock_get_client.return_value = mock_client
150150

151-
result = describe_service_job(JOB_ID)
151+
result = _describe_service_job(JOB_ID)
152152

153153
assert result["jobId"] == JOB_ID
154154
assert result["status"] == "RUNNING"
@@ -165,7 +165,7 @@ def test_terminate_service_job(self, mock_get_client):
165165
mock_client.terminate_service_job.return_value = {}
166166
mock_get_client.return_value = mock_client
167167

168-
result = terminate_service_job(JOB_ID, REASON)
168+
result = _terminate_service_job(JOB_ID, REASON)
169169

170170
assert result == {}
171171
mock_client.terminate_service_job.assert_called_once_with(
@@ -179,7 +179,7 @@ def test_terminate_service_job_default_reason(self, mock_get_client):
179179
mock_client.terminate_service_job.return_value = {}
180180
mock_get_client.return_value = mock_client
181181

182-
terminate_service_job(JOB_ID)
182+
_terminate_service_job(JOB_ID)
183183

184184
call_kwargs = mock_client.terminate_service_job.call_args[1]
185185
assert call_kwargs["jobId"] == JOB_ID
@@ -196,7 +196,7 @@ def test_list_service_job_empty(self, mock_get_client):
196196
mock_client.list_service_jobs.return_value = LIST_SERVICE_JOB_RESP_EMPTY
197197
mock_get_client.return_value = mock_client
198198

199-
gen = list_service_job(JOB_QUEUE)
199+
gen = _list_service_job(JOB_QUEUE)
200200
result = next(gen)
201201

202202
assert result["jobSummaryList"] == []
@@ -209,7 +209,7 @@ def test_list_service_job_with_jobs(self, mock_get_client):
209209
mock_client.list_service_jobs.return_value = LIST_SERVICE_JOB_RESP_WITH_JOBS
210210
mock_get_client.return_value = mock_client
211211

212-
gen = list_service_job(JOB_QUEUE)
212+
gen = _list_service_job(JOB_QUEUE)
213213
result = next(gen)
214214

215215
assert len(result["jobSummaryList"]) == 2
@@ -225,7 +225,7 @@ def test_list_service_job_with_pagination(self, mock_get_client):
225225
]
226226
mock_get_client.return_value = mock_client
227227

228-
gen = list_service_job(JOB_QUEUE)
228+
gen = _list_service_job(JOB_QUEUE)
229229
first_result = next(gen)
230230
assert first_result["nextToken"] == NEXT_TOKEN
231231

@@ -240,7 +240,7 @@ def test_list_service_job_with_filters(self, mock_get_client):
240240
mock_get_client.return_value = mock_client
241241

242242
filters = [{"name": "JOB_NAME", "values": [JOB_NAME]}]
243-
gen = list_service_job(JOB_QUEUE, filters=filters)
243+
gen = _list_service_job(JOB_QUEUE, filters=filters)
244244
result = next(gen)
245245

246246
call_kwargs = mock_client.list_service_jobs.call_args[1]
@@ -253,7 +253,7 @@ def test_list_service_job_with_status(self, mock_get_client):
253253
mock_client.list_service_jobs.return_value = LIST_SERVICE_JOB_RESP_WITH_JOBS
254254
mock_get_client.return_value = mock_client
255255

256-
gen = list_service_job(JOB_QUEUE, job_status=JOB_STATUS_RUNNING)
256+
gen = _list_service_job(JOB_QUEUE, job_status=JOB_STATUS_RUNNING)
257257
result = next(gen)
258258

259259
call_kwargs = mock_client.list_service_jobs.call_args[1]

sagemaker-train/tests/unit/train/aws_batch/test_training_queue.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_training_queue_init(self):
4646
class TestTrainingQueueSubmit:
4747
"""Tests for TrainingQueue.submit method"""
4848

49-
@patch("sagemaker.train.aws_batch.training_queue.submit_service_job")
49+
@patch("sagemaker.train.aws_batch.training_queue._submit_service_job")
5050
def test_submit_model_trainer(self, mock_submit_service_job):
5151
"""Test submit with ModelTrainer"""
5252
mock_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP
@@ -71,7 +71,7 @@ def test_submit_model_trainer(self, mock_submit_service_job):
7171
assert queued_job.job_arn == JOB_ARN
7272
mock_submit_service_job.assert_called_once()
7373

74-
@patch("sagemaker.train.aws_batch.training_queue.submit_service_job")
74+
@patch("sagemaker.train.aws_batch.training_queue._submit_service_job")
7575
def test_submit_with_default_timeout(self, mock_submit_service_job):
7676
"""Test submit uses default timeout when not provided"""
7777
mock_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP
@@ -96,7 +96,7 @@ def test_submit_with_default_timeout(self, mock_submit_service_job):
9696
# Timeout should be set to default
9797
assert call_kwargs[5] is not None
9898

99-
@patch("sagemaker.train.aws_batch.training_queue.submit_service_job")
99+
@patch("sagemaker.train.aws_batch.training_queue._submit_service_job")
100100
def test_submit_with_generated_job_name(self, mock_submit_service_job):
101101
"""Test submit generates job name from payload if not provided"""
102102
mock_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP
@@ -156,7 +156,7 @@ def test_submit_invalid_training_mode(self):
156156
BATCH_TAGS,
157157
)
158158

159-
@patch("sagemaker.train.aws_batch.training_queue.submit_service_job")
159+
@patch("sagemaker.train.aws_batch.training_queue._submit_service_job")
160160
def test_submit_missing_job_arn_in_response(self, mock_submit_service_job):
161161
"""Test submit raises error when jobArn missing from response"""
162162
mock_submit_service_job.return_value = {"jobName": JOB_NAME} # Missing jobArn
@@ -183,7 +183,7 @@ def test_submit_missing_job_arn_in_response(self, mock_submit_service_job):
183183
class TestTrainingQueueMap:
184184
"""Tests for TrainingQueue.map method"""
185185

186-
@patch("sagemaker.train.aws_batch.training_queue.submit_service_job")
186+
@patch("sagemaker.train.aws_batch.training_queue._submit_service_job")
187187
def test_map_multiple_inputs(self, mock_submit_service_job):
188188
"""Test map submits multiple jobs"""
189189
mock_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP
@@ -208,7 +208,7 @@ def test_map_multiple_inputs(self, mock_submit_service_job):
208208
assert len(queued_jobs) == 3
209209
assert mock_submit_service_job.call_count == 3
210210

211-
@patch("sagemaker.train.aws_batch.training_queue.submit_service_job")
211+
@patch("sagemaker.train.aws_batch.training_queue._submit_service_job")
212212
def test_map_with_job_names(self, mock_submit_service_job):
213213
"""Test map with explicit job names"""
214214
mock_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP
@@ -258,7 +258,7 @@ def test_map_mismatched_job_names_length(self):
258258
class TestTrainingQueueList:
259259
"""Tests for TrainingQueue.list_jobs method"""
260260

261-
@patch("sagemaker.train.aws_batch.training_queue.list_service_job")
261+
@patch("sagemaker.train.aws_batch.training_queue._list_service_job")
262262
def test_list_jobs_default(self, mock_list_service_job):
263263
"""Test list_jobs with default parameters"""
264264
mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_RESP_WITH_JOBS])
@@ -269,7 +269,7 @@ def test_list_jobs_default(self, mock_list_service_job):
269269
assert len(jobs) == 2
270270
assert jobs[0].job_name == JOB_NAME
271271

272-
@patch("sagemaker.train.aws_batch.training_queue.list_service_job")
272+
@patch("sagemaker.train.aws_batch.training_queue._list_service_job")
273273
def test_list_jobs_with_name_filter(self, mock_list_service_job):
274274
"""Test list_jobs with job name filter"""
275275
mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_RESP_WITH_JOBS])
@@ -292,7 +292,7 @@ def test_list_jobs_with_name_filter(self, mock_list_service_job):
292292
assert filters[0]["name"] == "JOB_NAME", "JOB_NAME filter should be present"
293293
assert filters[0]["values"] == [JOB_NAME], "Filter values should contain the job name"
294294

295-
@patch("sagemaker.train.aws_batch.training_queue.list_service_job")
295+
@patch("sagemaker.train.aws_batch.training_queue._list_service_job")
296296
def test_list_jobs_empty(self, mock_list_service_job):
297297
"""Test list_jobs returns empty list"""
298298
mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_RESP_EMPTY])
@@ -306,7 +306,7 @@ def test_list_jobs_empty(self, mock_list_service_job):
306306
class TestTrainingQueueGet:
307307
"""Tests for TrainingQueue.get_job method"""
308308

309-
@patch("sagemaker.train.aws_batch.training_queue.list_service_job")
309+
@patch("sagemaker.train.aws_batch.training_queue._list_service_job")
310310
def test_get_job_found(self, mock_list_service_job):
311311
"""Test get_job returns job when found"""
312312
mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_RESP_WITH_JOBS])
@@ -317,7 +317,7 @@ def test_get_job_found(self, mock_list_service_job):
317317
assert job.job_name == JOB_NAME
318318
assert job.job_arn == JOB_ARN
319319

320-
@patch("sagemaker.train.aws_batch.training_queue.list_service_job")
320+
@patch("sagemaker.train.aws_batch.training_queue._list_service_job")
321321
def test_get_job_not_found(self, mock_list_service_job):
322322
"""Test get_job raises error when job not found"""
323323
mock_list_service_job.return_value = iter([LIST_SERVICE_JOB_RESP_EMPTY])

sagemaker-train/tests/unit/train/aws_batch/test_training_queued_job.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_training_queued_job_init(self):
4949
class TestTrainingQueuedJobDescribe:
5050
"""Tests for TrainingQueuedJob.describe method"""
5151

52-
@patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job")
52+
@patch("sagemaker.train.aws_batch.training_queued_job._describe_service_job")
5353
def test_describe(self, mock_describe_service_job):
5454
"""Test describe returns job details"""
5555
mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_RUNNING
@@ -64,7 +64,7 @@ def test_describe(self, mock_describe_service_job):
6464
class TestTrainingQueuedJobTerminate:
6565
"""Tests for TrainingQueuedJob.terminate method"""
6666

67-
@patch("sagemaker.train.aws_batch.training_queued_job.terminate_service_job")
67+
@patch("sagemaker.train.aws_batch.training_queued_job._terminate_service_job")
6868
def test_terminate(self, mock_terminate_service_job):
6969
"""Test terminate calls terminate API"""
7070
mock_terminate_service_job.return_value = {}
@@ -74,7 +74,7 @@ def test_terminate(self, mock_terminate_service_job):
7474

7575
mock_terminate_service_job.assert_called_once_with(JOB_ARN, REASON)
7676

77-
@patch("sagemaker.train.aws_batch.training_queued_job.terminate_service_job")
77+
@patch("sagemaker.train.aws_batch.training_queued_job._terminate_service_job")
7878
def test_terminate_default_reason(self, mock_terminate_service_job):
7979
"""Test terminate with default reason"""
8080
mock_terminate_service_job.return_value = {}
@@ -89,7 +89,7 @@ def test_terminate_default_reason(self, mock_terminate_service_job):
8989
class TestTrainingQueuedJobWait:
9090
"""Tests for TrainingQueuedJob.wait method"""
9191

92-
@patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job")
92+
@patch("sagemaker.train.aws_batch.training_queued_job._describe_service_job")
9393
def test_wait_immediate_completion(self, mock_describe_service_job):
9494
"""Test wait returns immediately when job is completed"""
9595
mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED
@@ -99,7 +99,7 @@ def test_wait_immediate_completion(self, mock_describe_service_job):
9999

100100
assert result["status"] == JOB_STATUS_SUCCEEDED
101101

102-
@patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job")
102+
@patch("sagemaker.train.aws_batch.training_queued_job._describe_service_job")
103103
def test_wait_with_polling(self, mock_describe_service_job):
104104
"""Test wait polls until job completes"""
105105
mock_describe_service_job.side_effect = [
@@ -114,7 +114,7 @@ def test_wait_with_polling(self, mock_describe_service_job):
114114
assert result["status"] == JOB_STATUS_SUCCEEDED
115115
assert mock_describe_service_job.call_count == 3
116116

117-
@patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job")
117+
@patch("sagemaker.train.aws_batch.training_queued_job._describe_service_job")
118118
def test_wait_with_timeout(self, mock_describe_service_job):
119119
"""Test wait respects timeout"""
120120
mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_RUNNING
@@ -128,7 +128,7 @@ def test_wait_with_timeout(self, mock_describe_service_job):
128128
assert end_time - start_time >= 2
129129
assert result["status"] == JOB_STATUS_RUNNING
130130

131-
@patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job")
131+
@patch("sagemaker.train.aws_batch.training_queued_job._describe_service_job")
132132
def test_wait_job_failed(self, mock_describe_service_job):
133133
"""Test wait returns failed status"""
134134
mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_FAILED
@@ -144,7 +144,7 @@ class TestTrainingQueuedJobGetModelTrainer:
144144

145145
@patch("sagemaker.train.aws_batch.training_queued_job._remove_system_tags_in_place_in_model_trainer_object")
146146
@patch("sagemaker.train.aws_batch.training_queued_job._construct_model_trainer_from_training_job_name")
147-
@patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job")
147+
@patch("sagemaker.train.aws_batch.training_queued_job._describe_service_job")
148148
def test_get_model_trainer_success(self, mock_describe_service_job, mock_construct_trainer, mock_remove_tags):
149149
"""Test get_model_trainer returns ModelTrainer when training job created"""
150150
# Return a real dict (not a mock) so nested dict access works
@@ -160,7 +160,7 @@ def test_get_model_trainer_success(self, mock_describe_service_job, mock_constru
160160
mock_construct_trainer.assert_called_once()
161161
mock_remove_tags.assert_called_once_with(mock_trainer)
162162

163-
@patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job")
163+
@patch("sagemaker.train.aws_batch.training_queued_job._describe_service_job")
164164
def test_get_model_trainer_no_training_job_pending(self, mock_describe_service_job):
165165
"""Test get_model_trainer raises error when job still pending"""
166166
mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_PENDING
@@ -170,7 +170,7 @@ def test_get_model_trainer_no_training_job_pending(self, mock_describe_service_j
170170
with pytest.raises(NoTrainingJob):
171171
queued_job.get_model_trainer()
172172

173-
@patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job")
173+
@patch("sagemaker.train.aws_batch.training_queued_job._describe_service_job")
174174
def test_get_model_trainer_no_latest_attempt(self, mock_describe_service_job):
175175
"""Test get_model_trainer raises error when latestAttempt missing"""
176176
resp = DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED.copy()
@@ -186,7 +186,7 @@ def test_get_model_trainer_no_latest_attempt(self, mock_describe_service_job):
186186
class TestTrainingQueuedJobResult:
187187
"""Tests for TrainingQueuedJob.result method"""
188188

189-
@patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job")
189+
@patch("sagemaker.train.aws_batch.training_queued_job._describe_service_job")
190190
def test_result_success(self, mock_describe_service_job):
191191
"""Test result returns job result when completed"""
192192
mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED
@@ -196,7 +196,7 @@ def test_result_success(self, mock_describe_service_job):
196196

197197
assert result["status"] == JOB_STATUS_SUCCEEDED
198198

199-
@patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job")
199+
@patch("sagemaker.train.aws_batch.training_queued_job._describe_service_job")
200200
def test_result_timeout(self, mock_describe_service_job):
201201
"""Test result raises TimeoutError when timeout exceeded"""
202202
mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_RUNNING
@@ -210,7 +210,7 @@ def test_result_timeout(self, mock_describe_service_job):
210210
class TestTrainingQueuedJobAsync:
211211
"""Tests for TrainingQueuedJob async methods"""
212212

213-
@patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job")
213+
@patch("sagemaker.train.aws_batch.training_queued_job._describe_service_job")
214214
def test_fetch_job_results_success(self, mock_describe_service_job):
215215
"""Test fetch_job_results returns result when job succeeds"""
216216
mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED
@@ -220,7 +220,7 @@ def test_fetch_job_results_success(self, mock_describe_service_job):
220220

221221
assert result["status"] == JOB_STATUS_SUCCEEDED
222222

223-
@patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job")
223+
@patch("sagemaker.train.aws_batch.training_queued_job._describe_service_job")
224224
def test_fetch_job_results_failed(self, mock_describe_service_job):
225225
"""Test fetch_job_results raises error when job fails"""
226226
mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_FAILED
@@ -230,7 +230,7 @@ def test_fetch_job_results_failed(self, mock_describe_service_job):
230230
with pytest.raises(RuntimeError):
231231
asyncio.run(queued_job.fetch_job_results())
232232

233-
@patch("sagemaker.train.aws_batch.training_queued_job.describe_service_job")
233+
@patch("sagemaker.train.aws_batch.training_queued_job._describe_service_job")
234234
def test_fetch_job_results_timeout(self, mock_describe_service_job):
235235
"""Test fetch_job_results raises TimeoutError when timeout exceeded"""
236236
mock_describe_service_job.return_value = DESCRIBE_SERVICE_JOB_RESP_RUNNING

0 commit comments

Comments
 (0)