@@ -407,28 +407,45 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
407407 "If 'requirements' or 'entry_script' is provided in 'source_code', "
408408 + "'source_dir' must also be provided." ,
409409 )
410- if not _is_valid_path (source_dir , path_type = "Directory" ):
410+ if not (
411+ _is_valid_path (source_dir , path_type = "Directory" )
412+ or _is_valid_s3_uri (source_dir , path_type = "Directory" )
413+ or (
414+ _is_valid_path (source_dir , path_type = "File" )
415+ and source_dir .endswith (".tar.gz" )
416+ )
417+ or (
418+ _is_valid_s3_uri (source_dir , path_type = "File" )
419+ and source_dir .endswith (".tar.gz" )
420+ )
421+ ):
411422 raise ValueError (
412- f"Invalid 'source_dir' path: { source_dir } . " + "Must be a valid directory." ,
423+ f"Invalid 'source_dir' path: { source_dir } . "
424+ + "Must be a valid local directory, "
425+ "s3 uri or path to tar.gz file stored locally or in s3." ,
413426 )
414427 if requirements :
415- if not _is_valid_path (
416- f"{ source_dir } /{ requirements } " ,
417- path_type = "File" ,
418- ):
419- raise ValueError (
420- f"Invalid 'requirements': { requirements } . "
421- + "Must be a valid file within the 'source_dir'." ,
422- )
428+ if not source_dir .endswith (".tar.gz" ):
429+ if not _is_valid_path (
430+ f"{ source_dir } /{ requirements } " , path_type = "File"
431+ ) and not _is_valid_s3_uri (
432+ f"{ source_dir } /{ requirements } " , path_type = "File"
433+ ):
434+ raise ValueError (
435+ f"Invalid 'requirements': { requirements } . "
436+ + "Must be a valid file within the 'source_dir'." ,
437+ )
423438 if entry_script :
424- if not _is_valid_path (
425- f"{ source_dir } /{ entry_script } " ,
426- path_type = "File" ,
427- ):
428- raise ValueError (
429- f"Invalid 'entry_script': { entry_script } . "
430- + "Must be a valid file within the 'source_dir'." ,
431- )
439+ if not source_dir .endswith (".tar.gz" ):
440+ if not _is_valid_path (
441+ f"{ source_dir } /{ entry_script } " , path_type = "File"
442+ ) and not _is_valid_s3_uri (
443+ f"{ source_dir } /{ entry_script } " , path_type = "File"
444+ ):
445+ raise ValueError (
446+ f"Invalid 'entry_script': { entry_script } . "
447+ + "Must be a valid file within the 'source_dir'." ,
448+ )
432449
433450 def model_post_init (self , __context : Any ):
434451 """Post init method to perform custom validation and set default values."""
@@ -838,12 +855,17 @@ def _prepare_train_script(
838855
839856 install_requirements = ""
840857 if source_code .requirements :
841- install_requirements = "echo 'Installing requirements'\n "
842- install_requirements = f"$SM_PIP_CMD install -r { source_code .requirements } "
858+ install_requirements = (
859+ "echo 'Installing requirements'\n "
860+ + f"$SM_PIP_CMD install -r { source_code .requirements } "
861+ )
843862
844863 working_dir = ""
845864 if source_code .source_dir :
846- working_dir = f"cd { SM_CODE_CONTAINER_PATH } "
865+ working_dir = f"cd { SM_CODE_CONTAINER_PATH } \n "
866+ if source_code .source_dir .endswith (".tar.gz" ):
867+ tarfile_name = os .path .basename (source_code .source_dir )
868+ working_dir += f"tar --strip-components=1 -xzf { tarfile_name } \n "
847869
848870 if base_command :
849871 execute_driver = EXECUTE_BASE_COMMANDS .format (base_command = base_command )
0 commit comments