Skip to content

Commit 05235c1

Browse files
authored
Merge branch 'master' into master
2 parents f3e221b + 9ba3997 commit 05235c1

File tree

19 files changed

+1151
-66
lines changed

19 files changed

+1151
-66
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
{
2+
"processors": ["cpu", "gpu"],
3+
"scope": ["inference"],
4+
"version_aliases": {
5+
"3.0": "3.0.0"
6+
},
7+
"versions": {
8+
"3.0.0": {
9+
"registries": {
10+
"us-east-1": "885854791233",
11+
"us-east-2": "137914896644",
12+
"us-west-1": "053634841547",
13+
"us-west-2": "542918446943",
14+
"af-south-1": "238384257742",
15+
"ap-east-1": "523751269255",
16+
"ap-south-1": "245090515133",
17+
"ap-northeast-2": "064688005998",
18+
"ap-southeast-1": "022667117163",
19+
"ap-southeast-2": "648430277019",
20+
"ap-northeast-1": "010972774902",
21+
"ca-central-1": "481561238223",
22+
"eu-central-1": "545423591354",
23+
"eu-west-1": "819792524951",
24+
"eu-west-2": "021081402939",
25+
"eu-west-3": "856416204555",
26+
"eu-north-1": "175620155138",
27+
"eu-south-1": "810671768855",
28+
"sa-east-1": "567556641782",
29+
"ap-northeast-3": "564864627153",
30+
"ap-southeast-3": "370607712162",
31+
"me-south-1": "523774347010",
32+
"me-central-1": "358593528301"
33+
},
34+
"repository": "sagemaker-distribution-prod"
35+
}
36+
}
37+
}

src/sagemaker/modules/configs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ class SourceCode(BaseConfig):
8888
8989
Parameters:
9090
source_dir (Optional[str]):
91-
The local directory containing the source code to be used in the training job container.
91+
The local directory, s3 uri, or path to tar.gz file stored locally or in s3 that contains
92+
the source code to be used in the training job container.
9293
requirements (Optional[str]):
9394
The path within ``source_dir`` to a ``requirements.txt`` file. If specified, the listed
9495
requirements will be installed in the training job container.

src/sagemaker/modules/train/model_trainer.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -407,28 +407,45 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
407407
"If 'requirements' or 'entry_script' is provided in 'source_code', "
408408
+ "'source_dir' must also be provided.",
409409
)
410-
if not _is_valid_path(source_dir, path_type="Directory"):
410+
if not (
411+
_is_valid_path(source_dir, path_type="Directory")
412+
or _is_valid_s3_uri(source_dir, path_type="Directory")
413+
or (
414+
_is_valid_path(source_dir, path_type="File")
415+
and source_dir.endswith(".tar.gz")
416+
)
417+
or (
418+
_is_valid_s3_uri(source_dir, path_type="File")
419+
and source_dir.endswith(".tar.gz")
420+
)
421+
):
411422
raise ValueError(
412-
f"Invalid 'source_dir' path: {source_dir}. " + "Must be a valid directory.",
423+
f"Invalid 'source_dir' path: {source_dir}. "
424+
+ "Must be a valid local directory, "
425+
"s3 uri or path to tar.gz file stored locally or in s3.",
413426
)
414427
if requirements:
415-
if not _is_valid_path(
416-
f"{source_dir}/{requirements}",
417-
path_type="File",
418-
):
419-
raise ValueError(
420-
f"Invalid 'requirements': {requirements}. "
421-
+ "Must be a valid file within the 'source_dir'.",
422-
)
428+
if not source_dir.endswith(".tar.gz"):
429+
if not _is_valid_path(
430+
f"{source_dir}/{requirements}", path_type="File"
431+
) and not _is_valid_s3_uri(
432+
f"{source_dir}/{requirements}", path_type="File"
433+
):
434+
raise ValueError(
435+
f"Invalid 'requirements': {requirements}. "
436+
+ "Must be a valid file within the 'source_dir'.",
437+
)
423438
if entry_script:
424-
if not _is_valid_path(
425-
f"{source_dir}/{entry_script}",
426-
path_type="File",
427-
):
428-
raise ValueError(
429-
f"Invalid 'entry_script': {entry_script}. "
430-
+ "Must be a valid file within the 'source_dir'.",
431-
)
439+
if not source_dir.endswith(".tar.gz"):
440+
if not _is_valid_path(
441+
f"{source_dir}/{entry_script}", path_type="File"
442+
) and not _is_valid_s3_uri(
443+
f"{source_dir}/{entry_script}", path_type="File"
444+
):
445+
raise ValueError(
446+
f"Invalid 'entry_script': {entry_script}. "
447+
+ "Must be a valid file within the 'source_dir'.",
448+
)
432449

433450
def model_post_init(self, __context: Any):
434451
"""Post init method to perform custom validation and set default values."""
@@ -838,12 +855,17 @@ def _prepare_train_script(
838855

839856
install_requirements = ""
840857
if source_code.requirements:
841-
install_requirements = "echo 'Installing requirements'\n"
842-
install_requirements = f"$SM_PIP_CMD install -r {source_code.requirements}"
858+
install_requirements = (
859+
"echo 'Installing requirements'\n"
860+
+ f"$SM_PIP_CMD install -r {source_code.requirements}"
861+
)
843862

844863
working_dir = ""
845864
if source_code.source_dir:
846-
working_dir = f"cd {SM_CODE_CONTAINER_PATH}"
865+
working_dir = f"cd {SM_CODE_CONTAINER_PATH} \n"
866+
if source_code.source_dir.endswith(".tar.gz"):
867+
tarfile_name = os.path.basename(source_code.source_dir)
868+
working_dir += f"tar --strip-components=1 -xzf {tarfile_name} \n"
847869

848870
if base_command:
849871
execute_driver = EXECUTE_BASE_COMMANDS.format(base_command=base_command)

0 commit comments

Comments
 (0)