Skip to content

Commit fb28458

Browse files
author
Bryannah Hernandez
committed
feat: InferenceSpec support for MMS and testing
1 parent b25295a commit fb28458

File tree

4 files changed

+14
-28
lines changed

4 files changed

+14
-28
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -881,8 +881,8 @@ def _build_for_model_server(self): # pylint: disable=R0911, R1710
881881
if self.model_metadata:
882882
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
883883

884-
if not self.model and not mlflow_path:
885-
raise ValueError("Missing required parameter `model` or 'ml_flow' path")
884+
if not self.model and not mlflow_path and not self.inference_spec:
885+
raise ValueError("Missing required parameter `model` or 'ml_flow' path or inf_spec")
886886

887887
if self.model_server == ModelServer.TORCHSERVE:
888888
return self._build_for_torchserve()

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __init__(self):
7272
self.pytorch_version = None
7373
self.instance_type = None
7474
self.schema_builder = None
75+
self.inference_spec = None
7576

7677
@abstractmethod
7778
def _prepare_for_mode(self):
@@ -109,7 +110,7 @@ def _get_hf_metadata_create_model(self) -> Type[Model]:
109110
"""
110111

111112
hf_model_md = get_huggingface_model_metadata(
112-
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
113+
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
113114
)
114115
hf_config = image_uris.config_for_framework("huggingface").get("inference")
115116
config = hf_config["versions"]
@@ -246,18 +247,22 @@ def _build_transformers_env(self):
246247

247248
_create_dir_structure(self.model_path)
248249
if not hasattr(self, "pysdk_model"):
249-
self.env_vars.update({"HF_MODEL_ID": self.model})
250+
251+
if self.inference_spec is not None:
252+
self.env_vars.update({"HF_MODEL_ID": self.inference_spec.get_model()})
253+
else:
254+
self.env_vars.update({"HF_MODEL_ID": self.model})
250255

251256
logger.info(self.env_vars)
252257

253258
# TODO: Move to a helper function
254259
if hasattr(self.env_vars, "HF_API_TOKEN"):
255260
self.hf_model_config = _get_model_config_properties_from_hf(
256-
self.model, self.env_vars.get("HF_API_TOKEN")
261+
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HF_API_TOKEN")
257262
)
258263
else:
259264
self.hf_model_config = _get_model_config_properties_from_hf(
260-
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
265+
self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
261266
)
262267

263268
self.pysdk_model = self._create_transformers_model()

src/sagemaker/serve/model_server/multi_model_server/inference.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,17 @@
11
"""This module is for SageMaker inference.py."""
22

33
from __future__ import absolute_import
4-
import os
54
import io
65
import cloudpickle
76
import shutil
8-
import platform
9-
import importlib
107
from pathlib import Path
118
from functools import partial
12-
from sagemaker.serve.validations.check_integrity import perform_integrity_check
139
from sagemaker.serve.spec.inference_spec import InferenceSpec
14-
from sagemaker.serve.detector.image_detector import _detect_framework_and_version, _get_model_base
15-
from sagemaker.serve.detector.pickler import load_xgboost_from_json
1610
import logging
1711

1812
logger = logging.getLogger(__name__)
1913

2014
inference_spec = None
21-
native_model = None
2215
schema_builder = None
2316

2417

@@ -32,20 +25,12 @@ def model_fn(model_dir):
3225

3326
serve_path = Path(__file__).parent.joinpath("serve.pkl")
3427
with open(str(serve_path), mode="rb") as file:
35-
global inference_spec, native_model, schema_builder
28+
global inference_spec, schema_builder
3629
obj = cloudpickle.load(file)
3730
if isinstance(obj[0], InferenceSpec):
3831
inference_spec, schema_builder = obj
39-
else:
40-
native_model, schema_builder = obj
41-
if native_model:
42-
framework, _ = _detect_framework_and_version(
43-
model_base=str(_get_model_base(model=native_model))
44-
)
45-
if framework == "pytorch":
46-
native_model.eval()
47-
return native_model if callable(native_model) else native_model.predict
48-
elif inference_spec:
32+
33+
if inference_spec:
4934
return partial(inference_spec.invoke, model=inference_spec.load(model_dir))
5035

5136

src/sagemaker/serve/model_server/multi_model_server/server.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def _start_serving(
4343
"mode": "rw",
4444
},
4545
},
46-
4746
environment={
4847
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
4948
"SAGEMAKER_PROGRAM": "inference.py",
@@ -88,12 +87,10 @@ class SageMakerMultiModelServer:
8887
def _upload_server_artifacts(
8988
self,
9089
model_path: str,
91-
secret_key: str,
9290
sagemaker_session: Session,
9391
s3_model_data_url: str = None,
9492
image: str = None,
9593
env_vars: dict = None,
96-
9794
):
9895
if s3_model_data_url:
9996
bucket, key_prefix = parse_s3_url(url=s3_model_data_url)
@@ -132,7 +129,6 @@ def _upload_server_artifacts(
132129
"SAGEMAKER_PROGRAM": "inference.py",
133130
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
134131
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
135-
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
136132
"LOCAL_PYTHON": platform.python_version(),
137133
}
138134

0 commit comments

Comments
 (0)