@@ -407,28 +407,26 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
407
407
"If 'requirements' or 'entry_script' is provided in 'source_code', "
408
408
+ "'source_dir' must also be provided." ,
409
409
)
410
- if not _is_valid_path (source_dir , path_type = "Directory" ) or _is_valid_s3_uri (source_dir , path_type = "Directory" ):
410
+ if not _is_valid_path (source_dir ) and not _is_valid_s3_uri (source_dir ):
411
411
raise ValueError (
412
412
f"Invalid 'source_dir' path: { source_dir } . " + "Must be a valid directory." ,
413
413
)
414
414
if requirements :
415
- if not _is_valid_path (
416
- f"{ source_dir } /{ requirements } " ,
417
- path_type = "File" ,
418
- ):
419
- raise ValueError (
420
- f"Invalid 'requirements': { requirements } . "
421
- + "Must be a valid file within the 'source_dir'." ,
422
- )
415
+ if not source_dir .endswith (".tar.gz" ):
416
+ if (not _is_valid_path (f"{ source_dir } /{ requirements } " , path_type = "File" )
417
+ and not _is_valid_s3_uri (f"{ source_dir } /{ requirements } " , path_type = "File" )):
418
+ raise ValueError (
419
+ f"Invalid 'requirements': { requirements } . "
420
+ + "Must be a valid file within the 'source_dir'." ,
421
+ )
423
422
if entry_script :
424
- if not _is_valid_path (
425
- f"{ source_dir } /{ entry_script } " ,
426
- path_type = "File" ,
427
- ):
428
- raise ValueError (
429
- f"Invalid 'entry_script': { entry_script } . "
430
- + "Must be a valid file within the 'source_dir'." ,
431
- )
423
+ if not source_dir .endswith (".tar.gz" ):
424
+ if (not _is_valid_path (f"{ source_dir } /{ entry_script } " , path_type = "File" )
425
+ and not _is_valid_s3_uri (f"{ source_dir } /{ entry_script } " , path_type = "File" )):
426
+ raise ValueError (
427
+ f"Invalid 'entry_script': { entry_script } . "
428
+ + "Must be a valid file within the 'source_dir'." ,
429
+ )
432
430
433
431
def model_post_init (self , __context : Any ):
434
432
"""Post init method to perform custom validation and set default values."""
@@ -838,12 +836,20 @@ def _prepare_train_script(
838
836
839
837
install_requirements = ""
840
838
if source_code .requirements :
841
- install_requirements = "echo 'Installing requirements'\n "
842
- install_requirements = f"$SM_PIP_CMD install -r { source_code .requirements } "
839
+ install_requirements = (
840
+ "echo 'Installing requirements'\n " +
841
+ f"$SM_PIP_CMD install -r { source_code .requirements } "
842
+ )
843
843
844
844
working_dir = ""
845
845
if source_code .source_dir :
846
- working_dir = f"cd { SM_CODE_CONTAINER_PATH } "
846
+ working_dir = f"cd { SM_CODE_CONTAINER_PATH } \n "
847
+ if source_code .source_dir .endswith (".tar.gz" ):
848
+ if source_code .source_dir .startswith ("s3://" ):
849
+ tarfile_name = os .path .basename (source_code .source_dir )
850
+ else :
851
+ tarfile_name = source_code .source_dir
852
+ working_dir += f"tar --strip-components=1 -xzf { tarfile_name } \n "
847
853
848
854
if base_command :
849
855
execute_driver = EXECUTE_BASE_COMMANDS .format (base_command = base_command )
0 commit comments