@@ -407,28 +407,45 @@ def _validate_source_code(self, source_code: Optional[SourceCode]):
407
407
"If 'requirements' or 'entry_script' is provided in 'source_code', "
408
408
+ "'source_dir' must also be provided." ,
409
409
)
410
- if not _is_valid_path (source_dir , path_type = "Directory" ):
410
+ if not (
411
+ _is_valid_path (source_dir , path_type = "Directory" )
412
+ or _is_valid_s3_uri (source_dir , path_type = "Directory" )
413
+ or (
414
+ _is_valid_path (source_dir , path_type = "File" )
415
+ and source_dir .endswith (".tar.gz" )
416
+ )
417
+ or (
418
+ _is_valid_s3_uri (source_dir , path_type = "File" )
419
+ and source_dir .endswith (".tar.gz" )
420
+ )
421
+ ):
411
422
raise ValueError (
412
- f"Invalid 'source_dir' path: { source_dir } . " + "Must be a valid directory." ,
423
+ f"Invalid 'source_dir' path: { source_dir } . "
424
+ + "Must be a valid local directory, "
425
+ "s3 uri or path to tar.gz file stored locally or in s3." ,
413
426
)
414
427
if requirements :
415
- if not _is_valid_path (
416
- f"{ source_dir } /{ requirements } " ,
417
- path_type = "File" ,
418
- ):
419
- raise ValueError (
420
- f"Invalid 'requirements': { requirements } . "
421
- + "Must be a valid file within the 'source_dir'." ,
422
- )
428
+ if not source_dir .endswith (".tar.gz" ):
429
+ if not _is_valid_path (
430
+ f"{ source_dir } /{ requirements } " , path_type = "File"
431
+ ) and not _is_valid_s3_uri (
432
+ f"{ source_dir } /{ requirements } " , path_type = "File"
433
+ ):
434
+ raise ValueError (
435
+ f"Invalid 'requirements': { requirements } . "
436
+ + "Must be a valid file within the 'source_dir'." ,
437
+ )
423
438
if entry_script :
424
- if not _is_valid_path (
425
- f"{ source_dir } /{ entry_script } " ,
426
- path_type = "File" ,
427
- ):
428
- raise ValueError (
429
- f"Invalid 'entry_script': { entry_script } . "
430
- + "Must be a valid file within the 'source_dir'." ,
431
- )
439
+ if not source_dir .endswith (".tar.gz" ):
440
+ if not _is_valid_path (
441
+ f"{ source_dir } /{ entry_script } " , path_type = "File"
442
+ ) and not _is_valid_s3_uri (
443
+ f"{ source_dir } /{ entry_script } " , path_type = "File"
444
+ ):
445
+ raise ValueError (
446
+ f"Invalid 'entry_script': { entry_script } . "
447
+ + "Must be a valid file within the 'source_dir'." ,
448
+ )
432
449
433
450
def model_post_init (self , __context : Any ):
434
451
"""Post init method to perform custom validation and set default values."""
@@ -838,12 +855,17 @@ def _prepare_train_script(
838
855
839
856
install_requirements = ""
840
857
if source_code .requirements :
841
- install_requirements = "echo 'Installing requirements'\n "
842
- install_requirements = f"$SM_PIP_CMD install -r { source_code .requirements } "
858
+ install_requirements = (
859
+ "echo 'Installing requirements'\n "
860
+ + f"$SM_PIP_CMD install -r { source_code .requirements } "
861
+ )
843
862
844
863
working_dir = ""
845
864
if source_code .source_dir :
846
- working_dir = f"cd { SM_CODE_CONTAINER_PATH } "
865
+ working_dir = f"cd { SM_CODE_CONTAINER_PATH } \n "
866
+ if source_code .source_dir .endswith (".tar.gz" ):
867
+ tarfile_name = os .path .basename (source_code .source_dir )
868
+ working_dir += f"tar --strip-components=1 -xzf { tarfile_name } \n "
847
869
848
870
if base_command :
849
871
execute_driver = EXECUTE_BASE_COMMANDS .format (base_command = base_command )
0 commit comments