Skip to content

Commit aa4a62e

Browse files
author
Bryannah Hernandez
committed
feat: InferenceSpec support for MMS and testing
1 parent ce43606 commit aa4a62e

File tree

6 files changed

+312
-9
lines changed

6 files changed

+312
-9
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -881,8 +881,8 @@ def _build_for_model_server(self): # pylint: disable=R0911, R1710
881881
if self.model_metadata:
882882
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
883883

884-
if not self.model and not mlflow_path:
885-
raise ValueError("Missing required parameter `model` or 'ml_flow' path")
884+
if not self.model and not mlflow_path and not self.inference_spec:
885+
raise ValueError("Missing required parameter `model` or 'ml_flow' path or inf_spec")
886886

887887
if self.model_server == ModelServer.TORCHSERVE:
888888
return self._build_for_torchserve()

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
"""Transformers build logic with model builder"""
1414
from __future__ import absolute_import
1515
import logging
16+
import os
1617
from abc import ABC, abstractmethod
1718
from typing import Type
1819
from packaging.version import Version
1920

21+
from pathlib import Path
22+
2023
from sagemaker.model import Model
2124
from sagemaker import image_uris
2225
from sagemaker.serve.utils.local_hardware import (
@@ -26,7 +29,12 @@
2629
from sagemaker.huggingface import HuggingFaceModel
2730
from sagemaker.serve.model_server.multi_model_server.prepare import (
2831
_create_dir_structure,
32+
prepare_for_mms
33+
)
34+
from sagemaker.serve.detector.image_detector import (
35+
auto_detect_container,
2936
)
37+
from sagemaker.serve.detector.pickler import save_pkl
3038
from sagemaker.serve.utils.predictors import TransformersLocalModePredictor
3139
from sagemaker.serve.utils.types import ModelServer
3240
from sagemaker.serve.mode.function_pointers import Mode
@@ -72,6 +80,8 @@ def __init__(self):
7280
self.pytorch_version = None
7381
self.instance_type = None
7482
self.schema_builder = None
83+
self.inference_spec = None
84+
self.shared_libs = None
7585

7686
@abstractmethod
7787
def _prepare_for_mode(self):
@@ -109,7 +119,7 @@ def _get_hf_metadata_create_model(self) -> Type[Model]:
109119
"""
110120

111121
hf_model_md = get_huggingface_model_metadata(
112-
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
122+
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
113123
)
114124
hf_config = image_uris.config_for_framework("huggingface").get("inference")
115125
config = hf_config["versions"]
@@ -246,25 +256,31 @@ def _build_transformers_env(self):
246256

247257
_create_dir_structure(self.model_path)
248258
if not hasattr(self, "pysdk_model"):
249-
self.env_vars.update({"HF_MODEL_ID": self.model})
259+
260+
if self.inference_spec is not None:
261+
self.env_vars.update({"HF_MODEL_ID": self.inference_spec.get_model()})
262+
else:
263+
self.env_vars.update({"HF_MODEL_ID": self.model})
250264

251265
logger.info(self.env_vars)
252266

253267
# TODO: Move to a helper function
254268
if hasattr(self.env_vars, "HF_API_TOKEN"):
255269
self.hf_model_config = _get_model_config_properties_from_hf(
256-
self.model, self.env_vars.get("HF_API_TOKEN")
270+
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HF_API_TOKEN")
257271
)
258272
else:
259273
self.hf_model_config = _get_model_config_properties_from_hf(
260-
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
274+
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
261275
)
262276

263277
self.pysdk_model = self._create_transformers_model()
264278

265279
if self.mode == Mode.LOCAL_CONTAINER:
266280
self._prepare_for_mode()
267281

282+
logger.info("Model configuration %s", self.pysdk_model)
283+
268284
return self.pysdk_model
269285

270286
def _set_instance(self, **kwargs):
@@ -293,6 +309,41 @@ def _get_supported_version(self, hf_config, hugging_face_version, base_fw):
293309
versions_to_return.append(base_fw_version)
294310
return sorted(versions_to_return, reverse=True)[0]
295311

312+
def _auto_detect_container(self):
313+
"""Placeholder docstring"""
314+
# Auto detect the container image uri
315+
if self.image_uri:
316+
logger.info(
317+
"Skipping auto detection as the image uri is provided %s",
318+
self.image_uri,
319+
)
320+
return
321+
322+
if self.model:
323+
logger.info(
324+
"Auto detect container url for the provided model and on instance %s",
325+
self.instance_type,
326+
)
327+
self.image_uri = auto_detect_container(
328+
self.model, self.sagemaker_session.boto_region_name, self.instance_type
329+
)
330+
331+
elif self.inference_spec:
332+
# TODO: this won't work for larger image.
333+
# Fail and let the customer include the image uri
334+
logger.warning(
335+
"model_path provided with no image_uri. Attempting to autodetect the image\
336+
by loading the model using inference_spec.load()..."
337+
)
338+
self.image_uri = auto_detect_container(
339+
self.inference_spec.load(self.model_path),
340+
self.sagemaker_session.boto_region_name,
341+
self.instance_type,
342+
)
343+
else:
344+
raise ValueError("Cannot detect required model or inference spec")
345+
346+
296347
def _build_for_transformers(self):
297348
"""Method that triggers model build
298349
@@ -301,6 +352,26 @@ def _build_for_transformers(self):
301352
self.secret_key = None
302353
self.model_server = ModelServer.MMS
303354

355+
if not os.path.exists(self.model_path):
356+
os.makedirs(self.model_path)
357+
358+
code_path = Path(self.model_path).joinpath("code")
359+
# save the model or inference spec in cloud pickle format
360+
if self.inference_spec:
361+
save_pkl(code_path, (self.inference_spec, self.schema_builder))
362+
logger.info("PKL file saved to file: {}".format(code_path))
363+
364+
self._auto_detect_container()
365+
366+
self.secret_key = prepare_for_mms(
367+
model_path=self.model_path,
368+
shared_libs=self.shared_libs,
369+
dependencies=self.dependencies,
370+
session=self.sagemaker_session,
371+
image_uri=self.image_uri,
372+
inference_spec=self.inference_spec,
373+
)
374+
304375
self._build_transformers_env()
305376

306377
return self.pysdk_model
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""This module is for SageMaker inference.py."""
2+
3+
from __future__ import absolute_import
4+
import os
5+
import io
6+
import cloudpickle
7+
import shutil
8+
import platform
9+
from pathlib import Path
10+
from functools import partial
11+
from sagemaker.serve.spec.inference_spec import InferenceSpec
12+
from sagemaker.serve.validations.check_integrity import perform_integrity_check
13+
import logging
14+
15+
logger = logging.getLogger(__name__)
16+
17+
inference_spec = None
18+
schema_builder = None
19+
SHARED_LIBS_DIR = Path(__file__).parent.parent.joinpath("shared_libs")
20+
SERVE_PATH = Path(__file__).parent.joinpath("serve.pkl")
21+
METADATA_PATH = Path(__file__).parent.joinpath("metadata.json")
22+
23+
24+
def model_fn(model_dir):
25+
"""Placeholder docstring"""
26+
shared_libs_path = Path(model_dir + "/shared_libs")
27+
28+
if shared_libs_path.exists():
29+
# before importing, place dynamic linked libraries in shared lib path
30+
shutil.copytree(shared_libs_path, "/lib", dirs_exist_ok=True)
31+
32+
serve_path = Path(__file__).parent.joinpath("serve.pkl")
33+
with open(str(serve_path), mode="rb") as file:
34+
global inference_spec, schema_builder
35+
obj = cloudpickle.load(file)
36+
if isinstance(obj[0], InferenceSpec):
37+
inference_spec, schema_builder = obj
38+
39+
logger.info("in model_fn")
40+
41+
if inference_spec:
42+
return partial(inference_spec.invoke, model=inference_spec.load(model_dir))
43+
44+
45+
def input_fn(input_data, content_type):
46+
"""Placeholder docstring"""
47+
try:
48+
if hasattr(schema_builder, "custom_input_translator"):
49+
return schema_builder.custom_input_translator.deserialize(
50+
io.BytesIO(input_data), content_type
51+
)
52+
else:
53+
return schema_builder.input_deserializer.deserialize(
54+
io.BytesIO(input_data), content_type[0]
55+
)
56+
except Exception as e:
57+
logger.error("Encountered error: %s in deserialize_response." % e)
58+
raise Exception("Encountered error in deserialize_request.") from e
59+
60+
61+
def predict_fn(input_data, predict_callable):
62+
"""Placeholder docstring"""
63+
logger.info("in predict_fn")
64+
return predict_callable(input_data)
65+
66+
67+
def output_fn(predictions, accept_type):
68+
"""Placeholder docstring"""
69+
try:
70+
if hasattr(schema_builder, "custom_output_translator"):
71+
return schema_builder.custom_output_translator.serialize(predictions, accept_type)
72+
else:
73+
return schema_builder.output_serializer.serialize(predictions)
74+
except Exception as e:
75+
logger.error("Encountered error: %s in serialize_response." % e)
76+
raise Exception("Encountered error in serialize_response.") from e
77+
78+
79+
def _run_preflight_diagnostics():
80+
_py_vs_parity_check()
81+
_pickle_file_integrity_check()
82+
83+
84+
def _py_vs_parity_check():
85+
container_py_vs = platform.python_version()
86+
local_py_vs = os.getenv("LOCAL_PYTHON")
87+
88+
if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]:
89+
logger.warning(
90+
f"The local python version {local_py_vs} differs from the python version "
91+
f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior"
92+
)
93+
94+
95+
def _pickle_file_integrity_check():
96+
with open(SERVE_PATH, "rb") as f:
97+
buffer = f.read()
98+
99+
perform_integrity_check(buffer=buffer, metadata_path=METADATA_PATH)
100+
101+
102+
# on import, execute
103+
_run_preflight_diagnostics()

src/sagemaker/serve/model_server/multi_model_server/prepare.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,23 @@
1414

1515
from __future__ import absolute_import
1616
import logging
17-
from pathlib import Path
18-
from typing import List
1917

2018
from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts
2119
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
2220

21+
from pathlib import Path
22+
import shutil
23+
from typing import List
24+
25+
from sagemaker.session import Session
26+
from sagemaker.serve.spec.inference_spec import InferenceSpec
27+
from sagemaker.serve.detector.dependency_manager import capture_dependencies
28+
from sagemaker.serve.validations.check_integrity import (
29+
generate_secret_key,
30+
compute_hash,
31+
)
32+
from sagemaker.remote_function.core.serialization import _MetaData
33+
2334
logger = logging.getLogger(__name__)
2435

2536

@@ -63,3 +74,54 @@ def prepare_mms_js_resources(
6374
model_path, code_dir = _create_dir_structure(model_path)
6475

6576
return _copy_jumpstart_artifacts(model_data, js_id, code_dir)
77+
78+
def prepare_for_mms(
79+
model_path: str,
80+
shared_libs: List[str],
81+
dependencies: dict,
82+
session: Session,
83+
image_uri: str,
84+
inference_spec: InferenceSpec = None,
85+
) -> str:
86+
"""This is a one-line summary of the function.
87+
Args:to
88+
model_path (str) : Argument
89+
shared_libs (List[]) : Argument
90+
dependencies (dict) : Argument
91+
session (Session) : Argument
92+
inference_spec (InferenceSpec, optional) : Argument
93+
(default is None)
94+
Returns:
95+
( str ) :
96+
"""
97+
model_path = Path(model_path)
98+
if not model_path.exists():
99+
model_path.mkdir()
100+
elif not model_path.is_dir():
101+
raise Exception("model_dir is not a valid directory")
102+
103+
if inference_spec:
104+
inference_spec.prepare(str(model_path))
105+
106+
code_dir = model_path.joinpath("code")
107+
code_dir.mkdir(exist_ok=True)
108+
109+
shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir)
110+
111+
logger.info("Finished writing inference.py to code directory")
112+
113+
shared_libs_dir = model_path.joinpath("shared_libs")
114+
shared_libs_dir.mkdir(exist_ok=True)
115+
for shared_lib in shared_libs:
116+
shutil.copy2(Path(shared_lib), shared_libs_dir)
117+
118+
capture_dependencies(dependencies=dependencies, work_dir=code_dir)
119+
120+
secret_key = generate_secret_key()
121+
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
122+
buffer = f.read()
123+
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
124+
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
125+
metadata.write(_MetaData(hash_value).to_json())
126+
127+
return secret_key

src/sagemaker/serve/model_server/multi_model_server/server.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import requests
66
import logging
7+
import platform
78
from pathlib import Path
89
from sagemaker import Session, fw_utils
910
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
@@ -42,7 +43,13 @@ def _start_serving(
4243
"mode": "rw",
4344
},
4445
},
45-
environment=_update_env_vars(env_vars),
46+
environment={
47+
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
48+
"SAGEMAKER_PROGRAM": "inference.py",
49+
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
50+
"LOCAL_PYTHON": platform.python_version(),
51+
**env_vars,
52+
},
4653
)
4754

4855
def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
@@ -116,6 +123,15 @@ def _upload_server_artifacts(
116123
"S3Uri": model_data_url + "/",
117124
}
118125
}
126+
127+
env_vars = {
128+
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
129+
"SAGEMAKER_PROGRAM": "inference.py",
130+
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
131+
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
132+
"LOCAL_PYTHON": platform.python_version(),
133+
}
134+
119135
return model_data, _update_env_vars(env_vars)
120136

121137

0 commit comments

Comments
 (0)