Skip to content

Commit e62dd2e

Browse files
authored
fix: copy dependencies into new folder when repacking model (#1021)
1 parent de676a1 commit e62dd2e

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

src/sagemaker/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,10 +468,12 @@ def _create_or_update_code_dir(
468468
shutil.copy2(inference_script, code_dir)
469469

470470
for dependency in dependencies:
471+
lib_dir = os.path.join(code_dir, "lib")
471472
if os.path.isdir(dependency):
472-
shutil.copytree(dependency, code_dir)
473+
shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency)))
473474
else:
474-
shutil.copy2(dependency, code_dir)
475+
os.mkdir(lib_dir)
476+
shutil.copy2(dependency, lib_dir)
475477

476478

477479
def _extract_model(model_uri, sagemaker_session, tmp):

tests/unit/test_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def test_repack_model_without_source_dir(tmp, fake_s3):
344344
[
345345
"model-dir/model",
346346
"dependencies/a",
347-
"dependencies/b",
347+
"dependencies/some/dir/b",
348348
"source-dir/inference.py",
349349
"source-dir/this-file-should-not-be-included.py",
350350
],
@@ -355,16 +355,19 @@ def test_repack_model_without_source_dir(tmp, fake_s3):
355355
sagemaker.utils.repack_model(
356356
inference_script=os.path.join(tmp, "source-dir/inference.py"),
357357
source_directory=None,
358-
dependencies=[os.path.join(tmp, "dependencies/a"), os.path.join(tmp, "dependencies/b")],
358+
dependencies=[
359+
os.path.join(tmp, "dependencies/a"),
360+
os.path.join(tmp, "dependencies/some/dir"),
361+
],
359362
model_uri="s3://fake/location",
360363
repacked_model_uri="s3://destination-bucket/model.tar.gz",
361364
sagemaker_session=fake_s3.sagemaker_session,
362365
)
363366

364367
assert list_tar_files(fake_s3.fake_upload_path, tmp) == {
365368
"/model",
366-
"/code/a",
367-
"/code/b",
369+
"/code/lib/a",
370+
"/code/lib/dir/b",
368371
"/code/inference.py",
369372
}
370373

@@ -449,7 +452,7 @@ def test_repack_model_from_file_to_file(tmp):
449452
sagemaker_session,
450453
)
451454

452-
assert list_tar_files(destination_path, tmp) == {"/code/a", "/code/inference.py", "/model"}
455+
assert list_tar_files(destination_path, tmp) == {"/code/lib/a", "/code/inference.py", "/model"}
453456

454457

455458
def test_repack_model_with_inference_code_should_replace_the_code(tmp, fake_s3):

0 commit comments

Comments
 (0)