-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: InferenceSpec support for MMS and testing #4763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
aa4a62e
895fcf2
e6209de
c3aea27
b7417eb
c2cb579
54241b4
ca26cd2
4800eb1
8a2d524
368a3d5
3e133b2
3e0a98a
d201b41
72055b9
5c9fd74
99d3123
52f7fc6
a5a7d3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,10 +13,13 @@ | |
"""Transformers build logic with model builder""" | ||
from __future__ import absolute_import | ||
import logging | ||
import os | ||
from abc import ABC, abstractmethod | ||
from typing import Type | ||
from packaging.version import Version | ||
|
||
from pathlib import Path | ||
|
||
from sagemaker.model import Model | ||
from sagemaker import image_uris | ||
from sagemaker.serve.utils.local_hardware import ( | ||
|
@@ -26,7 +29,12 @@ | |
from sagemaker.huggingface import HuggingFaceModel | ||
from sagemaker.serve.model_server.multi_model_server.prepare import ( | ||
_create_dir_structure, | ||
prepare_for_mms, | ||
) | ||
from sagemaker.serve.detector.image_detector import ( | ||
auto_detect_container, | ||
) | ||
from sagemaker.serve.detector.pickler import save_pkl | ||
from sagemaker.serve.utils.predictors import TransformersLocalModePredictor | ||
from sagemaker.serve.utils.types import ModelServer | ||
from sagemaker.serve.mode.function_pointers import Mode | ||
|
@@ -72,6 +80,8 @@ def __init__(self): | |
self.pytorch_version = None | ||
self.instance_type = None | ||
self.schema_builder = None | ||
self.inference_spec = None | ||
self.shared_libs = None | ||
|
||
@abstractmethod | ||
def _prepare_for_mode(self): | ||
|
@@ -109,7 +119,7 @@ def _get_hf_metadata_create_model(self) -> Type[Model]: | |
""" | ||
|
||
hf_model_md = get_huggingface_model_metadata( | ||
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") | ||
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HUGGING_FACE_HUB_TOKEN") | ||
) | ||
hf_config = image_uris.config_for_framework("huggingface").get("inference") | ||
config = hf_config["versions"] | ||
|
@@ -246,18 +256,22 @@ def _build_transformers_env(self): | |
|
||
_create_dir_structure(self.model_path) | ||
if not hasattr(self, "pysdk_model"): | ||
self.env_vars.update({"HF_MODEL_ID": self.model}) | ||
|
||
if self.inference_spec is not None: | ||
self.env_vars.update({"HF_MODEL_ID": self.inference_spec.get_model()}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This means that customer needs to implement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will make that update now, good catch |
||
else: | ||
self.env_vars.update({"HF_MODEL_ID": self.model}) | ||
|
||
logger.info(self.env_vars) | ||
|
||
# TODO: Move to a helper function | ||
if hasattr(self.env_vars, "HF_API_TOKEN"): | ||
self.hf_model_config = _get_model_config_properties_from_hf( | ||
self.model, self.env_vars.get("HF_API_TOKEN") | ||
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HF_API_TOKEN") | ||
) | ||
else: | ||
self.hf_model_config = _get_model_config_properties_from_hf( | ||
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") | ||
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HUGGING_FACE_HUB_TOKEN") | ||
) | ||
|
||
self.pysdk_model = self._create_transformers_model() | ||
|
@@ -293,6 +307,40 @@ def _get_supported_version(self, hf_config, hugging_face_version, base_fw): | |
versions_to_return.append(base_fw_version) | ||
return sorted(versions_to_return, reverse=True)[0] | ||
|
||
def _auto_detect_container(self): | ||
"""Placeholder docstring""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. replace with actual docstring There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replaced, thank you |
||
# Auto detect the container image uri | ||
if self.image_uri: | ||
logger.info( | ||
"Skipping auto detection as the image uri is provided %s", | ||
self.image_uri, | ||
) | ||
return | ||
|
||
if self.model: | ||
logger.info( | ||
"Auto detect container url for the provided model and on instance %s", | ||
self.instance_type, | ||
) | ||
self.image_uri = auto_detect_container( | ||
self.model, self.sagemaker_session.boto_region_name, self.instance_type | ||
) | ||
|
||
elif self.inference_spec: | ||
# TODO: this won't work for larger image. | ||
# Fail and let the customer include the image uri | ||
logger.warning( | ||
"model_path provided with no image_uri. Attempting to autodetect the image\ | ||
by loading the model using inference_spec.load()..." | ||
) | ||
self.image_uri = auto_detect_container( | ||
self.inference_spec.load(self.model_path), | ||
self.sagemaker_session.boto_region_name, | ||
self.instance_type, | ||
) | ||
else: | ||
raise ValueError("Cannot detect required model or inference spec") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add more details on how to fix this error. Like what parameter does the customer need to pass to fix this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it has just been updated. Thank you. |
||
|
||
def _build_for_transformers(self): | ||
"""Method that triggers model build | ||
|
||
|
@@ -301,6 +349,26 @@ def _build_for_transformers(self): | |
self.secret_key = None | ||
self.model_server = ModelServer.MMS | ||
|
||
if not os.path.exists(self.model_path): | ||
os.makedirs(self.model_path) | ||
|
||
code_path = Path(self.model_path).joinpath("code") | ||
# save the model or inference spec in cloud pickle format | ||
if self.inference_spec: | ||
save_pkl(code_path, (self.inference_spec, self.schema_builder)) | ||
logger.info("PKL file saved to file: {}".format(code_path)) | ||
|
||
self._auto_detect_container() | ||
|
||
self.secret_key = prepare_for_mms( | ||
model_path=self.model_path, | ||
shared_libs=self.shared_libs, | ||
dependencies=self.dependencies, | ||
session=self.sagemaker_session, | ||
image_uri=self.image_uri, | ||
inference_spec=self.inference_spec, | ||
) | ||
|
||
self._build_transformers_env() | ||
|
||
return self.pysdk_model |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
"""This module is for SageMaker inference.py.""" | ||
|
||
from __future__ import absolute_import | ||
import os | ||
import io | ||
import cloudpickle | ||
import shutil | ||
import platform | ||
from pathlib import Path | ||
from functools import partial | ||
from sagemaker.serve.spec.inference_spec import InferenceSpec | ||
from sagemaker.serve.validations.check_integrity import perform_integrity_check | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
inference_spec = None | ||
schema_builder = None | ||
SHARED_LIBS_DIR = Path(__file__).parent.parent.joinpath("shared_libs") | ||
SERVE_PATH = Path(__file__).parent.joinpath("serve.pkl") | ||
METADATA_PATH = Path(__file__).parent.joinpath("metadata.json") | ||
|
||
|
||
def model_fn(model_dir): | ||
"""Placeholder docstring""" | ||
shared_libs_path = Path(model_dir + "/shared_libs") | ||
|
||
if shared_libs_path.exists(): | ||
# before importing, place dynamic linked libraries in shared lib path | ||
shutil.copytree(shared_libs_path, "/lib", dirs_exist_ok=True) | ||
|
||
serve_path = Path(__file__).parent.joinpath("serve.pkl") | ||
with open(str(serve_path), mode="rb") as file: | ||
global inference_spec, schema_builder | ||
obj = cloudpickle.load(file) | ||
if isinstance(obj[0], InferenceSpec): | ||
inference_spec, schema_builder = obj | ||
|
||
logger.info("in model_fn") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This log statement can be removed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It has been removed |
||
|
||
if inference_spec: | ||
return partial(inference_spec.invoke, model=inference_spec.load(model_dir)) | ||
|
||
|
||
def input_fn(input_data, content_type): | ||
"""Placeholder docstring""" | ||
try: | ||
if hasattr(schema_builder, "custom_input_translator"): | ||
return schema_builder.custom_input_translator.deserialize( | ||
io.BytesIO(input_data), content_type | ||
) | ||
else: | ||
return schema_builder.input_deserializer.deserialize( | ||
io.BytesIO(input_data), content_type[0] | ||
) | ||
except Exception as e: | ||
logger.error("Encountered error: %s in deserialize_response." % e) | ||
raise Exception("Encountered error in deserialize_request.") from e | ||
|
||
|
||
def predict_fn(input_data, predict_callable): | ||
"""Placeholder docstring""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we update docstring here with what the method does? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated, thank you |
||
logger.info("in predict_fn") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This log statement can be removed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's now removed |
||
return predict_callable(input_data) | ||
|
||
|
||
def output_fn(predictions, accept_type): | ||
"""Placeholder docstring""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Update doc string There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's now updated |
||
try: | ||
if hasattr(schema_builder, "custom_output_translator"): | ||
return schema_builder.custom_output_translator.serialize(predictions, accept_type) | ||
else: | ||
return schema_builder.output_serializer.serialize(predictions) | ||
except Exception as e: | ||
logger.error("Encountered error: %s in serialize_response." % e) | ||
raise Exception("Encountered error in serialize_response.") from e | ||
|
||
|
||
def _run_preflight_diagnostics(): | ||
_py_vs_parity_check() | ||
_pickle_file_integrity_check() | ||
|
||
|
||
def _py_vs_parity_check(): | ||
container_py_vs = platform.python_version() | ||
local_py_vs = os.getenv("LOCAL_PYTHON") | ||
|
||
if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]: | ||
logger.warning( | ||
f"The local python version {local_py_vs} differs from the python version " | ||
f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior" | ||
) | ||
|
||
|
||
def _pickle_file_integrity_check(): | ||
with open(SERVE_PATH, "rb") as f: | ||
buffer = f.read() | ||
|
||
perform_integrity_check(buffer=buffer, metadata_path=METADATA_PATH) | ||
|
||
|
||
# on import, execute | ||
_run_preflight_diagnostics() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,12 +14,23 @@ | |
|
||
from __future__ import absolute_import | ||
import logging | ||
from pathlib import Path | ||
from typing import List | ||
|
||
from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts | ||
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage | ||
|
||
from pathlib import Path | ||
import shutil | ||
from typing import List | ||
|
||
from sagemaker.session import Session | ||
from sagemaker.serve.spec.inference_spec import InferenceSpec | ||
from sagemaker.serve.detector.dependency_manager import capture_dependencies | ||
from sagemaker.serve.validations.check_integrity import ( | ||
generate_secret_key, | ||
compute_hash, | ||
) | ||
from sagemaker.remote_function.core.serialization import _MetaData | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
|
@@ -63,3 +74,55 @@ def prepare_mms_js_resources( | |
model_path, code_dir = _create_dir_structure(model_path) | ||
|
||
return _copy_jumpstart_artifacts(model_data, js_id, code_dir) | ||
|
||
|
||
def prepare_for_mms( | ||
model_path: str, | ||
shared_libs: List[str], | ||
dependencies: dict, | ||
session: Session, | ||
image_uri: str, | ||
inference_spec: InferenceSpec = None, | ||
) -> str: | ||
"""Prepares for InferenceSpec using model_path, writes inference.py, and captures dependencies to generate secret_key. | ||
Args:to | ||
model_path (str) : Argument | ||
shared_libs (List[]) : Argument | ||
dependencies (dict) : Argument | ||
session (Session) : Argument | ||
inference_spec (InferenceSpec, optional) : Argument | ||
(default is None) | ||
Returns: | ||
( str ) : secret_key | ||
""" | ||
model_path = Path(model_path) | ||
if not model_path.exists(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: similar comment |
||
model_path.mkdir() | ||
elif not model_path.is_dir(): | ||
raise Exception("model_dir is not a valid directory") | ||
|
||
if inference_spec: | ||
inference_spec.prepare(str(model_path)) | ||
|
||
code_dir = model_path.joinpath("code") | ||
code_dir.mkdir(exist_ok=True) | ||
|
||
shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir) | ||
|
||
logger.info("Finished writing inference.py to code directory") | ||
|
||
shared_libs_dir = model_path.joinpath("shared_libs") | ||
shared_libs_dir.mkdir(exist_ok=True) | ||
for shared_lib in shared_libs: | ||
shutil.copy2(Path(shared_lib), shared_libs_dir) | ||
|
||
capture_dependencies(dependencies=dependencies, work_dir=code_dir) | ||
|
||
secret_key = generate_secret_key() | ||
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: | ||
buffer = f.read() | ||
hash_value = compute_hash(buffer=buffer, secret_key=secret_key) | ||
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: | ||
metadata.write(_MetaData(hash_value).to_json()) | ||
|
||
return secret_key |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens if
model
is set with an HF model ID but the environment variable doesn't exist? does the environment variable get set based onself.model
before this line is executed?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, HF Model ID gets set before the line is executed. This is where it is called:
https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/serve/builder/transformers_builder.py#L248-L261
https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/serve/builder/transformers_builder.py#L263
https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/serve/builder/transformers_builder.py#L83