Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/sagemaker/modules/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class SourceCode(BaseConfig):

Parameters:
source_dir (Optional[str]):
The local directory containing the source code to be used in the training job container.
The local directory, s3 uri, or path to tar.gz file stored locally or in s3 that contains
the source code to be used in the training job container.
requirements (Optional[str]):
The path within ``source_dir`` to a ``requirements.txt`` file. If specified, the listed
requirements will be installed in the training job container.
Expand Down
46 changes: 26 additions & 20 deletions src/sagemaker/modules/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,28 +407,26 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
"If 'requirements' or 'entry_script' is provided in 'source_code', "
+ "'source_dir' must also be provided.",
)
if not _is_valid_path(source_dir, path_type="Directory"):
if not _is_valid_path(source_dir) and not _is_valid_s3_uri(source_dir):
raise ValueError(
f"Invalid 'source_dir' path: {source_dir}. " + "Must be a valid directory.",
)
if requirements:
if not _is_valid_path(
f"{source_dir}/{requirements}",
path_type="File",
):
raise ValueError(
f"Invalid 'requirements': {requirements}. "
+ "Must be a valid file within the 'source_dir'.",
)
if not source_dir.endswith(".tar.gz"):
if (not _is_valid_path(f"{source_dir}/{requirements}", path_type="File")
and not _is_valid_s3_uri(f"{source_dir}/{requirements}", path_type="File")):
raise ValueError(
f"Invalid 'requirements': {requirements}. "
+ "Must be a valid file within the 'source_dir'.",
)
if entry_script:
if not _is_valid_path(
f"{source_dir}/{entry_script}",
path_type="File",
):
raise ValueError(
f"Invalid 'entry_script': {entry_script}. "
+ "Must be a valid file within the 'source_dir'.",
)
if not source_dir.endswith(".tar.gz"):
if (not _is_valid_path(f"{source_dir}/{entry_script}", path_type="File")
and not _is_valid_s3_uri(f"{source_dir}/{entry_script}", path_type="File")):
raise ValueError(
f"Invalid 'entry_script': {entry_script}. "
+ "Must be a valid file within the 'source_dir'.",
)

def model_post_init(self, __context: Any):
"""Post init method to perform custom validation and set default values."""
Expand Down Expand Up @@ -838,12 +836,20 @@ def _prepare_train_script(

install_requirements = ""
if source_code.requirements:
install_requirements = "echo 'Installing requirements'\n"
install_requirements = f"$SM_PIP_CMD install -r {source_code.requirements}"
install_requirements = (
"echo 'Installing requirements'\n" +
f"$SM_PIP_CMD install -r {source_code.requirements}"
)

working_dir = ""
if source_code.source_dir:
working_dir = f"cd {SM_CODE_CONTAINER_PATH}"
working_dir = f"cd {SM_CODE_CONTAINER_PATH} \n"
if source_code.source_dir.endswith(".tar.gz"):
if source_code.source_dir.startswith("s3://"):
tarfile_name = os.path.basename(source_code.source_dir)
else:
tarfile_name = source_code.source_dir
working_dir += f"tar --strip-components=1 -xzf {tarfile_name} \n"

if base_command:
execute_driver = EXECUTE_BASE_COMMANDS.format(base_command=base_command)
Expand Down
Loading