File tree Expand file tree Collapse file tree 2 files changed +36
-2
lines changed Expand file tree Collapse file tree 2 files changed +36
-2
lines changed Original file line number Diff line number Diff line change @@ -530,7 +530,9 @@ def _create_or_update_code_dir(
530530 """
531531 code_dir = os .path .join (model_dir , "code" )
532532 if os .path .exists (code_dir ):
533- shutil .rmtree (code_dir , ignore_errors = True )
533+ for filename in os .listdir (code_dir ):
534+ if filename .endswith (".py" ):
535+ os .remove (os .path .join (code_dir , filename ))
534536 if source_directory and source_directory .lower ().startswith ("s3://" ):
535537 local_code_path = os .path .join (tmp , "local_code.tar.gz" )
536538 download_file_from_url (source_directory , local_code_path , sagemaker_session )
@@ -539,9 +541,12 @@ def _create_or_update_code_dir(
539541 t .extractall (path = code_dir )
540542
541543 elif source_directory :
544+ if os .path .exists (code_dir ):
545+ shutil .rmtree (code_dir )
542546 shutil .copytree (source_directory , code_dir )
543547 else :
544- os .mkdir (code_dir )
548+ if not os .path .exists (code_dir ):
549+ os .mkdir (code_dir )
545550 shutil .copy2 (inference_script , code_dir )
546551
547552 for dependency in dependencies :
Original file line number Diff line number Diff line change @@ -614,6 +614,35 @@ def test_repack_model_from_file_to_folder(tmp):
614614 }
615615
616616
617+ def test_repack_model_with_inference_code_and_requirements (tmp , fake_s3 ):
618+ create_file_tree (
619+ tmp ,
620+ [
621+ "new-inference.py" ,
622+ "model-dir/model" ,
623+ "model-dir/code/old-inference.py" ,
624+ "model-dir/code/requirements.txt" ,
625+ ],
626+ )
627+
628+ fake_s3 .tar_and_upload ("model-dir" , "s3://fake/location" )
629+
630+ sagemaker .utils .repack_model (
631+ os .path .join (tmp , "new-inference.py" ),
632+ None ,
633+ None ,
634+ "s3://fake/location" ,
635+ "s3://destination-bucket/repacked-model" ,
636+ fake_s3 .sagemaker_session ,
637+ )
638+
639+ assert list_tar_files (fake_s3 .fake_upload_path , tmp ) == {
640+ "/code/requirements.txt" ,
641+ "/code/new-inference.py" ,
642+ "/model" ,
643+ }
644+
645+
617646class FakeS3 (object ):
618647 def __init__ (self , tmp ):
619648 self .tmp = tmp
You can’t perform that action at this time.
0 commit comments