@@ -317,6 +317,9 @@ def test_repack_model_without_source_dir(tmpdir):
317317 script_path = os .path .join (source_dir , 'inference.py' )
318318 write_file (script_path , 'inference script' )
319319
320+ script_path = os .path .join (source_dir , 'this-file-should-not-be-included.py' )
321+ write_file (script_path , 'This file should not be included' )
322+
320323 contents = [model_path ]
321324
322325 sagemaker_session = MagicMock ()
@@ -334,6 +337,44 @@ def test_repack_model_without_source_dir(tmpdir):
334337 assert re .match (r'^s3://fake/model-\d+-\d+.tar.gz$' , new_model_uri )
335338
336339
340+ def test_repack_model_with_entry_point_without_path_without_source_dir (tmpdir ):
341+
342+ tmp = str (tmpdir )
343+
344+ model_path = os .path .join (tmp , 'model' )
345+ write_file (model_path , 'model data' )
346+
347+ source_dir = os .path .join (tmp , 'source-dir' )
348+ os .mkdir (source_dir )
349+ script_path = os .path .join (source_dir , 'inference.py' )
350+ write_file (script_path , 'inference script' )
351+
352+ script_path = os .path .join (source_dir , 'this-file-should-not-be-included.py' )
353+ write_file (script_path , 'This file should not be included' )
354+
355+ contents = [model_path ]
356+
357+ sagemaker_session = MagicMock ()
358+ mock_s3_model_tar (contents , sagemaker_session , tmp )
359+ fake_upload_path = mock_s3_upload (sagemaker_session , tmp )
360+
361+ model_uri = 's3://fake/location'
362+
363+ cwd = os .getcwd ()
364+ try :
365+ os .chdir (source_dir )
366+
367+ new_model_uri = sagemaker .utils .repack_model ('inference.py' ,
368+ None ,
369+ model_uri ,
370+ sagemaker_session )
371+ finally :
372+ os .chdir (cwd )
373+
374+ assert list_tar_files (fake_upload_path , tmpdir ) == {'/code/inference.py' , '/model' }
375+ assert re .match (r'^s3://fake/model-\d+-\d+.tar.gz$' , new_model_uri )
376+
377+
337378def test_repack_model_from_s3_saved_model_to_s3 (tmpdir ):
338379
339380 tmp = str (tmpdir )
@@ -346,6 +387,9 @@ def test_repack_model_from_s3_saved_model_to_s3(tmpdir):
346387 script_path = os .path .join (source_dir , 'inference.py' )
347388 write_file (script_path , 'inference script' )
348389
390+ script_path = os .path .join (source_dir , 'this-file-should-be-included.py' )
391+ write_file (script_path , 'This file should be included' )
392+
349393 contents = [model_path ]
350394
351395 sagemaker_session = MagicMock ()
@@ -359,7 +403,9 @@ def test_repack_model_from_s3_saved_model_to_s3(tmpdir):
359403 model_uri ,
360404 sagemaker_session )
361405
362- assert list_tar_files (fake_upload_path , tmpdir ) == {'/code/inference.py' , '/model' }
406+ assert list_tar_files (fake_upload_path , tmpdir ) == {'/code/this-file-should-be-included.py' ,
407+ '/code/inference.py' ,
408+ '/model' }
363409 assert re .match (r'^s3://fake/model-\d+-\d+.tar.gz$' , new_model_uri )
364410
365411
0 commit comments