@@ -407,28 +407,26 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
407407 "If 'requirements' or 'entry_script' is provided in 'source_code', "
408408 + "'source_dir' must also be provided." ,
409409 )
410- if not _is_valid_path (source_dir , path_type = "Directory" ) or _is_valid_s3_uri (source_dir , path_type = "Directory" ):
410+ if not _is_valid_path (source_dir ) and not _is_valid_s3_uri (source_dir ):
411411 raise ValueError (
412412 f"Invalid 'source_dir' path: { source_dir } . " + "Must be a valid directory." ,
413413 )
414414 if requirements :
415- if not _is_valid_path (
416- f"{ source_dir } /{ requirements } " ,
417- path_type = "File" ,
418- ):
419- raise ValueError (
420- f"Invalid 'requirements': { requirements } . "
421- + "Must be a valid file within the 'source_dir'." ,
422- )
415+ if not source_dir .endswith (".tar.gz" ):
416+ if (not _is_valid_path (f"{ source_dir } /{ requirements } " , path_type = "File" )
417+ and not _is_valid_s3_uri (f"{ source_dir } /{ requirements } " , path_type = "File" )):
418+ raise ValueError (
419+ f"Invalid 'requirements': { requirements } . "
420+ + "Must be a valid file within the 'source_dir'." ,
421+ )
423422 if entry_script :
424- if not _is_valid_path (
425- f"{ source_dir } /{ entry_script } " ,
426- path_type = "File" ,
427- ):
428- raise ValueError (
429- f"Invalid 'entry_script': { entry_script } . "
430- + "Must be a valid file within the 'source_dir'." ,
431- )
423+ if not source_dir .endswith (".tar.gz" ):
424+ if (not _is_valid_path (f"{ source_dir } /{ entry_script } " , path_type = "File" )
425+ and not _is_valid_s3_uri (f"{ source_dir } /{ entry_script } " , path_type = "File" )):
426+ raise ValueError (
427+ f"Invalid 'entry_script': { entry_script } . "
428+ + "Must be a valid file within the 'source_dir'." ,
429+ )
432430
433431 def model_post_init (self , __context : Any ):
434432 """Post init method to perform custom validation and set default values."""
@@ -838,12 +836,20 @@ def _prepare_train_script(
838836
839837 install_requirements = ""
840838 if source_code .requirements :
841- install_requirements = "echo 'Installing requirements'\n "
842- install_requirements = f"$SM_PIP_CMD install -r { source_code .requirements } "
839+ install_requirements = (
840+ "echo 'Installing requirements'\n " +
841+ f"$SM_PIP_CMD install -r { source_code .requirements } "
842+ )
843843
844844 working_dir = ""
845845 if source_code .source_dir :
846- working_dir = f"cd { SM_CODE_CONTAINER_PATH } "
846+ working_dir = f"cd { SM_CODE_CONTAINER_PATH } \n "
847+ if source_code .source_dir .endswith (".tar.gz" ):
848+ if source_code .source_dir .startswith ("s3://" ):
849+ tarfile_name = os .path .basename (source_code .source_dir )
850+ else :
851+ tarfile_name = source_code .source_dir
852+ working_dir += f"tar --strip-components=1 -xzf { tarfile_name } \n "
847853
848854 if base_command :
849855 execute_driver = EXECUTE_BASE_COMMANDS .format (base_command = base_command )
0 commit comments