Skip to content

Commit d3b8e9b

Browse files
author
Bryannah Hernandez
committed
mb_inprocess updates
1 parent 3576ea9 commit d3b8e9b

File tree

3 files changed

+164
-4
lines changed

3 files changed

+164
-4
lines changed

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
"""Transformers build logic with model builder"""
1414
from __future__ import absolute_import
1515
import logging
16+
import os
1617
from abc import ABC, abstractmethod
1718
from typing import Type
1819
from packaging.version import Version
1920

21+
from pathlib import Path
22+
2023
from sagemaker.model import Model
2124
from sagemaker import image_uris
2225
from sagemaker.serve.utils.local_hardware import (
@@ -26,7 +29,12 @@
2629
from sagemaker.huggingface import HuggingFaceModel
2730
from sagemaker.serve.model_server.multi_model_server.prepare import (
2831
_create_dir_structure,
32+
prepare_for_mms
33+
)
34+
from sagemaker.serve.detector.image_detector import (
35+
auto_detect_container,
2936
)
37+
from sagemaker.serve.detector.pickler import save_pkl
3038
from sagemaker.serve.utils.predictors import TransformersLocalModePredictor
3139
from sagemaker.serve.utils.types import ModelServer
3240
from sagemaker.serve.mode.function_pointers import Mode
@@ -73,6 +81,7 @@ def __init__(self):
7381
self.instance_type = None
7482
self.schema_builder = None
7583
self.inference_spec = None
84+
self.shared_libs = None
7685

7786
@abstractmethod
7887
def _prepare_for_mode(self):
@@ -298,6 +307,40 @@ def _get_supported_version(self, hf_config, hugging_face_version, base_fw):
298307
versions_to_return.append(base_fw_version)
299308
return sorted(versions_to_return, reverse=True)[0]
300309

310+
def _auto_detect_container(self):
311+
"""Placeholder docstring"""
312+
# Auto detect the container image uri
313+
if self.image_uri:
314+
logger.info(
315+
"Skipping auto detection as the image uri is provided %s",
316+
self.image_uri,
317+
)
318+
return
319+
320+
if self.model:
321+
logger.info(
322+
"Auto detect container url for the provided model and on instance %s",
323+
self.instance_type,
324+
)
325+
self.image_uri = auto_detect_container(
326+
self.model, self.sagemaker_session.boto_region_name, self.instance_type
327+
)
328+
329+
elif self.inference_spec:
330+
# TODO: this won't work for larger image.
331+
# Fail and let the customer include the image uri
332+
logger.warning(
333+
"model_path provided with no image_uri. Attempting to autodetect the image\
334+
by loading the model using inference_spec.load()..."
335+
)
336+
self.image_uri = auto_detect_container(
337+
self.inference_spec.load(self.model_path),
338+
self.sagemaker_session.boto_region_name,
339+
self.instance_type,
340+
)
341+
else:
342+
raise ValueError("Cannot detect required model or inference spec")
343+
301344
def _build_for_transformers(self):
302345
"""Method that triggers model build
303346
@@ -306,6 +349,26 @@ def _build_for_transformers(self):
306349
self.secret_key = None
307350
self.model_server = ModelServer.MMS
308351

352+
if not os.path.exists(self.model_path):
353+
os.makedirs(self.model_path)
354+
355+
code_path = Path(self.model_path).joinpath("code")
356+
# save the model or inference spec in cloud pickle format
357+
if self.inference_spec:
358+
save_pkl(code_path, (self.inference_spec, self.schema_builder))
359+
logger.info("PKL file saved to file: {}".format(code_path))
360+
361+
self._auto_detect_container()
362+
363+
self.secret_key = prepare_for_mms(
364+
model_path=self.model_path,
365+
shared_libs=self.shared_libs,
366+
dependencies=self.dependencies,
367+
session=self.sagemaker_session,
368+
image_uri=self.image_uri,
369+
inference_spec=self.inference_spec,
370+
)
371+
309372
self._build_transformers_env()
310373

311374
return self.pysdk_model

src/sagemaker/serve/model_server/multi_model_server/inference.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
"""This module is for SageMaker inference.py."""
22

33
from __future__ import absolute_import
4+
import os
45
import io
56
import cloudpickle
67
import shutil
8+
import platform
79
from pathlib import Path
810
from functools import partial
911
from sagemaker.serve.spec.inference_spec import InferenceSpec
12+
from sagemaker.serve.validations.check_integrity import perform_integrity_check
1013
import logging
1114

1215
logger = logging.getLogger(__name__)
1316

1417
inference_spec = None
1518
schema_builder = None
19+
SHARED_LIBS_DIR = Path(__file__).parent.parent.joinpath("shared_libs")
20+
SERVE_PATH = Path(__file__).parent.joinpath("serve.pkl")
21+
METADATA_PATH = Path(__file__).parent.joinpath("metadata.json")
1622

1723

1824
def model_fn(model_dir):
@@ -27,9 +33,11 @@ def model_fn(model_dir):
2733
with open(str(serve_path), mode="rb") as file:
2834
global inference_spec, schema_builder
2935
obj = cloudpickle.load(file)
30-
if isinstance(obj[0], InferenceSpec):
36+
if isinstance(obj[0], InferenceSpec):
3137
inference_spec, schema_builder = obj
32-
38+
39+
logger.info("in model_fn")
40+
3341
if inference_spec:
3442
return partial(inference_spec.invoke, model=inference_spec.load(model_dir))
3543

@@ -46,11 +54,13 @@ def input_fn(input_data, content_type):
4654
io.BytesIO(input_data), content_type[0]
4755
)
4856
except Exception as e:
57+
logger.error("Encountered error: %s in deserialize_response." % e)
4958
raise Exception("Encountered error in deserialize_request.") from e
5059

5160

5261
def predict_fn(input_data, predict_callable):
5362
"""Placeholder docstring"""
63+
logger.info("in predict_fn")
5464
return predict_callable(input_data)
5565

5666

@@ -66,3 +76,28 @@ def output_fn(predictions, accept_type):
6676
raise Exception("Encountered error in serialize_response.") from e
6777

6878

79+
def _run_preflight_diagnostics():
80+
_py_vs_parity_check()
81+
_pickle_file_integrity_check()
82+
83+
84+
def _py_vs_parity_check():
85+
container_py_vs = platform.python_version()
86+
local_py_vs = os.getenv("LOCAL_PYTHON")
87+
88+
if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]:
89+
logger.warning(
90+
f"The local python version {local_py_vs} differs from the python version "
91+
f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior"
92+
)
93+
94+
95+
def _pickle_file_integrity_check():
96+
with open(SERVE_PATH, "rb") as f:
97+
buffer = f.read()
98+
99+
perform_integrity_check(buffer=buffer, metadata_path=METADATA_PATH)
100+
101+
102+
# on import, execute
103+
_run_preflight_diagnostics()

src/sagemaker/serve/model_server/multi_model_server/prepare.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,23 @@
1414

1515
from __future__ import absolute_import
1616
import logging
17-
from pathlib import Path
18-
from typing import List
1917

2018
from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts
2119
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
2220

21+
from pathlib import Path
22+
import shutil
23+
from typing import List
24+
25+
from sagemaker.session import Session
26+
from sagemaker.serve.spec.inference_spec import InferenceSpec
27+
from sagemaker.serve.detector.dependency_manager import capture_dependencies
28+
from sagemaker.serve.validations.check_integrity import (
29+
generate_secret_key,
30+
compute_hash,
31+
)
32+
from sagemaker.remote_function.core.serialization import _MetaData
33+
2334
logger = logging.getLogger(__name__)
2435

2536

@@ -63,3 +74,54 @@ def prepare_mms_js_resources(
6374
model_path, code_dir = _create_dir_structure(model_path)
6475

6576
return _copy_jumpstart_artifacts(model_data, js_id, code_dir)
77+
78+
def prepare_for_mms(
79+
model_path: str,
80+
shared_libs: List[str],
81+
dependencies: dict,
82+
session: Session,
83+
image_uri: str,
84+
inference_spec: InferenceSpec = None,
85+
) -> str:
86+
"""This is a one-line summary of the function.
87+
Args:to
88+
model_path (str) : Argument
89+
shared_libs (List[]) : Argument
90+
dependencies (dict) : Argument
91+
session (Session) : Argument
92+
inference_spec (InferenceSpec, optional) : Argument
93+
(default is None)
94+
Returns:
95+
( str ) :
96+
"""
97+
model_path = Path(model_path)
98+
if not model_path.exists():
99+
model_path.mkdir()
100+
elif not model_path.is_dir():
101+
raise Exception("model_dir is not a valid directory")
102+
103+
if inference_spec:
104+
inference_spec.prepare(str(model_path))
105+
106+
code_dir = model_path.joinpath("code")
107+
code_dir.mkdir(exist_ok=True)
108+
109+
shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir)
110+
111+
logger.info("Finished writing inference.py to code directory")
112+
113+
shared_libs_dir = model_path.joinpath("shared_libs")
114+
shared_libs_dir.mkdir(exist_ok=True)
115+
for shared_lib in shared_libs:
116+
shutil.copy2(Path(shared_lib), shared_libs_dir)
117+
118+
capture_dependencies(dependencies=dependencies, work_dir=code_dir)
119+
120+
secret_key = generate_secret_key()
121+
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
122+
buffer = f.read()
123+
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
124+
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
125+
metadata.write(_MetaData(hash_value).to_json())
126+
127+
return secret_key

0 commit comments

Comments
 (0)