Skip to content

Commit 6789b61

Browse files
bryannahm1Bryannah Hernandez
andauthored
feat: InferenceSpec support for MMS and testing (#4763)
* feat: InferenceSpec support for MMS and testing * Fix formatting * CR Fixes for InferenceSpec MMS * remove code * Changes to environment, avoid duplicates * Remove loggers and add docstring updates * Changes for unit tests in transformers build * formatting changes * Add secret_key to endpoint mode * get_model, docstring, and if changes * pre-push fixes * integ test edits * formatting fixes * format changes * updated value error * formatting changes for value error update --------- Co-authored-by: Bryannah Hernandez <[email protected]>
1 parent b7621dc commit 6789b61

File tree

7 files changed

+321
-7
lines changed

7 files changed

+321
-7
lines changed

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
"""Transformers build logic with model builder"""
1414
from __future__ import absolute_import
1515
import logging
16+
import os
1617
from abc import ABC, abstractmethod
1718
from typing import Type
19+
from pathlib import Path
1820
from packaging.version import Version
1921

2022
from sagemaker.model import Model
@@ -26,7 +28,12 @@
2628
from sagemaker.huggingface import HuggingFaceModel
2729
from sagemaker.serve.model_server.multi_model_server.prepare import (
2830
_create_dir_structure,
31+
prepare_for_mms,
2932
)
33+
from sagemaker.serve.detector.image_detector import (
34+
auto_detect_container,
35+
)
36+
from sagemaker.serve.detector.pickler import save_pkl
3037
from sagemaker.serve.utils.optimize_utils import _is_optimized
3138
from sagemaker.serve.utils.predictors import TransformersLocalModePredictor
3239
from sagemaker.serve.utils.types import ModelServer
@@ -73,6 +80,8 @@ def __init__(self):
7380
self.pytorch_version = None
7481
self.instance_type = None
7582
self.schema_builder = None
83+
self.inference_spec = None
84+
self.shared_libs = None
7685

7786
@abstractmethod
7887
def _prepare_for_mode(self):
@@ -110,7 +119,7 @@ def _get_hf_metadata_create_model(self) -> Type[Model]:
110119
"""
111120

112121
hf_model_md = get_huggingface_model_metadata(
113-
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
122+
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
114123
)
115124
hf_config = image_uris.config_for_framework("huggingface").get("inference")
116125
config = hf_config["versions"]
@@ -245,18 +254,22 @@ def _build_transformers_env(self):
245254

246255
_create_dir_structure(self.model_path)
247256
if not hasattr(self, "pysdk_model"):
248-
self.env_vars.update({"HF_MODEL_ID": self.model})
257+
258+
if self.inference_spec is not None:
259+
self.env_vars.update({"HF_MODEL_ID": self.inference_spec.get_model()})
260+
else:
261+
self.env_vars.update({"HF_MODEL_ID": self.model})
249262

250263
logger.info(self.env_vars)
251264

252265
# TODO: Move to a helper function
253266
if hasattr(self.env_vars, "HF_API_TOKEN"):
254267
self.hf_model_config = _get_model_config_properties_from_hf(
255-
self.model, self.env_vars.get("HF_API_TOKEN")
268+
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HF_API_TOKEN")
256269
)
257270
else:
258271
self.hf_model_config = _get_model_config_properties_from_hf(
259-
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
272+
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
260273
)
261274

262275
self.pysdk_model = self._create_transformers_model()
@@ -292,6 +305,42 @@ def _get_supported_version(self, hf_config, hugging_face_version, base_fw):
292305
versions_to_return.append(base_fw_version)
293306
return sorted(versions_to_return, reverse=True)[0]
294307

308+
def _auto_detect_container(self):
309+
"""Set image_uri by detecting container via model name or inference spec"""
310+
# Auto detect the container image uri
311+
if self.image_uri:
312+
logger.info(
313+
"Skipping auto detection as the image uri is provided %s",
314+
self.image_uri,
315+
)
316+
return
317+
318+
if self.model:
319+
logger.info(
320+
"Auto detect container url for the provided model and on instance %s",
321+
self.instance_type,
322+
)
323+
self.image_uri = auto_detect_container(
324+
self.model, self.sagemaker_session.boto_region_name, self.instance_type
325+
)
326+
327+
elif self.inference_spec:
328+
# TODO: this won't work for larger image.
329+
# Fail and let the customer include the image uri
330+
logger.warning(
331+
"model_path provided with no image_uri. Attempting to autodetect the image\
332+
by loading the model using inference_spec.load()..."
333+
)
334+
self.image_uri = auto_detect_container(
335+
self.inference_spec.load(self.model_path),
336+
self.sagemaker_session.boto_region_name,
337+
self.instance_type,
338+
)
339+
else:
340+
raise ValueError(
341+
"Cannot detect and set image_uri. Please pass model or inference spec."
342+
)
343+
295344
def _build_for_transformers(self):
296345
"""Method that triggers model build
297346
@@ -300,6 +349,26 @@ def _build_for_transformers(self):
300349
self.secret_key = None
301350
self.model_server = ModelServer.MMS
302351

352+
if self.inference_spec:
353+
354+
os.makedirs(self.model_path, exist_ok=True)
355+
356+
code_path = Path(self.model_path).joinpath("code")
357+
358+
save_pkl(code_path, (self.inference_spec, self.schema_builder))
359+
logger.info("PKL file saved to file: %s", code_path)
360+
361+
self._auto_detect_container()
362+
363+
self.secret_key = prepare_for_mms(
364+
model_path=self.model_path,
365+
shared_libs=self.shared_libs,
366+
dependencies=self.dependencies,
367+
session=self.sagemaker_session,
368+
image_uri=self.image_uri,
369+
inference_spec=self.inference_spec,
370+
)
371+
303372
self._build_transformers_env()
304373

305374
if self.role_arn:

src/sagemaker/serve/mode/sagemaker_endpoint_mode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def prepare(
130130
model_path=model_path,
131131
sagemaker_session=sagemaker_session,
132132
s3_model_data_url=s3_model_data_url,
133+
secret_key=secret_key,
133134
image=image,
134135
should_upload_artifacts=should_upload_artifacts,
135136
)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""This module is for SageMaker inference.py."""
2+
3+
from __future__ import absolute_import
4+
import os
5+
import io
6+
import cloudpickle
7+
import shutil
8+
import platform
9+
from pathlib import Path
10+
from functools import partial
11+
from sagemaker.serve.spec.inference_spec import InferenceSpec
12+
from sagemaker.serve.validations.check_integrity import perform_integrity_check
13+
import logging
14+
15+
logger = logging.getLogger(__name__)
16+
17+
inference_spec = None
18+
schema_builder = None
19+
SHARED_LIBS_DIR = Path(__file__).parent.parent.joinpath("shared_libs")
20+
SERVE_PATH = Path(__file__).parent.joinpath("serve.pkl")
21+
METADATA_PATH = Path(__file__).parent.joinpath("metadata.json")
22+
23+
24+
def model_fn(model_dir):
25+
"""Overrides default method for loading a model"""
26+
shared_libs_path = Path(model_dir + "/shared_libs")
27+
28+
if shared_libs_path.exists():
29+
# before importing, place dynamic linked libraries in shared lib path
30+
shutil.copytree(shared_libs_path, "/lib", dirs_exist_ok=True)
31+
32+
serve_path = Path(__file__).parent.joinpath("serve.pkl")
33+
with open(str(serve_path), mode="rb") as file:
34+
global inference_spec, schema_builder
35+
obj = cloudpickle.load(file)
36+
if isinstance(obj[0], InferenceSpec):
37+
inference_spec, schema_builder = obj
38+
39+
if inference_spec:
40+
return partial(inference_spec.invoke, model=inference_spec.load(model_dir))
41+
42+
43+
def input_fn(input_data, content_type):
44+
"""Deserializes the bytes that were received from the model server"""
45+
try:
46+
if hasattr(schema_builder, "custom_input_translator"):
47+
return schema_builder.custom_input_translator.deserialize(
48+
io.BytesIO(input_data), content_type
49+
)
50+
else:
51+
return schema_builder.input_deserializer.deserialize(
52+
io.BytesIO(input_data), content_type[0]
53+
)
54+
except Exception as e:
55+
logger.error("Encountered error: %s in deserialize_response." % e)
56+
raise Exception("Encountered error in deserialize_request.") from e
57+
58+
59+
def predict_fn(input_data, predict_callable):
60+
"""Invokes the model that is taken in by model server"""
61+
return predict_callable(input_data)
62+
63+
64+
def output_fn(predictions, accept_type):
65+
"""Prediction is serialized to bytes and sent back to the customer"""
66+
try:
67+
if hasattr(schema_builder, "custom_output_translator"):
68+
return schema_builder.custom_output_translator.serialize(predictions, accept_type)
69+
else:
70+
return schema_builder.output_serializer.serialize(predictions)
71+
except Exception as e:
72+
logger.error("Encountered error: %s in serialize_response." % e)
73+
raise Exception("Encountered error in serialize_response.") from e
74+
75+
76+
def _run_preflight_diagnostics():
77+
_py_vs_parity_check()
78+
_pickle_file_integrity_check()
79+
80+
81+
def _py_vs_parity_check():
82+
container_py_vs = platform.python_version()
83+
local_py_vs = os.getenv("LOCAL_PYTHON")
84+
85+
if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]:
86+
logger.warning(
87+
f"The local python version {local_py_vs} differs from the python version "
88+
f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior"
89+
)
90+
91+
92+
def _pickle_file_integrity_check():
93+
with open(SERVE_PATH, "rb") as f:
94+
buffer = f.read()
95+
96+
perform_integrity_check(buffer=buffer, metadata_path=METADATA_PATH)
97+
98+
99+
# on import, execute
100+
_run_preflight_diagnostics()

src/sagemaker/serve/model_server/multi_model_server/prepare.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,23 @@
1414

1515
from __future__ import absolute_import
1616
import logging
17-
from pathlib import Path
18-
from typing import List
1917

2018
from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts
2119
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
2220

21+
from pathlib import Path
22+
import shutil
23+
from typing import List
24+
25+
from sagemaker.session import Session
26+
from sagemaker.serve.spec.inference_spec import InferenceSpec
27+
from sagemaker.serve.detector.dependency_manager import capture_dependencies
28+
from sagemaker.serve.validations.check_integrity import (
29+
generate_secret_key,
30+
compute_hash,
31+
)
32+
from sagemaker.remote_function.core.serialization import _MetaData
33+
2334
logger = logging.getLogger(__name__)
2435

2536

@@ -63,3 +74,56 @@ def prepare_mms_js_resources(
6374
model_path, code_dir = _create_dir_structure(model_path)
6475

6576
return _copy_jumpstart_artifacts(model_data, js_id, code_dir)
77+
78+
79+
def prepare_for_mms(
80+
model_path: str,
81+
shared_libs: List[str],
82+
dependencies: dict,
83+
session: Session,
84+
image_uri: str,
85+
inference_spec: InferenceSpec = None,
86+
) -> str:
87+
"""Prepares for InferenceSpec using model_path, writes inference.py, and captures dependencies to generate secret_key.
88+
89+
Args:to
90+
model_path (str) : Argument
91+
shared_libs (List[]) : Argument
92+
dependencies (dict) : Argument
93+
session (Session) : Argument
94+
inference_spec (InferenceSpec, optional) : Argument
95+
(default is None)
96+
Returns:
97+
( str ) : secret_key
98+
"""
99+
model_path = Path(model_path)
100+
if not model_path.exists():
101+
model_path.mkdir()
102+
elif not model_path.is_dir():
103+
raise Exception("model_dir is not a valid directory")
104+
105+
if inference_spec:
106+
inference_spec.prepare(str(model_path))
107+
108+
code_dir = model_path.joinpath("code")
109+
code_dir.mkdir(exist_ok=True)
110+
111+
shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir)
112+
113+
logger.info("Finished writing inference.py to code directory")
114+
115+
shared_libs_dir = model_path.joinpath("shared_libs")
116+
shared_libs_dir.mkdir(exist_ok=True)
117+
for shared_lib in shared_libs:
118+
shutil.copy2(Path(shared_lib), shared_libs_dir)
119+
120+
capture_dependencies(dependencies=dependencies, work_dir=code_dir)
121+
122+
secret_key = generate_secret_key()
123+
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
124+
buffer = f.read()
125+
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
126+
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
127+
metadata.write(_MetaData(hash_value).to_json())
128+
129+
return secret_key

src/sagemaker/serve/model_server/multi_model_server/server.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import requests
66
import logging
7+
import platform
78
from pathlib import Path
89
from sagemaker import Session, fw_utils
910
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
@@ -31,6 +32,17 @@ def _start_serving(
3132
env_vars: dict,
3233
):
3334
"""Placeholder docstring"""
35+
env = {
36+
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
37+
"SAGEMAKER_PROGRAM": "inference.py",
38+
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
39+
"LOCAL_PYTHON": platform.python_version(),
40+
}
41+
if env_vars:
42+
env_vars.update(env)
43+
else:
44+
env_vars = env
45+
3446
self.container = client.containers.run(
3547
image,
3648
"serve",
@@ -43,7 +55,7 @@ def _start_serving(
4355
"mode": "rw",
4456
},
4557
},
46-
environment=_update_env_vars(env_vars),
58+
environment=env_vars,
4759
)
4860

4961
def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
@@ -81,6 +93,7 @@ class SageMakerMultiModelServer:
8193
def _upload_server_artifacts(
8294
self,
8395
model_path: str,
96+
secret_key: str,
8497
sagemaker_session: Session,
8598
s3_model_data_url: str = None,
8699
image: str = None,
@@ -127,6 +140,16 @@ def _upload_server_artifacts(
127140
else None
128141
)
129142

143+
if secret_key:
144+
env_vars = {
145+
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
146+
"SAGEMAKER_PROGRAM": "inference.py",
147+
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
148+
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
149+
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
150+
"LOCAL_PYTHON": platform.python_version(),
151+
}
152+
130153
return model_data, _update_env_vars(env_vars)
131154

132155

src/sagemaker/serve/spec/inference_spec.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@ def invoke(self, input_object: object, model: object):
3030

3131
def prepare(self, *args, **kwargs):
3232
"""Custom prepare function"""
33+
34+
def get_model(self):
35+
"""Return HuggingFace model name for inference spec"""

0 commit comments

Comments
 (0)