Skip to content

Commit c83e0ce

Browse files
author
Bryannah Hernandez
committed
unit test fir app
1 parent ebfc3e2 commit c83e0ce

File tree

1 file changed

+53
-130
lines changed

1 file changed

+53
-130
lines changed

tests/unit/sagemaker/serve/test_app.py

Lines changed: 53 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -13,140 +13,63 @@
1313
from __future__ import absolute_import
1414

1515
import unittest
16+
1617
from unittest.mock import patch, Mock
18+
from sagemaker.serve.app import InProcessServer
19+
20+
mock_model_id = "mock_model_id"
1721

18-
from sagemaker.serve.mode.in_process_mode import InProcessMode
19-
from sagemaker.serve import SchemaBuilder
20-
from sagemaker.serve.utils.types import ModelServer
21-
from sagemaker.serve.utils.exceptions import InProcessDeepPingException
2222

23+
class TestAppInProcessServer(unittest.TestCase):
24+
@patch("sagemaker.serve.app.threading")
25+
@patch("sagemaker.serve.app.pipeline")
26+
def test_in_process_server_init(self, mock_pipeline, mock_threading):
27+
mock_generator = Mock()
28+
mock_generator.side_effect = None
2329

24-
mock_prompt = "Hello, I'm a language model,"
25-
mock_response = "Hello, I'm a language model, and I'm here to help you with your English."
26-
mock_sample_input = {"inputs": mock_prompt, "parameters": {}}
27-
mock_sample_output = [{"generated_text": mock_response}]
30+
in_process_server = InProcessServer(model_id=mock_model_id)
31+
in_process_server._generator = mock_generator
2832

33+
@patch("sagemaker.serve.app.logger")
34+
@patch("sagemaker.serve.app.threading")
35+
@patch("sagemaker.serve.app.pipeline")
36+
def test_start_server(self, mock_pipeline, mock_threading, mock_logger):
37+
mock_generator = Mock()
38+
mock_generator.side_effect = None
39+
mock_thread = Mock()
40+
mock_threading.Thread.return_value = mock_thread
2941

30-
class TestAppInProcessServer(unittest.TestCase):
42+
in_process_server = InProcessServer(model_id=mock_model_id)
43+
in_process_server._generator = mock_generator
44+
45+
in_process_server.start_server()
46+
47+
mock_logger.info.assert_called()
48+
mock_thread.start.assert_called()
49+
50+
@patch("sagemaker.serve.app.asyncio")
51+
@patch("sagemaker.serve.app.pipeline")
52+
def test_start_run_async_in_thread(self, mock_pipeline, mock_asyncio):
53+
mock_pipeline.side_effect = None
54+
55+
mock_loop = Mock()
56+
mock_asyncio.new_event_loop.side_effect = lambda: mock_loop
57+
58+
in_process_server = InProcessServer(model_id=mock_model_id)
59+
in_process_server._start_run_async_in_thread()
60+
61+
mock_asyncio.set_event_loop.assert_called_once_with(mock_loop)
62+
mock_loop.run_until_complete.assert_called()
63+
64+
@patch("sagemaker.serve.app.pipeline")
65+
async def test_serve(self, mock_pipeline):
66+
mock_pipeline.side_effect = None
67+
68+
mock_server = Mock()
69+
70+
in_process_server = InProcessServer(model_id=mock_model_id)
71+
in_process_server.server = mock_server
72+
73+
await in_process_server._serve()
3174

32-
@patch("sagemaker.server.app.uvicorn")
33-
def test_uvicorn_import(self, mock_uvicorn):
34-
mock_uvicorn.return_value.exists.side_effect = lambda *args, **kwargs: False
35-
self.assertRaises(ImportError, in_process_mode.load, "/tmp/model-builder/code/")
36-
37-
def test_transformers_import(self):
38-
self.assertRaises(ImportError, in_process_mode.load, "/tmp/model-builder/code/")
39-
40-
def test_fastapi_import(self):
41-
self.assertRaises(ImportError, in_process_mode.load, "/tmp/model-builder/code/")
42-
43-
44-
45-
# @patch("sagemaker.serve.mode.in_process_mode.Path")
46-
# @patch("sagemaker.serve.spec.inference_spec.InferenceSpec")
47-
# @patch("sagemaker.session.Session")
48-
# def test_load_ex(self, mock_session, mock_inference_spec, mock_path):
49-
# mock_path.return_value.exists.side_effect = lambda *args, **kwargs: False
50-
# mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: True
51-
52-
# mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load"
53-
54-
# mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output)
55-
# in_process_mode = InProcessMode(
56-
# model_server=ModelServer.MMS,
57-
# inference_spec=mock_inference_spec,
58-
# schema_builder=mock_schema_builder,
59-
# session=mock_session,
60-
# model_path="model_path",
61-
# )
62-
63-
# self.assertRaises(ValueError, in_process_mode.load, "/tmp/model-builder/code/")
64-
65-
# mock_path.return_value.exists.side_effect = lambda *args, **kwargs: True
66-
# mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: False
67-
68-
# mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load"
69-
# mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output)
70-
# in_process_mode = InProcessMode(
71-
# model_server=ModelServer.MMS,
72-
# inference_spec=mock_inference_spec,
73-
# schema_builder=mock_schema_builder,
74-
# session=mock_session,
75-
# model_path="model_path",
76-
# )
77-
78-
# self.assertRaises(ValueError, in_process_mode.load, "/tmp/model-builder/code/")
79-
80-
# @patch("sagemaker.serve.mode.in_process_mode.logger")
81-
# @patch("sagemaker.base_predictor.PredictorBase")
82-
# @patch("sagemaker.serve.spec.inference_spec.InferenceSpec")
83-
# @patch("sagemaker.session.Session")
84-
# def test_create_server_happy(
85-
# self, mock_session, mock_inference_spec, mock_predictor, mock_logger
86-
# ):
87-
# mock_start_serving = Mock()
88-
# mock_start_serving.side_effect = lambda *args, **kwargs: (
89-
# True,
90-
# None,
91-
# )
92-
93-
# mock_response = "Fake response"
94-
# mock_multi_model_server_deep_ping = Mock()
95-
# mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: (
96-
# True,
97-
# mock_response,
98-
# )
99-
100-
# in_process_mode = InProcessMode(
101-
# model_server=ModelServer.MMS,
102-
# inference_spec=mock_inference_spec,
103-
# schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output),
104-
# session=mock_session,
105-
# model_path="model_path",
106-
# )
107-
108-
# in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping
109-
# in_process_mode._start_serving = mock_start_serving
110-
111-
# in_process_mode.create_server(predictor=mock_predictor)
112-
113-
# mock_logger.info.assert_called_once_with(
114-
# "Waiting for model server %s to start up...", ModelServer.MMS
115-
# )
116-
# mock_logger.debug.assert_called_once_with(
117-
# "Ping health check has passed. Returned %s", str(mock_response)
118-
# )
119-
120-
# @patch("sagemaker.base_predictor.PredictorBase")
121-
# @patch("sagemaker.serve.spec.inference_spec.InferenceSpec")
122-
# @patch("sagemaker.session.Session")
123-
# def test_create_server_ex(
124-
# self,
125-
# mock_session,
126-
# mock_inference_spec,
127-
# mock_predictor,
128-
# ):
129-
# mock_start_serving = Mock()
130-
# mock_start_serving.side_effect = lambda *args, **kwargs: (
131-
# True,
132-
# None,
133-
# )
134-
135-
# mock_multi_model_server_deep_ping = Mock()
136-
# mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: (
137-
# False,
138-
# None,
139-
# )
140-
141-
# in_process_mode = InProcessMode(
142-
# model_server=ModelServer.MMS,
143-
# inference_spec=mock_inference_spec,
144-
# schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output),
145-
# session=mock_session,
146-
# model_path="model_path",
147-
# )
148-
149-
# in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping
150-
# in_process_mode._start_serving = mock_start_serving
151-
152-
# self.assertRaises(InProcessDeepPingException, in_process_mode.create_server, mock_predictor)
75+
mock_server.serve.assert_called()

0 commit comments

Comments
 (0)