Skip to content

Commit 719fe7e

Browse files
author
Bryannah Hernandez
committed
feat: formatting and InferenceSpec support for MMS
1 parent aa4a62e commit 719fe7e

File tree

5 files changed

+194
-116
lines changed

5 files changed

+194
-116
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -857,7 +857,7 @@ def build( # pylint: disable=R0911
857857

858858
def _build_validations(self):
859859
"""Validations needed for model server overrides, or auto-detection or fallback"""
860-
if self.mode == Mode.IN_PROCESS:
860+
if self.mode == Mode.IN_PROCESS and self.model_server is not ModelServer.MMS:
861861
raise ValueError("IN_PROCESS mode is not supported yet!")
862862

863863
if self.inference_spec and self.model:

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from sagemaker.huggingface import HuggingFaceModel
3030
from sagemaker.serve.model_server.multi_model_server.prepare import (
3131
_create_dir_structure,
32-
prepare_for_mms
32+
prepare_for_mms,
3333
)
3434
from sagemaker.serve.detector.image_detector import (
3535
auto_detect_container,
@@ -161,7 +161,7 @@ def _get_hf_metadata_create_model(self) -> Type[Model]:
161161
vpc_config=self.vpc_config,
162162
)
163163

164-
if self.mode == Mode.LOCAL_CONTAINER:
164+
if self.mode == Mode.LOCAL_CONTAINER or self.mode == Mode.IN_PROCESS:
165165
self.image_uri = pysdk_model.serving_image_uri(
166166
self.sagemaker_session.boto_region_name, "local"
167167
)
@@ -227,6 +227,23 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr
227227
)
228228
return predictor
229229

230+
if self.mode == Mode.IN_PROCESS:
231+
timeout = kwargs.get("model_data_download_timeout")
232+
233+
predictor = TransformersLocalModePredictor(
234+
self.modes[str(Mode.IN_PROCESS)], serializer, deserializer
235+
)
236+
237+
self.modes[str(Mode.IN_PROCESS)].create_server(
238+
self.image_uri,
239+
timeout if timeout else DEFAULT_TIMEOUT,
240+
None,
241+
predictor,
242+
self.pysdk_model.env,
243+
jumpstart=False,
244+
)
245+
return predictor
246+
230247
if "mode" in kwargs:
231248
del kwargs["mode"]
232249
if "role" in kwargs:
@@ -276,11 +293,11 @@ def _build_transformers_env(self):
276293

277294
self.pysdk_model = self._create_transformers_model()
278295

279-
if self.mode == Mode.LOCAL_CONTAINER:
296+
if self.mode == Mode.LOCAL_CONTAINER or self.mode == Mode.IN_PROCESS:
280297
self._prepare_for_mode()
281298

282299
logger.info("Model configuration %s", self.pysdk_model)
283-
300+
284301
return self.pysdk_model
285302

286303
def _set_instance(self, **kwargs):
@@ -343,7 +360,6 @@ def _auto_detect_container(self):
343360
else:
344361
raise ValueError("Cannot detect required model or inference spec")
345362

346-
347363
def _build_for_transformers(self):
348364
"""Method that triggers model build
349365
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""Module that defines the LocalContainerMode class"""
2+
3+
from __future__ import absolute_import
4+
from pathlib import Path
5+
import logging
6+
from datetime import datetime, timedelta
7+
from typing import Dict, Type
8+
import base64
9+
import time
10+
import subprocess
11+
import docker
12+
13+
from sagemaker.base_predictor import PredictorBase
14+
from sagemaker.serve.spec.inference_spec import InferenceSpec
15+
from sagemaker.serve.builder.schema_builder import SchemaBuilder
16+
from sagemaker.serve.utils.logging_agent import pull_logs
17+
from sagemaker.serve.utils.types import ModelServer
18+
from sagemaker.serve.utils.exceptions import LocalDeepPingException
19+
from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer
20+
from sagemaker.session import Session
21+
22+
logger = logging.getLogger(__name__)
23+
24+
_PING_HEALTH_CHECK_INTERVAL_SEC = 5
25+
26+
_PING_HEALTH_CHECK_FAIL_MSG = (
27+
"Container did not pass the ping health check. "
28+
+ "Please increase container_timeout_seconds or review your inference code."
29+
)
30+
31+
32+
class InProcessMode(
33+
InProcessMultiModelServer,
34+
):
35+
"""A class that holds methods to deploy model to a container in local environment"""
36+
37+
def __init__(
38+
self,
39+
model_server: ModelServer,
40+
inference_spec: Type[InferenceSpec],
41+
schema_builder: Type[SchemaBuilder],
42+
session: Session,
43+
model_path: str = None,
44+
env_vars: Dict = None,
45+
):
46+
# pylint: disable=bad-super-call
47+
super().__init__()
48+
49+
self.inference_spec = inference_spec
50+
self.model_path = model_path
51+
self.env_vars = env_vars
52+
self.session = session
53+
self.schema_builder = schema_builder
54+
self.ecr = session.boto_session.client("ecr")
55+
self.model_server = model_server
56+
self.client = None
57+
self.container = None
58+
self.secret_key = None
59+
self._ping_container = None
60+
self._invoke_serving = None
61+
62+
def load(self, model_path: str = None):
63+
"""Placeholder docstring"""
64+
path = Path(model_path if model_path else self.model_path)
65+
if not path.exists():
66+
raise Exception("model_path does not exist")
67+
if not path.is_dir():
68+
raise Exception("model_path is not a valid directory")
69+
70+
return self.inference_spec.load(str(path))
71+
72+
def prepare(self):
73+
"""Placeholder docstring"""
74+
75+
def create_server(
76+
self,
77+
image: str,
78+
container_timeout_seconds: int,
79+
secret_key: str,
80+
predictor: PredictorBase,
81+
env_vars: Dict[str, str] = None,
82+
model_path: str = None,
83+
):
84+
"""Placeholder docstring"""
85+
86+
self._pull_image(image=image)
87+
88+
self.destroy_server()
89+
90+
logger.info("Waiting for model server %s to start up...", self.model_server)
91+
92+
if self.model_server == ModelServer.MMS:
93+
self._start_serving(
94+
client=self.client,
95+
image=image,
96+
model_path=model_path if model_path else self.model_path,
97+
secret_key=secret_key,
98+
env_vars=env_vars if env_vars else self.env_vars,
99+
)
100+
self._ping_container = self._multi_model_server_deep_ping
101+
102+
# allow some time for container to be ready
103+
time.sleep(10)
104+
105+
log_generator = self.container.logs(follow=True, stream=True)
106+
time_limit = datetime.now() + timedelta(seconds=container_timeout_seconds)
107+
healthy = False
108+
while True:
109+
now = datetime.now()
110+
final_pull = now > time_limit
111+
pull_logs(
112+
(x.decode("UTF-8").rstrip() for x in log_generator),
113+
log_generator.close,
114+
datetime.now() + timedelta(seconds=_PING_HEALTH_CHECK_INTERVAL_SEC),
115+
now > time_limit,
116+
)
117+
118+
if final_pull:
119+
break
120+
121+
# allow some time for container to be ready
122+
time.sleep(10)
123+
124+
healthy, response = self._ping_container(predictor)
125+
if healthy:
126+
logger.debug("Ping health check has passed. Returned %s", str(response))
127+
break
128+
129+
if not healthy:
130+
raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG)
131+
132+
def destroy_server(self):
133+
"""Placeholder docstring"""
134+
if self.container:
135+
try:
136+
logger.debug("Stopping currently running container...")
137+
self.container.kill()
138+
except docker.errors.APIError as exc:
139+
if exc.response.status_code < 400 or exc.response.status_code > 499:
140+
raise Exception("Error encountered when cleaning up local container") from exc
141+
self.container = None
142+
143+
def _pull_image(self, image: str):
144+
"""Placeholder docstring"""
145+
try:
146+
encoded_token = (
147+
self.ecr.get_authorization_token()
148+
.get("authorizationData")[0]
149+
.get("authorizationToken")
150+
)
151+
decoded_token = base64.b64decode(encoded_token).decode("utf-8")
152+
username, password = decoded_token.split(":")
153+
ecr_uri = image.split("/")[0]
154+
login_command = ["docker", "login", "-u", username, "-p", password, ecr_uri]
155+
subprocess.run(login_command, check=True, capture_output=True)
156+
except subprocess.CalledProcessError as e:
157+
logger.warning("Unable to login to ecr: %s", e)
158+
159+
self.client = docker.from_env()
160+
try:
161+
logger.info("Pulling image %s from repository...", image)
162+
self.client.images.pull(image)
163+
except docker.errors.NotFound as e:
164+
raise ValueError("Could not find remote image to pull") from e

src/sagemaker/serve/model_server/multi_model_server/inference.py

Lines changed: 0 additions & 103 deletions
This file was deleted.

src/sagemaker/serve/model_server/multi_model_server/prepare.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,14 @@ def prepare_mms_js_resources(
7575

7676
return _copy_jumpstart_artifacts(model_data, js_id, code_dir)
7777

78+
7879
def prepare_for_mms(
79-
model_path: str,
80-
shared_libs: List[str],
81-
dependencies: dict,
82-
session: Session,
83-
image_uri: str,
84-
inference_spec: InferenceSpec = None,
80+
model_path: str,
81+
shared_libs: List[str],
82+
dependencies: dict,
83+
session: Session,
84+
image_uri: str,
85+
inference_spec: InferenceSpec = None,
8586
) -> str:
8687
"""This is a one-line summary of the function.
8788
Args:to
@@ -124,4 +125,4 @@ def prepare_for_mms(
124125
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
125126
metadata.write(_MetaData(hash_value).to_json())
126127

127-
return secret_key
128+
return secret_key

0 commit comments

Comments
 (0)