Skip to content

feat: InferenceSpec support for DJL #4846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 75 additions & 4 deletions src/sagemaker/serve/builder/djl_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
"""Holds mixin logic to support deployment of Model ID"""
from __future__ import absolute_import
import logging
import os
from typing import Type
from pathlib import Path
from abc import ABC, abstractmethod
from datetime import datetime, timedelta

Expand Down Expand Up @@ -46,7 +48,12 @@
)
from sagemaker.serve.model_server.djl_serving.prepare import (
_create_dir_structure,
prepare_for_djl,
)
from sagemaker.serve.detector.image_detector import (
auto_detect_container,
)
from sagemaker.serve.detector.pickler import save_pkl
from sagemaker.serve.utils.predictors import DjlLocalModePredictor
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.mode.function_pointers import Mode
Expand Down Expand Up @@ -92,6 +99,8 @@ def __init__(self):
self.nb_instance_type = None
self.ram_usage_model_load = None
self.role_arn = None
self.inference_spec = None
self.shared_libs = None

@abstractmethod
def _prepare_for_mode(self):
Expand Down Expand Up @@ -247,17 +256,22 @@ def _build_for_hf_djl(self):

_create_dir_structure(self.model_path)
if not hasattr(self, "pysdk_model"):
self.env_vars.update({"HF_MODEL_ID": self.model})
if self.inference_spec is not None:
self.env_vars.update({"HF_MODEL_ID": self.inference_spec.get_model()})
else:
self.env_vars.update({"HF_MODEL_ID": self.model})

self.hf_model_config = _get_model_config_properties_from_hf(
self.model, self.env_vars.get("HF_TOKEN")
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HF_TOKEN")
)
default_djl_configurations, _default_max_new_tokens = _get_default_djl_configurations(
self.model, self.hf_model_config, self.schema_builder
self.env_vars.get("HF_MODEL_ID"), self.hf_model_config, self.schema_builder
)
self.env_vars.update(default_djl_configurations)
self.schema_builder.sample_input["parameters"][
"max_new_tokens"
] = _default_max_new_tokens

self.pysdk_model = self._create_djl_model()

if self.mode == Mode.LOCAL_CONTAINER:
Expand Down Expand Up @@ -445,10 +459,67 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800):

return self.pysdk_model

def _auto_detect_container(self):
"""Set image_uri by detecting container via model name or inference spec"""
# Auto detect the container image uri
if self.image_uri:
logger.info(
"Skipping auto detection as the image uri is provided %s",
self.image_uri,
)
return

if self.model:
logger.info(
"Auto detect container url for the provided model and on instance %s",
self.nb_instance_type,
)
self.image_uri = auto_detect_container(
self.model, self.sagemaker_session.boto_region_name, self.nb_instance_type
)

elif self.inference_spec:
# TODO: this won't work for larger image.
# Fail and let the customer include the image uri
logger.warning(
"model_path provided with no image_uri. Attempting to autodetect the image\
by loading the model using inference_spec.load()..."
)
self.image_uri = auto_detect_container(
self.inference_spec.load(self.model_path),
self.sagemaker_session.boto_region_name,
self.nb_instance_type,
)
else:
raise ValueError(
"Cannot detect and set image_uri. Please pass model or inference spec."
)

def _build_for_djl(self):
"""Placeholder docstring"""
"""Checks if inference spec passed and builds DJL server accordingly"""
self._validate_djl_serving_sample_data()
self.secret_key = None
self.model_server = ModelServer.DJL_SERVING

if self.inference_spec:

os.makedirs(self.model_path, exist_ok=True)

code_path = Path(self.model_path).joinpath("code")

save_pkl(code_path, (self.inference_spec, self.schema_builder))
logger.info("PKL file saved to file: %s", code_path)

self._auto_detect_container()

self.secret_key = prepare_for_djl(
model_path=self.model_path,
shared_libs=self.shared_libs,
dependencies=self.dependencies,
session=self.sagemaker_session,
image_uri=self.image_uri,
inference_spec=self.inference_spec,
)

self.pysdk_model = self._build_for_hf_djl()
self.pysdk_model.tune = self._tune_for_hf_djl
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/serve/mode/sagemaker_endpoint_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def prepare(
upload_artifacts = self._upload_djl_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
s3_model_data_url=s3_model_data_url,
image=image,
should_upload_artifacts=True,
Expand Down
146 changes: 146 additions & 0 deletions src/sagemaker/serve/model_server/djl_serving/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""This module is for SageMaker inference.py."""

from __future__ import absolute_import
import io
import yaml
import logging

from pathlib import Path
from djl_python import Input
from djl_python import Output


class DJLPythonInference(object):
"""A class for DJL inference"""

def __init__(self) -> None:
self.inference_spec = None
self.model_dir = None
self.model = None
self.schema_builder = None
self.inferenceSpec = None
self.metadata = None
self.default_serializer = None
self.default_deserializer = None
self.initialized = False

def load_yaml(self, path: str):
"""Placeholder docstring"""
with open(path, mode="r") as file:
return yaml.full_load(file)

def load_metadata(self):
"""Placeholder docstring"""
metadata_path = Path(self.model_dir).joinpath("metadata.yaml")
return self.load_yaml(metadata_path)

def load_and_validate_pkl(self, path, hash_tag):
"""Placeholder docstring"""

import os
import hmac
import hashlib
import cloudpickle

with open(path, mode="rb") as file:
buffer = file.read()
secret_key = os.getenv("SAGEMAKER_SERVE_SECRET_KEY")
stored_hash_tag = hmac.new(
secret_key.encode(), msg=buffer, digestmod=hashlib.sha256
).hexdigest()
if not hmac.compare_digest(stored_hash_tag, hash_tag):
raise Exception("Object is not valid: " + path)

with open(path, mode="rb") as file:
return cloudpickle.load(file)

def load(self):
"""Detecting for inference spec and loading model"""
self.metadata = self.load_metadata()
if "InferenceSpec" in self.metadata:
inference_spec_path = (
Path(self.model_dir).joinpath(self.metadata.get("InferenceSpec")).absolute()
)
self.inference_spec = self.load_and_validate_pkl(
inference_spec_path, self.metadata.get("InferenceSpecHMAC")
)

# Load model
if self.inference_spec:
self.model = self.inference_spec.load(self.model_dir)
else:
raise Exception(
"SageMaker model format does not support model type: "
+ self.metadata.get("ModelType")
)

def initialize(self, properties):
"""Initialize SageMaker service, loading model and inferenceSpec"""
self.model_dir = properties.get("model_dir")
self.load()
self.initialized = True
logging.info("SageMaker saved format entry-point is applied, service is initilized")

def preprocess_djl(self, inputs: Input):
"""Placeholder docstring"""
content_type = inputs.get_property("content-type")
logging.info(f"Content-type is: {content_type}")
if self.schema_builder:
logging.info("Customized input deserializer is applied")
try:
if hasattr(self.schema_builder, "custom_input_translator"):
return self.schema_builder.custom_input_translator.deserialize(
io.BytesIO(inputs.get_as_bytes()), content_type
)
else:
raise Exception("No custom input translator in cutomized schema builder.")
except Exception as e:
raise Exception("Encountered error in deserialize_request.") from e
elif self.default_deserializer:
return self.default_deserializer.deserialize(
io.BytesIO(inputs.get_as_bytes()), content_type
)

def postproces_djl(self, output):
"""Placeholder docstring"""
if self.schema_builder:
logging.info("Customized output serializer is applied")
try:
if hasattr(self.schema_builder, "custom_output_translator"):
return self.schema_builder.custom_output_translator.serialize(output)
else:
raise Exception("No custom output translator in cutomized schema builder.")
except Exception as e:
raise Exception("Encountered error in serialize_response.") from e
elif self.default_serializer:
return self.default_serializer.serialize(output)

def inference(self, inputs: Input):
"""Detects if inference spec used, returns output accordingly"""
processed_input = self.preprocess_djl(inputs=inputs)
if self.inference_spec:
output = self.inference_spec.invoke(processed_input, self.model)
else:
raise Exception(
"SageMaker model format does not support model type: "
+ self.metadata.get("ModelType")
)
processed_output = self.postproces_djl(output=output)
output_data = Output()
return output_data.add(processed_output)


_service = DJLPythonInference()


def handle(inputs: Input) -> Output:
"""Placeholder docstring"""
if not _service.initialized:
properties = inputs.get_properties()
_service.initialize(properties)

if inputs.is_empty():
# initialization request
return None

return _service.inference(inputs)
62 changes: 62 additions & 0 deletions src/sagemaker/serve/model_server/djl_serving/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,21 @@
import json
import tarfile
import logging
import shutil
from typing import List
from pathlib import Path

from sagemaker.utils import _tmpdir, custom_extractall_tarfile
from sagemaker.s3 import S3Downloader
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
from sagemaker.session import Session
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
generate_secret_key,
compute_hash,
)
from sagemaker.remote_function.core.serialization import _MetaData

_SETTING_PROPERTY_STMT = "Setting property: %s to %s"

Expand Down Expand Up @@ -109,3 +118,56 @@ def prepare_djl_js_resources(
model_path, code_dir = _create_dir_structure(model_path)

return _copy_jumpstart_artifacts(model_data, js_id, code_dir)


def prepare_for_djl(
model_path: str,
shared_libs: List[str],
dependencies: dict,
session: Session,
image_uri: str,
inference_spec: InferenceSpec = None,
) -> str:
"""Prepares for InferenceSpec using model_path, writes inference.py, and captures dependencies to generate secret_key.

Args:to
model_path (str) : Argument
shared_libs (List[]) : Argument
dependencies (dict) : Argument
session (Session) : Argument
inference_spec (InferenceSpec, optional) : Argument
(default is None)
Returns:
( str ) : secret_key
"""
model_path = Path(model_path)
if not model_path.exists():
model_path.mkdir()
elif not model_path.is_dir():
raise Exception("model_dir is not a valid directory")

if inference_spec:
inference_spec.prepare(str(model_path))

code_dir = model_path.joinpath("code")
code_dir.mkdir(exist_ok=True)

shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir)

logger.info("Finished writing inference.py to code directory")

shared_libs_dir = model_path.joinpath("shared_libs")
shared_libs_dir.mkdir(exist_ok=True)
for shared_lib in shared_libs:
shutil.copy2(Path(shared_lib), shared_libs_dir)

capture_dependencies(dependencies=dependencies, work_dir=code_dir)

secret_key = generate_secret_key()
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
buffer = f.read()
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
metadata.write(_MetaData(hash_value).to_json())

return secret_key
Loading
Loading