Skip to content

Commit 2b1052d

Browse files
committed
update logic and unit test to raise value error if the file is not .tar.gz
1 parent db6160b commit 2b1052d

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

src/sagemaker/modules/train/model_trainer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,18 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
407407
"If 'requirements' or 'entry_script' is provided in 'source_code', "
408408
+ "'source_dir' must also be provided.",
409409
)
410-
if not _is_valid_path(source_dir) and not _is_valid_s3_uri(source_dir):
410+
if not (
411+
_is_valid_path(source_dir, path_type="Directory")
412+
or _is_valid_s3_uri(source_dir, path_type="Directory")
413+
or (
414+
_is_valid_path(source_dir, path_type="File")
415+
and source_dir.endswith(".tar.gz")
416+
)
417+
or (
418+
_is_valid_s3_uri(source_dir, path_type="File")
419+
and source_dir.endswith(".tar.gz")
420+
)
421+
):
411422
raise ValueError(
412423
f"Invalid 'source_dir' path: {source_dir}. "
413424
+ "Must be a valid local directory, "

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,6 @@
9292
source_dir=DEFAULT_SOURCE_DIR,
9393
entry_script="custom_script.py",
9494
)
95-
UNSUPPORTED_SOURCE_CODE = SourceCode(
96-
entry_script="train.py",
97-
)
9895
DEFAULT_ENTRYPOINT = ["/bin/bash"]
9996
DEFAULT_ARGUMENTS = [
10097
"-c",
@@ -152,7 +149,19 @@ def model_trainer():
152149
{
153150
"init_params": {
154151
"training_image": DEFAULT_IMAGE,
155-
"source_code": UNSUPPORTED_SOURCE_CODE,
152+
"source_code": SourceCode(
153+
entry_script="train.py",
154+
),
155+
},
156+
"should_throw": True,
157+
},
158+
{
159+
"init_params": {
160+
"training_image": DEFAULT_IMAGE,
161+
"source_code": SourceCode(
162+
source_dir="s3://bucket/requirements.txt",
163+
entry_script="custom_script.py",
164+
),
156165
},
157166
"should_throw": True,
158167
},
@@ -177,7 +186,7 @@ def model_trainer():
177186
"init_params": {
178187
"training_image": DEFAULT_IMAGE,
179188
"source_code": SourceCode(
180-
source_dir="s3://bucket/code",
189+
source_dir="s3://bucket/code/",
181190
entry_script="custom_script.py",
182191
),
183192
},
@@ -198,7 +207,8 @@ def model_trainer():
198207
"no_params",
199208
"training_image_and_algorithm_name",
200209
"only_training_image",
201-
"unsupported_source_code",
210+
"unsupported_source_code_missing_source_dir",
211+
"unsupported_source_code_s3_other_file",
202212
"supported_source_code_local_dir",
203213
"supported_source_code_local_tar_file",
204214
"supported_source_code_s3_dir",

0 commit comments

Comments
 (0)