Skip to content

Commit 2cc906b

Browse files
author
Bryannah Hernandez
committed
InferenceSpec support for HF
1 parent ce43606 commit 2cc906b

File tree

3 files changed

+155
-1
lines changed

3 files changed

+155
-1
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""This module is for SageMaker inference.py."""
2+
3+
from __future__ import absolute_import
4+
import os
5+
import io
6+
import cloudpickle
7+
import shutil
8+
import platform
9+
import importlib
10+
from pathlib import Path
11+
from functools import partial
12+
from sagemaker.serve.validations.check_integrity import perform_integrity_check
13+
from sagemaker.serve.spec.inference_spec import InferenceSpec
14+
from sagemaker.serve.detector.image_detector import _detect_framework_and_version, _get_model_base
15+
from sagemaker.serve.detector.pickler import load_xgboost_from_json
16+
import logging
17+
18+
logger = logging.getLogger(__name__)
19+
20+
inference_spec = None
21+
native_model = None
22+
schema_builder = None
23+
24+
25+
def model_fn(model_dir):
26+
"""Placeholder docstring"""
27+
shared_libs_path = Path(model_dir + "/shared_libs")
28+
29+
if shared_libs_path.exists():
30+
# before importing, place dynamic linked libraries in shared lib path
31+
shutil.copytree(shared_libs_path, "/lib", dirs_exist_ok=True)
32+
33+
serve_path = Path(__file__).parent.joinpath("serve.pkl")
34+
with open(str(serve_path), mode="rb") as file:
35+
global inference_spec, native_model, schema_builder
36+
obj = cloudpickle.load(file)
37+
if isinstance(obj[0], InferenceSpec):
38+
inference_spec, schema_builder = obj
39+
else:
40+
native_model, schema_builder = obj
41+
if native_model:
42+
framework, _ = _detect_framework_and_version(
43+
model_base=str(_get_model_base(model=native_model))
44+
)
45+
if framework == "pytorch":
46+
native_model.eval()
47+
return native_model if callable(native_model) else native_model.predict
48+
elif inference_spec:
49+
return partial(inference_spec.invoke, model=inference_spec.load(model_dir))
50+
51+
52+
def input_fn(input_data, content_type):
53+
"""Placeholder docstring"""
54+
try:
55+
if hasattr(schema_builder, "custom_input_translator"):
56+
return schema_builder.custom_input_translator.deserialize(
57+
io.BytesIO(input_data), content_type
58+
)
59+
else:
60+
return schema_builder.input_deserializer.deserialize(
61+
io.BytesIO(input_data), content_type[0]
62+
)
63+
except Exception as e:
64+
raise Exception("Encountered error in deserialize_request.") from e
65+
66+
67+
def predict_fn(input_data, predict_callable):
68+
"""Placeholder docstring"""
69+
return predict_callable(input_data)
70+
71+
72+
def output_fn(predictions, accept_type):
73+
"""Placeholder docstring"""
74+
try:
75+
if hasattr(schema_builder, "custom_output_translator"):
76+
return schema_builder.custom_output_translator.serialize(predictions, accept_type)
77+
else:
78+
return schema_builder.output_serializer.serialize(predictions)
79+
except Exception as e:
80+
logger.error("Encountered error: %s in serialize_response." % e)
81+
raise Exception("Encountered error in serialize_response.") from e
82+
83+

src/sagemaker/serve/model_server/multi_model_server/server.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import requests
66
import logging
7+
import platform
78
from pathlib import Path
89
from sagemaker import Session, fw_utils
910
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
@@ -42,7 +43,14 @@ def _start_serving(
4243
"mode": "rw",
4344
},
4445
},
45-
environment=_update_env_vars(env_vars),
46+
47+
environment={
48+
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
49+
"SAGEMAKER_PROGRAM": "inference.py",
50+
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
51+
"LOCAL_PYTHON": platform.python_version(),
52+
**env_vars,
53+
},
4654
)
4755

4856
def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
@@ -80,10 +88,12 @@ class SageMakerMultiModelServer:
8088
def _upload_server_artifacts(
8189
self,
8290
model_path: str,
91+
secret_key: str,
8392
sagemaker_session: Session,
8493
s3_model_data_url: str = None,
8594
image: str = None,
8695
env_vars: dict = None,
96+
8797
):
8898
if s3_model_data_url:
8999
bucket, key_prefix = parse_s3_url(url=s3_model_data_url)
@@ -116,6 +126,16 @@ def _upload_server_artifacts(
116126
"S3Uri": model_data_url + "/",
117127
}
118128
}
129+
130+
env_vars = {
131+
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
132+
"SAGEMAKER_PROGRAM": "inference.py",
133+
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
134+
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
135+
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
136+
"LOCAL_PYTHON": platform.python_version(),
137+
}
138+
119139
return model_data, _update_env_vars(env_vars)
120140

121141

tests/unit/sagemaker/serve/model_server/multi_model_server/test_multi_model_server_prepare.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,64 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
from pathlib import PosixPath
16+
import platform
1517
from unittest import TestCase
1618
from unittest.mock import Mock, patch
1719

20+
import numpy as np
21+
1822
from sagemaker.serve.model_server.multi_model_server.prepare import _create_dir_structure
1923

24+
from sagemaker.serve.model_server.multi_model_server.server import (
25+
LocalMultiModelServer,
26+
)
27+
28+
CPU_TF_IMAGE = "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04"
29+
MODEL_PATH = "model_path"
30+
MODEL_REPO = f"{MODEL_PATH}/1"
31+
ENV_VAR = {"KEY": "VALUE"}
32+
PAYLOAD = np.random.rand(3, 4).astype(dtype=np.float32)
33+
DTYPE = "TYPE_FP32"
34+
SECRET_KEY = "secret_key"
35+
INFER_RESPONSE = {"outputs": [{"name": "output_name"}]}
36+
2037

2138
class MultiModelServerPrepareTests(TestCase):
39+
def test_start_invoke_destroy_local_multi_model_server(self):
40+
mock_container = Mock()
41+
mock_docker_client = Mock()
42+
mock_docker_client.containers.run.return_value = mock_container
43+
44+
local_multi_model_server = LocalMultiModelServer()
45+
mock_schema_builder = Mock()
46+
mock_schema_builder.input_serializer.serialize.return_value = PAYLOAD
47+
local_multi_model_server.schema_builder = mock_schema_builder
48+
49+
local_multi_model_server._start_serving(
50+
client=mock_docker_client,
51+
model_path=MODEL_PATH,
52+
secret_key=SECRET_KEY,
53+
env_vars=ENV_VAR,
54+
image=CPU_TF_IMAGE,
55+
)
56+
57+
mock_docker_client.containers.run.assert_called_once_with(
58+
CPU_TF_IMAGE,
59+
"serve",
60+
detach=True,
61+
auto_remove=True,
62+
network_mode="host",
63+
volumes={PosixPath("model_path"): {"bind": "/opt/ml/model", "mode": "rw"}},
64+
environment={
65+
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
66+
"SAGEMAKER_PROGRAM": "inference.py",
67+
"SAGEMAKER_SERVE_SECRET_KEY": "secret_key",
68+
"LOCAL_PYTHON": platform.python_version(),
69+
"KEY": "VALUE",
70+
},
71+
)
72+
2273
@patch("sagemaker.serve.model_server.multi_model_server.prepare._check_disk_space")
2374
@patch("sagemaker.serve.model_server.multi_model_server.prepare._check_docker_disk_usage")
2475
@patch("sagemaker.serve.model_server.multi_model_server.prepare.Path")

0 commit comments

Comments
 (0)