Skip to content

Commit 18f43d1

Browse files
committed
update ModelTrainer to support s3 uri and tar.gz file as source_dir
1 parent 2dd7111 commit 18f43d1

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

src/sagemaker/modules/configs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ class SourceCode(BaseConfig):
8888
8989
Parameters:
9090
source_dir (Optional[str]):
91-
The local directory containing the source code to be used in the training job container.
91+
The local directory, s3 uri, or path to tar.gz file stored locally or in s3 that contains
92+
the source code to be used in the training job container.
9293
requirements (Optional[str]):
9394
The path within ``source_dir`` to a ``requirements.txt`` file. If specified, the listed
9495
requirements will be installed in the training job container.

src/sagemaker/modules/train/model_trainer.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -407,28 +407,26 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
407407
"If 'requirements' or 'entry_script' is provided in 'source_code', "
408408
+ "'source_dir' must also be provided.",
409409
)
410-
if not _is_valid_path(source_dir, path_type="Directory") or _is_valid_s3_uri(source_dir, path_type="Directory"):
410+
if not _is_valid_path(source_dir) and not _is_valid_s3_uri(source_dir):
411411
raise ValueError(
412412
f"Invalid 'source_dir' path: {source_dir}. " + "Must be a valid directory.",
413413
)
414414
if requirements:
415-
if not _is_valid_path(
416-
f"{source_dir}/{requirements}",
417-
path_type="File",
418-
):
419-
raise ValueError(
420-
f"Invalid 'requirements': {requirements}. "
421-
+ "Must be a valid file within the 'source_dir'.",
422-
)
415+
if not source_dir.endswith(".tar.gz"):
416+
if (not _is_valid_path(f"{source_dir}/{requirements}", path_type="File")
417+
and not _is_valid_s3_uri(f"{source_dir}/{requirements}", path_type="File")):
418+
raise ValueError(
419+
f"Invalid 'requirements': {requirements}. "
420+
+ "Must be a valid file within the 'source_dir'.",
421+
)
423422
if entry_script:
424-
if not _is_valid_path(
425-
f"{source_dir}/{entry_script}",
426-
path_type="File",
427-
):
428-
raise ValueError(
429-
f"Invalid 'entry_script': {entry_script}. "
430-
+ "Must be a valid file within the 'source_dir'.",
431-
)
423+
if not source_dir.endswith(".tar.gz"):
424+
if (not _is_valid_path(f"{source_dir}/{entry_script}", path_type="File")
425+
and not _is_valid_s3_uri(f"{source_dir}/{entry_script}", path_type="File")):
426+
raise ValueError(
427+
f"Invalid 'entry_script': {entry_script}. "
428+
+ "Must be a valid file within the 'source_dir'.",
429+
)
432430

433431
def model_post_init(self, __context: Any):
434432
"""Post init method to perform custom validation and set default values."""
@@ -838,12 +836,20 @@ def _prepare_train_script(
838836

839837
install_requirements = ""
840838
if source_code.requirements:
841-
install_requirements = "echo 'Installing requirements'\n"
842-
install_requirements = f"$SM_PIP_CMD install -r {source_code.requirements}"
839+
install_requirements = (
840+
"echo 'Installing requirements'\n" +
841+
f"$SM_PIP_CMD install -r {source_code.requirements}"
842+
)
843843

844844
working_dir = ""
845845
if source_code.source_dir:
846-
working_dir = f"cd {SM_CODE_CONTAINER_PATH}"
846+
working_dir = f"cd {SM_CODE_CONTAINER_PATH} \n"
847+
if source_code.source_dir.endswith(".tar.gz"):
848+
if source_code.source_dir.startswith("s3://"):
849+
tarfile_name = os.path.basename(source_code.source_dir)
850+
else:
851+
tarfile_name = source_code.source_dir
852+
working_dir += f"tar --strip-components=1 -xzf {tarfile_name} \n"
847853

848854
if base_command:
849855
execute_driver = EXECUTE_BASE_COMMANDS.format(base_command=base_command)

0 commit comments

Comments
 (0)