Skip to content

Commit b40f36c

Browse files
author
Bryannah Hernandez
committed
tests for in process mode
1 parent 1b93244 commit b40f36c

File tree

2 files changed

+162
-10
lines changed

2 files changed

+162
-10
lines changed

src/sagemaker/serve/mode/in_process_mode.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sagemaker.serve.utils.exceptions import LocalDeepPingException
1414
from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer
1515
from sagemaker.session import Session
16+
from datetime import datetime, timedelta
1617

1718
logger = logging.getLogger(__name__)
1819

@@ -45,19 +46,15 @@ def __init__(
4546
self.session = session
4647
self.schema_builder = schema_builder
4748
self.model_server = model_server
48-
self.client = None
49-
self.container = None
50-
self.secret_key = None
51-
self._invoke_serving = None
5249
self._ping_container = None
5350

5451
def load(self, model_path: str = None):
5552
"""Loads model path, checks that path exists"""
5653
path = Path(model_path if model_path else self.model_path)
5754
if not path.exists():
58-
raise Exception("model_path does not exist")
55+
raise ValueError("model_path does not exist")
5956
if not path.is_dir():
60-
raise Exception("model_path is not a valid directory")
57+
raise ValueError("model_path is not a valid directory")
6158

6259
return self.inference_spec.load(str(path))
6360

@@ -69,15 +66,18 @@ def create_server(
6966
predictor: PredictorBase,
7067
):
7168
"""Creating the server and checking ping health."""
72-
73-
# self.destroy_server()
74-
7569
logger.info("Waiting for model server %s to start up...", self.model_server)
7670

7771
if self.model_server == ModelServer.MMS:
7872
self._ping_container = self._multi_model_server_deep_ping
7973

80-
while True:
74+
time_limit = datetime.now() + timedelta(seconds=5)
75+
while self._ping_container is not None:
76+
final_pull = datetime.now() > time_limit
77+
78+
if final_pull:
79+
break
80+
8181
time.sleep(10)
8282

8383
healthy, response = self._ping_container(predictor)
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import unittest
16+
from unittest.mock import MagicMock, patch, Mock, mock_open
17+
18+
from sagemaker.serve.mode.in_process_mode import InProcessMode
19+
from sagemaker.serve import SchemaBuilder
20+
from sagemaker.serve.utils.types import ModelServer
21+
from sagemaker.serve.utils.exceptions import LocalDeepPingException
22+
23+
24+
mock_prompt = "Hello, I'm a language model,"
25+
mock_response = "Hello, I'm a language model, and I'm here to help you with your English."
26+
mock_sample_input = {"inputs": mock_prompt, "parameters": {}}
27+
mock_sample_output = [{"generated_text": mock_response}]
28+
29+
30+
class TestInProcessMode(unittest.TestCase):
31+
32+
@patch("sagemaker.serve.mode.in_process_mode.Path")
33+
@patch("sagemaker.serve.spec.inference_spec.InferenceSpec")
34+
@patch("sagemaker.session.Session")
35+
def test_load_happy(self, mock_session, mock_inference_spec, mock_path):
36+
mock_path.return_value.exists.side_effect = lambda *args, **kwargs: True
37+
mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: True
38+
39+
mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load"
40+
41+
mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output)
42+
in_process_mode = InProcessMode(
43+
model_server=ModelServer.MMS,
44+
inference_spec=mock_inference_spec,
45+
schema_builder=mock_schema_builder,
46+
session=mock_session,
47+
model_path="model_path",
48+
env_vars={"key": "val"},
49+
)
50+
51+
res = in_process_mode.load(model_path="/tmp/model-builder/code/")
52+
53+
self.assertEqual(res, "Dummy load")
54+
self.assertEqual(in_process_mode.inference_spec, mock_inference_spec)
55+
self.assertEqual(in_process_mode.schema_builder, mock_schema_builder)
56+
self.assertEqual(in_process_mode.model_path, "model_path")
57+
self.assertEqual(in_process_mode.env_vars, {"key": "val"})
58+
59+
@patch("sagemaker.serve.mode.in_process_mode.Path")
60+
@patch("sagemaker.serve.spec.inference_spec.InferenceSpec")
61+
@patch("sagemaker.session.Session")
62+
def test_load_ex(self, mock_session, mock_inference_spec, mock_path):
63+
mock_path.return_value.exists.side_effect = lambda *args, **kwargs: False
64+
mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: True
65+
66+
mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load"
67+
68+
mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output)
69+
in_process_mode = InProcessMode(
70+
model_server=ModelServer.MMS,
71+
inference_spec=mock_inference_spec,
72+
schema_builder=mock_schema_builder,
73+
session=mock_session,
74+
model_path="model_path",
75+
)
76+
77+
self.assertRaises(ValueError, in_process_mode.load, "/tmp/model-builder/code/")
78+
79+
mock_path.return_value.exists.side_effect = lambda *args, **kwargs: True
80+
mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: False
81+
82+
mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load"
83+
mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output)
84+
in_process_mode = InProcessMode(
85+
model_server=ModelServer.MMS,
86+
inference_spec=mock_inference_spec,
87+
schema_builder=mock_schema_builder,
88+
session=mock_session,
89+
model_path="model_path",
90+
)
91+
92+
self.assertRaises(ValueError, in_process_mode.load, "/tmp/model-builder/code/")
93+
94+
@patch("sagemaker.serve.mode.in_process_mode.logger")
95+
@patch("sagemaker.base_predictor.PredictorBase")
96+
@patch("sagemaker.serve.spec.inference_spec.InferenceSpec")
97+
@patch("sagemaker.session.Session")
98+
def test_create_server_happy(
99+
self, mock_session, mock_inference_spec, mock_predictor, mock_logger
100+
):
101+
mock_response = "Fake response"
102+
mock_multi_model_server_deep_ping = Mock()
103+
mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: (
104+
True,
105+
mock_response,
106+
)
107+
108+
in_process_mode = InProcessMode(
109+
model_server=ModelServer.MMS,
110+
inference_spec=mock_inference_spec,
111+
schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output),
112+
session=mock_session,
113+
model_path="model_path",
114+
)
115+
116+
in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping
117+
118+
in_process_mode.create_server(predictor=mock_predictor)
119+
120+
mock_logger.info.assert_called_once_with(
121+
"Waiting for model server %s to start up...", ModelServer.MMS
122+
)
123+
mock_logger.debug.assert_called_once_with(
124+
"Ping health check has passed. Returned %s", str(mock_response)
125+
)
126+
127+
@patch("sagemaker.base_predictor.PredictorBase")
128+
@patch("sagemaker.serve.spec.inference_spec.InferenceSpec")
129+
@patch("sagemaker.session.Session")
130+
def test_create_server_ex(
131+
self,
132+
mock_session,
133+
mock_inference_spec,
134+
mock_predictor,
135+
):
136+
mock_multi_model_server_deep_ping = Mock()
137+
mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: (
138+
False,
139+
None,
140+
)
141+
142+
in_process_mode = InProcessMode(
143+
model_server=ModelServer.MMS,
144+
inference_spec=mock_inference_spec,
145+
schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output),
146+
session=mock_session,
147+
model_path="model_path",
148+
)
149+
150+
in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping
151+
152+
self.assertRaises(LocalDeepPingException, in_process_mode.create_server, mock_predictor)

0 commit comments

Comments
 (0)