Skip to content

Commit d2d1cdf

Browse files
authored
fix: preserve inference script in model repack. (#1432)
1 parent 11537b9 commit d2d1cdf

File tree

4 files changed

+63
-5
lines changed

4 files changed

+63
-5
lines changed

src/sagemaker/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -534,10 +534,6 @@ def _create_or_update_code_dir(
534534
tmp:
535535
"""
536536
code_dir = os.path.join(model_dir, "code")
537-
if os.path.exists(code_dir):
538-
for filename in os.listdir(code_dir):
539-
if filename.endswith(".py"):
540-
os.remove(os.path.join(code_dir, filename))
541537
if source_directory and source_directory.lower().startswith("s3://"):
542538
local_code_path = os.path.join(tmp, "local_code.tar.gz")
543539
download_file_from_url(source_directory, local_code_path, sagemaker_session)
@@ -552,7 +548,13 @@ def _create_or_update_code_dir(
552548
else:
553549
if not os.path.exists(code_dir):
554550
os.mkdir(code_dir)
555-
shutil.copy2(inference_script, code_dir)
551+
try:
552+
shutil.copy2(inference_script, code_dir)
553+
except FileNotFoundError:
554+
if os.path.exists(os.path.join(code_dir, inference_script)):
555+
pass
556+
else:
557+
raise
556558

557559
for dependency in dependencies:
558560
lib_dir = os.path.join(code_dir, "lib")
85.8 KB
Binary file not shown.

tests/integ/test_pytorch_train.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
MNIST_DIR = os.path.join(DATA_DIR, "pytorch_mnist")
3333
MNIST_SCRIPT = os.path.join(MNIST_DIR, "mnist.py")
34+
PACKED_MODEL = os.path.join(MNIST_DIR, "packed_model.tar.gz")
3435

3536
EIA_DIR = os.path.join(DATA_DIR, "pytorch_eia")
3637
EIA_MODEL = os.path.join(EIA_DIR, "model_mnist.tar.gz")
@@ -124,6 +125,31 @@ def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type
124125
assert output.shape == (batch_size, 10)
125126

126127

128+
@pytest.mark.skipif(
129+
PYTHON_VERSION == "py2",
130+
reason="Python 2 is supported by PyTorch {} and lower versions.".format(LATEST_PY2_VERSION),
131+
)
132+
def test_deploy_packed_model_with_entry_point_name(sagemaker_session, cpu_instance_type):
133+
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())
134+
135+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
136+
model_data = sagemaker_session.upload_data(path=PACKED_MODEL)
137+
model = PyTorchModel(
138+
model_data,
139+
"SageMakerRole",
140+
entry_point="mnist.py",
141+
framework_version="1.4.0",
142+
sagemaker_session=sagemaker_session,
143+
)
144+
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
145+
146+
batch_size = 100
147+
data = numpy.random.rand(batch_size, 1, 28, 28).astype(numpy.float32)
148+
output = predictor.predict(data)
149+
150+
assert output.shape == (batch_size, 10)
151+
152+
127153
@pytest.mark.skipif(PYTHON_VERSION == "py2", reason="PyTorch EIA does not support Python 2.")
128154
@pytest.mark.skipif(
129155
test_region() not in EI_SUPPORTED_REGIONS, reason="EI isn't supported in that specific region."

tests/unit/test_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,36 @@ def test_repack_model_with_inference_code_and_requirements(tmp, fake_s3):
643643
assert list_tar_files(fake_s3.fake_upload_path, tmp) == {
644644
"/code/requirements.txt",
645645
"/code/new-inference.py",
646+
"/code/old-inference.py",
647+
"/model",
648+
}
649+
650+
651+
def test_repack_model_with_same_inference_file_name(tmp, fake_s3):
652+
create_file_tree(
653+
tmp,
654+
[
655+
"inference.py",
656+
"model-dir/model",
657+
"model-dir/code/inference.py",
658+
"model-dir/code/requirements.txt",
659+
],
660+
)
661+
662+
fake_s3.tar_and_upload("model-dir", "s3://fake/location")
663+
664+
sagemaker.utils.repack_model(
665+
os.path.join(tmp, "inference.py"),
666+
None,
667+
None,
668+
"s3://fake/location",
669+
"s3://destination-bucket/repacked-model",
670+
fake_s3.sagemaker_session,
671+
)
672+
673+
assert list_tar_files(fake_s3.fake_upload_path, tmp) == {
674+
"/code/requirements.txt",
675+
"/code/inference.py",
646676
"/model",
647677
}
648678

0 commit comments

Comments
 (0)