Skip to content

Commit 81be29a

Browse files
author
Bryannah Hernandez
committed
djl inference spec
1 parent 97a6be3 commit 81be29a

File tree

5 files changed

+372
-14
lines changed

5 files changed

+372
-14
lines changed

src/sagemaker/serve/builder/djl_builder.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
"""Holds mixin logic to support deployment of Model ID"""
1414
from __future__ import absolute_import
1515
import logging
16+
import os
1617
from typing import Type
18+
from pathlib import Path
1719
from abc import ABC, abstractmethod
1820
from datetime import datetime, timedelta
1921

@@ -46,7 +48,12 @@
4648
)
4749
from sagemaker.serve.model_server.djl_serving.prepare import (
4850
_create_dir_structure,
51+
prepare_for_djl,
4952
)
53+
from sagemaker.serve.detector.image_detector import (
54+
auto_detect_container,
55+
)
56+
from sagemaker.serve.detector.pickler import save_pkl
5057
from sagemaker.serve.utils.predictors import DjlLocalModePredictor
5158
from sagemaker.serve.utils.types import ModelServer
5259
from sagemaker.serve.mode.function_pointers import Mode
@@ -92,6 +99,8 @@ def __init__(self):
9299
self.nb_instance_type = None
93100
self.ram_usage_model_load = None
94101
self.role_arn = None
102+
self.inference_spec = None
103+
self.shared_libs = None
95104

96105
@abstractmethod
97106
def _prepare_for_mode(self):
@@ -247,17 +256,22 @@ def _build_for_hf_djl(self):
247256

248257
_create_dir_structure(self.model_path)
249258
if not hasattr(self, "pysdk_model"):
250-
self.env_vars.update({"HF_MODEL_ID": self.model})
259+
if self.inference_spec is not None:
260+
self.env_vars.update({"HF_MODEL_ID": self.inference_spec.get_model()})
261+
else:
262+
self.env_vars.update({"HF_MODEL_ID": self.model})
263+
251264
self.hf_model_config = _get_model_config_properties_from_hf(
252-
self.model, self.env_vars.get("HF_TOKEN")
265+
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HF_TOKEN")
253266
)
254267
default_djl_configurations, _default_max_new_tokens = _get_default_djl_configurations(
255-
self.model, self.hf_model_config, self.schema_builder
268+
self.env_vars.get("HF_MODEL_ID"), self.hf_model_config, self.schema_builder
256269
)
257270
self.env_vars.update(default_djl_configurations)
258271
self.schema_builder.sample_input["parameters"][
259272
"max_new_tokens"
260273
] = _default_max_new_tokens
274+
261275
self.pysdk_model = self._create_djl_model()
262276

263277
if self.mode == Mode.LOCAL_CONTAINER:
@@ -445,10 +459,67 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800):
445459

446460
return self.pysdk_model
447461

462+
def _auto_detect_container(self):
463+
"""Set image_uri by detecting container via model name or inference spec"""
464+
# Auto detect the container image uri
465+
if self.image_uri:
466+
logger.info(
467+
"Skipping auto detection as the image uri is provided %s",
468+
self.image_uri,
469+
)
470+
return
471+
472+
if self.model:
473+
logger.info(
474+
"Auto detect container url for the provided model and on instance %s",
475+
self.instance_type,
476+
)
477+
self.image_uri = auto_detect_container(
478+
self.model, self.sagemaker_session.boto_region_name, self.instance_type
479+
)
480+
481+
elif self.inference_spec:
482+
# TODO: this won't work for larger image.
483+
# Fail and let the customer include the image uri
484+
logger.warning(
485+
"model_path provided with no image_uri. Attempting to autodetect the image\
486+
by loading the model using inference_spec.load()..."
487+
)
488+
self.image_uri = auto_detect_container(
489+
self.inference_spec.load(self.model_path),
490+
self.sagemaker_session.boto_region_name,
491+
self.instance_type,
492+
)
493+
else:
494+
raise ValueError(
495+
"Cannot detect and set image_uri. Please pass model or inference spec."
496+
)
497+
448498
def _build_for_djl(self):
449-
"""Placeholder docstring"""
499+
"""Checks if inference spec passed and builds DJL server accordingly"""
450500
self._validate_djl_serving_sample_data()
451501
self.secret_key = None
502+
self.model_server = ModelServer.DJL_SERVING
503+
504+
if self.inference_spec:
505+
506+
os.makedirs(self.model_path, exist_ok=True)
507+
508+
code_path = Path(self.model_path).joinpath("code")
509+
510+
save_pkl(code_path, (self.inference_spec, self.schema_builder))
511+
logger.info("PKL file saved to file: %s", code_path)
512+
513+
self._auto_detect_container()
514+
515+
self.secret_key = prepare_for_djl(
516+
model_path=self.model_path,
517+
shared_libs=self.shared_libs,
518+
dependencies=self.dependencies,
519+
session=self.sagemaker_session,
520+
image_uri=self.image_uri,
521+
inference_spec=self.inference_spec,
522+
)
452523

453524
self.pysdk_model = self._build_for_hf_djl()
454525
self.pysdk_model.tune = self._tune_for_hf_djl
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""This module is for SageMaker inference.py."""
2+
3+
from __future__ import absolute_import
4+
import io
5+
import yaml
6+
import logging
7+
8+
from pathlib import Path
9+
from djl_python import Input
10+
from djl_python import Output
11+
12+
13+
class DJLPythonInference(object):
14+
"""A class for DJL inference"""
15+
16+
def __init__(self) -> None:
17+
self.inference_spec = None
18+
self.model_dir = None
19+
self.model = None
20+
self.schema_builder = None
21+
self.inferenceSpec = None
22+
self.metadata = None
23+
self.default_serializer = None
24+
self.default_deserializer = None
25+
self.initialized = False
26+
27+
def load_yaml(self, path: str):
28+
"""Placeholder docstring"""
29+
with open(path, mode="r") as file:
30+
return yaml.full_load(file)
31+
32+
def load_metadata(self):
33+
"""Placeholder docstring"""
34+
metadata_path = Path(self.model_dir).joinpath("metadata.yaml")
35+
return self.load_yaml(metadata_path)
36+
37+
def load_and_validate_pkl(self, path, hash_tag):
38+
"""Placeholder docstring"""
39+
40+
import os
41+
import hmac
42+
import hashlib
43+
import cloudpickle
44+
45+
with open(path, mode="rb") as file:
46+
buffer = file.read()
47+
secret_key = os.getenv("SAGEMAKER_SERVE_SECRET_KEY")
48+
stored_hash_tag = hmac.new(
49+
secret_key.encode(), msg=buffer, digestmod=hashlib.sha256
50+
).hexdigest()
51+
if not hmac.compare_digest(stored_hash_tag, hash_tag):
52+
raise Exception("Object is not valid: " + path)
53+
54+
with open(path, mode="rb") as file:
55+
return cloudpickle.load(file)
56+
57+
def load(self):
58+
"""Detecting for inference spec and loading model"""
59+
self.metadata = self.load_metadata()
60+
if "InferenceSpec" in self.metadata:
61+
inference_spec_path = (
62+
Path(self.model_dir).joinpath(self.metadata.get("InferenceSpec")).absolute()
63+
)
64+
self.inference_spec = self.load_and_validate_pkl(
65+
inference_spec_path, self.metadata.get("InferenceSpecHMAC")
66+
)
67+
68+
# Load model
69+
if self.inference_spec:
70+
self.model = self.inference_spec.load(self.model_dir)
71+
else:
72+
raise Exception(
73+
"SageMaker model format does not support model type: "
74+
+ self.metadata.get("ModelType")
75+
)
76+
77+
def initialize(self, properties):
78+
"""Initialize SageMaker service, loading model and inferenceSpec"""
79+
self.model_dir = properties.get("model_dir")
80+
self.load()
81+
self.initialized = True
82+
logging.info("SageMaker saved format entry-point is applied, service is initilized")
83+
84+
def preprocess_djl(self, inputs: Input):
85+
"""Placeholder docstring"""
86+
content_type = inputs.get_property("content-type")
87+
logging.info(f"Content-type is: {content_type}")
88+
if self.schema_builder:
89+
logging.info("Customized input deserializer is applied")
90+
try:
91+
if hasattr(self.schema_builder, "custom_input_translator"):
92+
return self.schema_builder.custom_input_translator.deserialize(
93+
io.BytesIO(inputs.get_as_bytes()), content_type
94+
)
95+
else:
96+
raise Exception("No custom input translator in cutomized schema builder.")
97+
except Exception as e:
98+
raise Exception("Encountered error in deserialize_request.") from e
99+
elif self.default_deserializer:
100+
return self.default_deserializer.deserialize(
101+
io.BytesIO(inputs.get_as_bytes()), content_type
102+
)
103+
104+
def postproces_djl(self, output):
105+
"""Placeholder docstring"""
106+
if self.schema_builder:
107+
logging.info("Customized output serializer is applied")
108+
try:
109+
if hasattr(self.schema_builder, "custom_output_translator"):
110+
return self.schema_builder.custom_output_translator.serialize(output)
111+
else:
112+
raise Exception("No custom output translator in cutomized schema builder.")
113+
except Exception as e:
114+
raise Exception("Encountered error in serialize_response.") from e
115+
elif self.default_serializer:
116+
return self.default_serializer.serialize(output)
117+
118+
def inference(self, inputs: Input):
119+
"""Detects if inference spec used, returns output accordingly"""
120+
processed_input = self.preprocess_djl(inputs=inputs)
121+
if self.inference_spec:
122+
output = self.inference_spec.invoke(processed_input, self.model)
123+
else:
124+
raise Exception(
125+
"SageMaker model format does not support model type: "
126+
+ self.metadata.get("ModelType")
127+
)
128+
processed_output = self.postproces_djl(output=output)
129+
output_data = Output()
130+
return output_data.add(processed_output)
131+
132+
133+
_service = DJLPythonInference()
134+
135+
136+
def handle(inputs: Input) -> Output:
137+
"""Placeholder docstring"""
138+
if not _service.initialized:
139+
properties = inputs.get_properties()
140+
_service.initialize(properties)
141+
142+
if inputs.is_empty():
143+
# initialization request
144+
return None
145+
146+
return _service.inference(inputs)

src/sagemaker/serve/model_server/djl_serving/prepare.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,21 @@
1616
import json
1717
import tarfile
1818
import logging
19+
import shutil
1920
from typing import List
2021
from pathlib import Path
2122

2223
from sagemaker.utils import _tmpdir, custom_extractall_tarfile
2324
from sagemaker.s3 import S3Downloader
2425
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
26+
from sagemaker.session import Session
27+
from sagemaker.serve.spec.inference_spec import InferenceSpec
28+
from sagemaker.serve.detector.dependency_manager import capture_dependencies
29+
from sagemaker.serve.validations.check_integrity import (
30+
generate_secret_key,
31+
compute_hash,
32+
)
33+
from sagemaker.remote_function.core.serialization import _MetaData
2534

2635
_SETTING_PROPERTY_STMT = "Setting property: %s to %s"
2736

@@ -109,3 +118,56 @@ def prepare_djl_js_resources(
109118
model_path, code_dir = _create_dir_structure(model_path)
110119

111120
return _copy_jumpstart_artifacts(model_data, js_id, code_dir)
121+
122+
123+
def prepare_for_djl(
124+
model_path: str,
125+
shared_libs: List[str],
126+
dependencies: dict,
127+
session: Session,
128+
image_uri: str,
129+
inference_spec: InferenceSpec = None,
130+
) -> str:
131+
"""Prepares for InferenceSpec using model_path, writes inference.py, and captures dependencies to generate secret_key.
132+
133+
Args:to
134+
model_path (str) : Argument
135+
shared_libs (List[]) : Argument
136+
dependencies (dict) : Argument
137+
session (Session) : Argument
138+
inference_spec (InferenceSpec, optional) : Argument
139+
(default is None)
140+
Returns:
141+
( str ) : secret_key
142+
"""
143+
model_path = Path(model_path)
144+
if not model_path.exists():
145+
model_path.mkdir()
146+
elif not model_path.is_dir():
147+
raise Exception("model_dir is not a valid directory")
148+
149+
if inference_spec:
150+
inference_spec.prepare(str(model_path))
151+
152+
code_dir = model_path.joinpath("code")
153+
code_dir.mkdir(exist_ok=True)
154+
155+
shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir)
156+
157+
logger.info("Finished writing inference.py to code directory")
158+
159+
shared_libs_dir = model_path.joinpath("shared_libs")
160+
shared_libs_dir.mkdir(exist_ok=True)
161+
for shared_lib in shared_libs:
162+
shutil.copy2(Path(shared_lib), shared_libs_dir)
163+
164+
capture_dependencies(dependencies=dependencies, work_dir=code_dir)
165+
166+
secret_key = generate_secret_key()
167+
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
168+
buffer = f.read()
169+
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
170+
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
171+
metadata.write(_MetaData(hash_value).to_json())
172+
173+
return secret_key

0 commit comments

Comments
 (0)