Skip to content

Commit db6160b

Browse files
committed
add unit and integ tests
1 parent 46ae9c5 commit db6160b

File tree

4 files changed

+56
-6
lines changed

4 files changed

+56
-6
lines changed

src/sagemaker/modules/train/model_trainer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,9 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
409409
)
410410
if not _is_valid_path(source_dir) and not _is_valid_s3_uri(source_dir):
411411
raise ValueError(
412-
f"Invalid 'source_dir' path: {source_dir}. " + "Must be a valid directory.",
412+
f"Invalid 'source_dir' path: {source_dir}. "
413+
+ "Must be a valid local directory, "
414+
"s3 uri or path to tar.gz file stored locally or in s3.",
413415
)
414416
if requirements:
415417
if not source_dir.endswith(".tar.gz"):
@@ -851,10 +853,7 @@ def _prepare_train_script(
851853
if source_code.source_dir:
852854
working_dir = f"cd {SM_CODE_CONTAINER_PATH} \n"
853855
if source_code.source_dir.endswith(".tar.gz"):
854-
if source_code.source_dir.startswith("s3://"):
855-
tarfile_name = os.path.basename(source_code.source_dir)
856-
else:
857-
tarfile_name = source_code.source_dir
856+
tarfile_name = os.path.basename(source_code.source_dir)
858857
working_dir += f"tar --strip-components=1 -xzf {tarfile_name} \n"
859858

860859
if base_command:
37.1 KB
Binary file not shown.

tests/integ/sagemaker/modules/train/test_model_trainer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,24 @@
4444

4545
DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310"
4646

47+
TAR_FILE_SOURCE_DIR = f"{DATA_DIR}/modules/script_mode/code.tar.gz"
48+
TAR_FILE_SOURCE_CODE = SourceCode(
49+
source_dir=TAR_FILE_SOURCE_DIR,
50+
requirements="requirements.txt",
51+
entry_script="custom_script.py",
52+
)
53+
54+
55+
def test_source_dir_local_tar_file(modules_sagemaker_session):
56+
model_trainer = ModelTrainer(
57+
sagemaker_session=modules_sagemaker_session,
58+
training_image=DEFAULT_CPU_IMAGE,
59+
source_code=TAR_FILE_SOURCE_CODE,
60+
base_job_name="source_dir_local_tar_file",
61+
)
62+
63+
model_trainer.train()
64+
4765

4866
def test_hp_contract_basic_py_script(modules_sagemaker_session):
4967
model_trainer = ModelTrainer(

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,46 @@ def model_trainer():
163163
},
164164
"should_throw": False,
165165
},
166+
{
167+
"init_params": {
168+
"training_image": DEFAULT_IMAGE,
169+
"source_code": SourceCode(
170+
source_dir=f"{DEFAULT_SOURCE_DIR}/code.tar.gz",
171+
entry_script="custom_script.py",
172+
),
173+
},
174+
"should_throw": False,
175+
},
176+
{
177+
"init_params": {
178+
"training_image": DEFAULT_IMAGE,
179+
"source_code": SourceCode(
180+
source_dir="s3://bucket/code",
181+
entry_script="custom_script.py",
182+
),
183+
},
184+
"should_throw": False,
185+
},
186+
{
187+
"init_params": {
188+
"training_image": DEFAULT_IMAGE,
189+
"source_code": SourceCode(
190+
source_dir="s3://bucket/code/code.tar.gz",
191+
entry_script="custom_script.py",
192+
),
193+
},
194+
"should_throw": False,
195+
},
166196
],
167197
ids=[
168198
"no_params",
169199
"training_image_and_algorithm_name",
170200
"only_training_image",
171201
"unsupported_source_code",
172-
"supported_source_code",
202+
"supported_source_code_local_dir",
203+
"supported_source_code_local_tar_file",
204+
"supported_source_code_s3_dir",
205+
"supported_source_code_s3_tar_file",
173206
],
174207
)
175208
def test_model_trainer_param_validation(test_case, modules_session):

0 commit comments

Comments
 (0)