Skip to content

feat: InferenceSpec support for MMS and testing #4763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 72 additions & 4 deletions src/sagemaker/serve/builder/transformers_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
"""Transformers build logic with model builder"""
from __future__ import absolute_import
import logging
import os
from abc import ABC, abstractmethod
from typing import Type
from packaging.version import Version

from pathlib import Path

from sagemaker.model import Model
from sagemaker import image_uris
from sagemaker.serve.utils.local_hardware import (
Expand All @@ -26,7 +29,12 @@
from sagemaker.huggingface import HuggingFaceModel
from sagemaker.serve.model_server.multi_model_server.prepare import (
_create_dir_structure,
prepare_for_mms,
)
from sagemaker.serve.detector.image_detector import (
auto_detect_container,
)
from sagemaker.serve.detector.pickler import save_pkl
from sagemaker.serve.utils.predictors import TransformersLocalModePredictor
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.mode.function_pointers import Mode
Expand Down Expand Up @@ -72,6 +80,8 @@ def __init__(self):
self.pytorch_version = None
self.instance_type = None
self.schema_builder = None
self.inference_spec = None
self.shared_libs = None

@abstractmethod
def _prepare_for_mode(self):
Expand Down Expand Up @@ -109,7 +119,7 @@ def _get_hf_metadata_create_model(self) -> Type[Model]:
"""

hf_model_md = get_huggingface_model_metadata(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
Copy link
Contributor

@grenmester grenmester Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if model is set with an HF model ID but the environment variable doesn't exist? does the environment variable get set based on self.model before this line is executed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
hf_config = image_uris.config_for_framework("huggingface").get("inference")
config = hf_config["versions"]
Expand Down Expand Up @@ -246,18 +256,22 @@ def _build_transformers_env(self):

_create_dir_structure(self.model_path)
if not hasattr(self, "pysdk_model"):
self.env_vars.update({"HF_MODEL_ID": self.model})

if self.inference_spec is not None:
self.env_vars.update({"HF_MODEL_ID": self.inference_spec.get_model()})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that customer needs to implement get_model right? Do we need to update InferenceSpec

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will make that update now, good catch

else:
self.env_vars.update({"HF_MODEL_ID": self.model})

logger.info(self.env_vars)

# TODO: Move to a helper function
if hasattr(self.env_vars, "HF_API_TOKEN"):
self.hf_model_config = _get_model_config_properties_from_hf(
self.model, self.env_vars.get("HF_API_TOKEN")
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HF_API_TOKEN")
)
else:
self.hf_model_config = _get_model_config_properties_from_hf(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)

self.pysdk_model = self._create_transformers_model()
Expand Down Expand Up @@ -293,6 +307,40 @@ def _get_supported_version(self, hf_config, hugging_face_version, base_fw):
versions_to_return.append(base_fw_version)
return sorted(versions_to_return, reverse=True)[0]

def _auto_detect_container(self):
"""Placeholder docstring"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace with actual docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced, thank you

# Auto detect the container image uri
if self.image_uri:
logger.info(
"Skipping auto detection as the image uri is provided %s",
self.image_uri,
)
return

if self.model:
logger.info(
"Auto detect container url for the provided model and on instance %s",
self.instance_type,
)
self.image_uri = auto_detect_container(
self.model, self.sagemaker_session.boto_region_name, self.instance_type
)

elif self.inference_spec:
# TODO: this won't work for larger image.
# Fail and let the customer include the image uri
logger.warning(
"model_path provided with no image_uri. Attempting to autodetect the image\
by loading the model using inference_spec.load()..."
)
self.image_uri = auto_detect_container(
self.inference_spec.load(self.model_path),
self.sagemaker_session.boto_region_name,
self.instance_type,
)
else:
raise ValueError("Cannot detect required model or inference spec")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add more details on how to fix this error. Like what parameter does the customer need to pass to fix this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it has just been updated. Thank you.


def _build_for_transformers(self):
"""Method that triggers model build

Expand All @@ -301,6 +349,26 @@ def _build_for_transformers(self):
self.secret_key = None
self.model_server = ModelServer.MMS

if not os.path.exists(self.model_path):
os.makedirs(self.model_path)

code_path = Path(self.model_path).joinpath("code")
# save the model or inference spec in cloud pickle format
if self.inference_spec:
save_pkl(code_path, (self.inference_spec, self.schema_builder))
logger.info("PKL file saved to file: {}".format(code_path))

self._auto_detect_container()

self.secret_key = prepare_for_mms(
model_path=self.model_path,
shared_libs=self.shared_libs,
dependencies=self.dependencies,
session=self.sagemaker_session,
image_uri=self.image_uri,
inference_spec=self.inference_spec,
)

self._build_transformers_env()

return self.pysdk_model
103 changes: 103 additions & 0 deletions src/sagemaker/serve/model_server/multi_model_server/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""This module is for SageMaker inference.py."""

from __future__ import absolute_import
import os
import io
import cloudpickle
import shutil
import platform
from pathlib import Path
from functools import partial
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.validations.check_integrity import perform_integrity_check
import logging

logger = logging.getLogger(__name__)

inference_spec = None
schema_builder = None
SHARED_LIBS_DIR = Path(__file__).parent.parent.joinpath("shared_libs")
SERVE_PATH = Path(__file__).parent.joinpath("serve.pkl")
METADATA_PATH = Path(__file__).parent.joinpath("metadata.json")


def model_fn(model_dir):
"""Placeholder docstring"""
shared_libs_path = Path(model_dir + "/shared_libs")

if shared_libs_path.exists():
# before importing, place dynamic linked libraries in shared lib path
shutil.copytree(shared_libs_path, "/lib", dirs_exist_ok=True)

serve_path = Path(__file__).parent.joinpath("serve.pkl")
with open(str(serve_path), mode="rb") as file:
global inference_spec, schema_builder
obj = cloudpickle.load(file)
if isinstance(obj[0], InferenceSpec):
inference_spec, schema_builder = obj

logger.info("in model_fn")
Copy link
Collaborator

@samruds samruds Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This log statement can be removed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has been removed


if inference_spec:
return partial(inference_spec.invoke, model=inference_spec.load(model_dir))


def input_fn(input_data, content_type):
"""Placeholder docstring"""
try:
if hasattr(schema_builder, "custom_input_translator"):
return schema_builder.custom_input_translator.deserialize(
io.BytesIO(input_data), content_type
)
else:
return schema_builder.input_deserializer.deserialize(
io.BytesIO(input_data), content_type[0]
)
except Exception as e:
logger.error("Encountered error: %s in deserialize_response." % e)
raise Exception("Encountered error in deserialize_request.") from e


def predict_fn(input_data, predict_callable):
"""Placeholder docstring"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update docstring here with what the method does?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, thank you

logger.info("in predict_fn")
Copy link
Collaborator

@samruds samruds Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This log statement can be removed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's now removed

return predict_callable(input_data)


def output_fn(predictions, accept_type):
"""Placeholder docstring"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Update doc string

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's now updated

try:
if hasattr(schema_builder, "custom_output_translator"):
return schema_builder.custom_output_translator.serialize(predictions, accept_type)
else:
return schema_builder.output_serializer.serialize(predictions)
except Exception as e:
logger.error("Encountered error: %s in serialize_response." % e)
raise Exception("Encountered error in serialize_response.") from e


def _run_preflight_diagnostics():
_py_vs_parity_check()
_pickle_file_integrity_check()


def _py_vs_parity_check():
container_py_vs = platform.python_version()
local_py_vs = os.getenv("LOCAL_PYTHON")

if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]:
logger.warning(
f"The local python version {local_py_vs} differs from the python version "
f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior"
)


def _pickle_file_integrity_check():
with open(SERVE_PATH, "rb") as f:
buffer = f.read()

perform_integrity_check(buffer=buffer, metadata_path=METADATA_PATH)


# on import, execute
_run_preflight_diagnostics()
67 changes: 65 additions & 2 deletions src/sagemaker/serve/model_server/multi_model_server/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,23 @@

from __future__ import absolute_import
import logging
from pathlib import Path
from typing import List

from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage

from pathlib import Path
import shutil
from typing import List

from sagemaker.session import Session
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)
from sagemaker.remote_function.core.serialization import _MetaData

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -63,3 +74,55 @@ def prepare_mms_js_resources(
model_path, code_dir = _create_dir_structure(model_path)

return _copy_jumpstart_artifacts(model_data, js_id, code_dir)


def prepare_for_mms(
model_path: str,
shared_libs: List[str],
dependencies: dict,
session: Session,
image_uri: str,
inference_spec: InferenceSpec = None,
) -> str:
"""Prepares for InferenceSpec using model_path, writes inference.py, and captures dependencies to generate secret_key.
Args:to
model_path (str) : Argument
shared_libs (List[]) : Argument
dependencies (dict) : Argument
session (Session) : Argument
inference_spec (InferenceSpec, optional) : Argument
(default is None)
Returns:
( str ) : secret_key
"""
model_path = Path(model_path)
if not model_path.exists():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: similar comment

model_path.mkdir()
elif not model_path.is_dir():
raise Exception("model_dir is not a valid directory")

if inference_spec:
inference_spec.prepare(str(model_path))

code_dir = model_path.joinpath("code")
code_dir.mkdir(exist_ok=True)

shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir)

logger.info("Finished writing inference.py to code directory")

shared_libs_dir = model_path.joinpath("shared_libs")
shared_libs_dir.mkdir(exist_ok=True)
for shared_lib in shared_libs:
shutil.copy2(Path(shared_lib), shared_libs_dir)

capture_dependencies(dependencies=dependencies, work_dir=code_dir)

secret_key = generate_secret_key()
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

return secret_key
18 changes: 17 additions & 1 deletion src/sagemaker/serve/model_server/multi_model_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import requests
import logging
import platform
from pathlib import Path
from sagemaker import Session, fw_utils
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
Expand Down Expand Up @@ -42,7 +43,13 @@ def _start_serving(
"mode": "rw",
},
},
environment=_update_env_vars(env_vars),
environment={
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
**env_vars,
},
)

def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
Expand Down Expand Up @@ -116,6 +123,15 @@ def _upload_server_artifacts(
"S3Uri": model_data_url + "/",
}
}

env_vars = {
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
"LOCAL_PYTHON": platform.python_version(),
}

return model_data, _update_env_vars(env_vars)


Expand Down
Loading
Loading