Skip to content

Commit 9fd23a9

Browse files
committed
Make helper methods internal
1 parent 3fc1eff commit 9fd23a9

File tree

5 files changed

+14
-15
lines changed

5 files changed

+14
-15
lines changed

sagemaker-train/pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ dependencies = [
4242
"jinja2>=3.0,<4.0",
4343
"sagemaker-mlflow>=0.0.1,<1.0.0",
4444
"mlflow>=3.0.0,<4.0.0",
45-
"nest_asyncio>=1.5.0",
4645
]
4746

4847
[project.urls]

sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from sagemaker.train.aws_batch.boto_client import get_batch_boto_client
2424

2525

26-
def submit_service_job(
26+
def _submit_service_job(
2727
training_payload: Dict,
2828
job_name: str,
2929
job_queue: str,
@@ -71,7 +71,7 @@ def submit_service_job(
7171
return client.submit_service_job(**payload)
7272

7373

74-
def describe_service_job(job_id: str) -> Dict:
74+
def _describe_service_job(job_id: str) -> Dict:
7575
"""Batch describe_service_job API helper function.
7676
7777
Args:
@@ -119,7 +119,7 @@ def describe_service_job(job_id: str) -> Dict:
119119
return client.describe_service_job(jobId=job_id)
120120

121121

122-
def terminate_service_job(job_id: str, reason: Optional[str] = "default terminate reason") -> Dict:
122+
def _terminate_service_job(job_id: str, reason: Optional[str] = "default terminate reason") -> Dict:
123123
"""Batch terminate_service_job API helper function.
124124
125125
Args:
@@ -132,7 +132,7 @@ def terminate_service_job(job_id: str, reason: Optional[str] = "default terminat
132132
return client.terminate_service_job(jobId=job_id, reason=reason)
133133

134134

135-
def list_service_job(
135+
def _list_service_job(
136136
job_queue: str,
137137
job_status: Optional[str] = None,
138138
filters: Optional[List] = None,
@@ -161,7 +161,7 @@ def list_service_job(
161161
next_token = part_of_jobs.get("nextToken")
162162
yield part_of_jobs
163163
if next_token:
164-
yield from list_service_job(job_queue, job_status, filters, next_token)
164+
yield from _list_service_job(job_queue, job_status, filters, next_token)
165165

166166

167167
def __merge_tags(batch_tags: Optional[Dict], training_tags: Optional[List]) -> Optional[Dict]:

sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import logging
1818
from sagemaker.train.model_trainer import ModelTrainer, Mode
1919
from .training_queued_job import TrainingQueuedJob
20-
from .batch_api_helper import submit_service_job, list_service_job
20+
from .batch_api_helper import _submit_service_job, _list_service_job
2121
from .exception import MissingRequiredArgument
2222
from .constants import DEFAULT_TIMEOUT, JOB_STATUS_RUNNING
2323

@@ -85,7 +85,7 @@ def submit(
8585
if job_name is None:
8686
job_name = training_payload["TrainingJobName"]
8787

88-
resp = submit_service_job(
88+
resp = _submit_service_job(
8989
training_payload,
9090
job_name,
9191
self.queue_name,
@@ -177,7 +177,7 @@ def list_jobs(
177177
status = None # job_status is ignored when job_name is specified.
178178
jobs_to_return = []
179179
next_token = None
180-
for job_result_dict in list_service_job(self.queue_name, status, filters, next_token):
180+
for job_result_dict in _list_service_job(self.queue_name, status, filters, next_token):
181181
for job_result in job_result_dict.get("jobSummaryList", []):
182182
if "jobArn" in job_result and "jobName" in job_result:
183183
jobs_to_return.append(

sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
SourceCode,
3030
TrainingImageConfig,
3131
)
32-
from .batch_api_helper import terminate_service_job, describe_service_job
32+
from .batch_api_helper import _terminate_service_job, _describe_service_job
3333
from .exception import NoTrainingJob, MissingRequiredArgument
34-
from ..utils import get_training_job_name_from_training_job_arn
34+
from ..utils import _get_training_job_name_from_training_job_arn
3535
from .constants import JOB_STATUS_COMPLETED, JOB_STATUS_FAILED, POLL_IN_SECONDS
3636

3737
logging.basicConfig(
@@ -82,15 +82,15 @@ def terminate(self, reason: Optional[str] = "Default terminate reason") -> None:
8282
Returns: None
8383
8484
"""
85-
terminate_service_job(self.job_arn, reason)
85+
_terminate_service_job(self.job_arn, reason)
8686

8787
def describe(self) -> Dict:
8888
"""Describe Batch job.
8989
9090
Returns: A dict which includes job parameters, job status, attempts and so on.
9191
9292
"""
93-
return describe_service_job(self.job_arn)
93+
return _describe_service_job(self.job_arn)
9494

9595
def _training_job_created(self, status: str) -> bool:
9696
"""Return True if a Training job has been created
@@ -320,7 +320,7 @@ def _get_new_training_job_name_from_latest_attempt(latest_attempt: Dict) -> str:
320320
321321
"""
322322
training_job_arn = latest_attempt.get("serviceResourceId", {}).get("value", None)
323-
return get_training_job_name_from_training_job_arn(training_job_arn)
323+
return _get_training_job_name_from_training_job_arn(training_job_arn)
324324

325325

326326
def _remove_system_tags_in_place_in_model_trainer_object(model_trainer: ModelTrainer) -> None:

sagemaker-train/src/sagemaker/train/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def _get_studio_tags(model_id: str, hub_name: str):
249249
]
250250

251251

252-
def get_training_job_name_from_training_job_arn(training_job_arn: str) -> str:
252+
def _get_training_job_name_from_training_job_arn(training_job_arn: str) -> str:
253253
"""Extract Training job name from Training job arn.
254254
Args:
255255
training_job_arn: Training job arn.

0 commit comments

Comments
 (0)