Skip to content

Commit ebfc3e2

Browse files
author
Bryannah Hernandez
committed
ut for app.py
1 parent 009a90e commit ebfc3e2

File tree

1 file changed

+152
-0
lines changed

1 file changed

+152
-0
lines changed
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import unittest
16+
from unittest.mock import patch, Mock
17+
18+
from sagemaker.serve.mode.in_process_mode import InProcessMode
19+
from sagemaker.serve import SchemaBuilder
20+
from sagemaker.serve.utils.types import ModelServer
21+
from sagemaker.serve.utils.exceptions import InProcessDeepPingException
22+
23+
24+
mock_prompt = "Hello, I'm a language model,"
25+
mock_response = "Hello, I'm a language model, and I'm here to help you with your English."
26+
mock_sample_input = {"inputs": mock_prompt, "parameters": {}}
27+
mock_sample_output = [{"generated_text": mock_response}]
28+
29+
30+
class TestAppInProcessServer(unittest.TestCase):
31+
32+
@patch("sagemaker.server.app.uvicorn")
33+
def test_uvicorn_import(self, mock_uvicorn):
34+
mock_uvicorn.return_value.exists.side_effect = lambda *args, **kwargs: False
35+
self.assertRaises(ImportError, in_process_mode.load, "/tmp/model-builder/code/")
36+
37+
def test_transformers_import(self):
38+
self.assertRaises(ImportError, in_process_mode.load, "/tmp/model-builder/code/")
39+
40+
def test_fastapi_import(self):
41+
self.assertRaises(ImportError, in_process_mode.load, "/tmp/model-builder/code/")
42+
43+
44+
45+
# @patch("sagemaker.serve.mode.in_process_mode.Path")
46+
# @patch("sagemaker.serve.spec.inference_spec.InferenceSpec")
47+
# @patch("sagemaker.session.Session")
48+
# def test_load_ex(self, mock_session, mock_inference_spec, mock_path):
49+
# mock_path.return_value.exists.side_effect = lambda *args, **kwargs: False
50+
# mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: True
51+
52+
# mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load"
53+
54+
# mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output)
55+
# in_process_mode = InProcessMode(
56+
# model_server=ModelServer.MMS,
57+
# inference_spec=mock_inference_spec,
58+
# schema_builder=mock_schema_builder,
59+
# session=mock_session,
60+
# model_path="model_path",
61+
# )
62+
63+
# self.assertRaises(ValueError, in_process_mode.load, "/tmp/model-builder/code/")
64+
65+
# mock_path.return_value.exists.side_effect = lambda *args, **kwargs: True
66+
# mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: False
67+
68+
# mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load"
69+
# mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output)
70+
# in_process_mode = InProcessMode(
71+
# model_server=ModelServer.MMS,
72+
# inference_spec=mock_inference_spec,
73+
# schema_builder=mock_schema_builder,
74+
# session=mock_session,
75+
# model_path="model_path",
76+
# )
77+
78+
# self.assertRaises(ValueError, in_process_mode.load, "/tmp/model-builder/code/")
79+
80+
# @patch("sagemaker.serve.mode.in_process_mode.logger")
81+
# @patch("sagemaker.base_predictor.PredictorBase")
82+
# @patch("sagemaker.serve.spec.inference_spec.InferenceSpec")
83+
# @patch("sagemaker.session.Session")
84+
# def test_create_server_happy(
85+
# self, mock_session, mock_inference_spec, mock_predictor, mock_logger
86+
# ):
87+
# mock_start_serving = Mock()
88+
# mock_start_serving.side_effect = lambda *args, **kwargs: (
89+
# True,
90+
# None,
91+
# )
92+
93+
# mock_response = "Fake response"
94+
# mock_multi_model_server_deep_ping = Mock()
95+
# mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: (
96+
# True,
97+
# mock_response,
98+
# )
99+
100+
# in_process_mode = InProcessMode(
101+
# model_server=ModelServer.MMS,
102+
# inference_spec=mock_inference_spec,
103+
# schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output),
104+
# session=mock_session,
105+
# model_path="model_path",
106+
# )
107+
108+
# in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping
109+
# in_process_mode._start_serving = mock_start_serving
110+
111+
# in_process_mode.create_server(predictor=mock_predictor)
112+
113+
# mock_logger.info.assert_called_once_with(
114+
# "Waiting for model server %s to start up...", ModelServer.MMS
115+
# )
116+
# mock_logger.debug.assert_called_once_with(
117+
# "Ping health check has passed. Returned %s", str(mock_response)
118+
# )
119+
120+
# @patch("sagemaker.base_predictor.PredictorBase")
121+
# @patch("sagemaker.serve.spec.inference_spec.InferenceSpec")
122+
# @patch("sagemaker.session.Session")
123+
# def test_create_server_ex(
124+
# self,
125+
# mock_session,
126+
# mock_inference_spec,
127+
# mock_predictor,
128+
# ):
129+
# mock_start_serving = Mock()
130+
# mock_start_serving.side_effect = lambda *args, **kwargs: (
131+
# True,
132+
# None,
133+
# )
134+
135+
# mock_multi_model_server_deep_ping = Mock()
136+
# mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: (
137+
# False,
138+
# None,
139+
# )
140+
141+
# in_process_mode = InProcessMode(
142+
# model_server=ModelServer.MMS,
143+
# inference_spec=mock_inference_spec,
144+
# schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output),
145+
# session=mock_session,
146+
# model_path="model_path",
147+
# )
148+
149+
# in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping
150+
# in_process_mode._start_serving = mock_start_serving
151+
152+
# self.assertRaises(InProcessDeepPingException, in_process_mode.create_server, mock_predictor)

0 commit comments

Comments
 (0)