|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
15 | 15 | import unittest
|
| 16 | + |
16 | 17 | from unittest.mock import patch, Mock
|
| 18 | +from sagemaker.serve.app import InProcessServer |
| 19 | + |
| 20 | +mock_model_id = "mock_model_id" |
17 | 21 |
|
18 |
| -from sagemaker.serve.mode.in_process_mode import InProcessMode |
19 |
| -from sagemaker.serve import SchemaBuilder |
20 |
| -from sagemaker.serve.utils.types import ModelServer |
21 |
| -from sagemaker.serve.utils.exceptions import InProcessDeepPingException |
22 | 22 |
|
| 23 | +class TestAppInProcessServer(unittest.TestCase): |
| 24 | + @patch("sagemaker.serve.app.threading") |
| 25 | + @patch("sagemaker.serve.app.pipeline") |
| 26 | + def test_in_process_server_init(self, mock_pipeline, mock_threading): |
| 27 | + mock_generator = Mock() |
| 28 | + mock_generator.side_effect = None |
23 | 29 |
|
24 |
| -mock_prompt = "Hello, I'm a language model," |
25 |
| -mock_response = "Hello, I'm a language model, and I'm here to help you with your English." |
26 |
| -mock_sample_input = {"inputs": mock_prompt, "parameters": {}} |
27 |
| -mock_sample_output = [{"generated_text": mock_response}] |
| 30 | + in_process_server = InProcessServer(model_id=mock_model_id) |
| 31 | + in_process_server._generator = mock_generator |
28 | 32 |
|
| 33 | + @patch("sagemaker.serve.app.logger") |
| 34 | + @patch("sagemaker.serve.app.threading") |
| 35 | + @patch("sagemaker.serve.app.pipeline") |
| 36 | + def test_start_server(self, mock_pipeline, mock_threading, mock_logger): |
| 37 | + mock_generator = Mock() |
| 38 | + mock_generator.side_effect = None |
| 39 | + mock_thread = Mock() |
| 40 | + mock_threading.Thread.return_value = mock_thread |
29 | 41 |
|
30 |
| -class TestAppInProcessServer(unittest.TestCase): |
| 42 | + in_process_server = InProcessServer(model_id=mock_model_id) |
| 43 | + in_process_server._generator = mock_generator |
| 44 | + |
| 45 | + in_process_server.start_server() |
| 46 | + |
| 47 | + mock_logger.info.assert_called() |
| 48 | + mock_thread.start.assert_called() |
| 49 | + |
| 50 | + @patch("sagemaker.serve.app.asyncio") |
| 51 | + @patch("sagemaker.serve.app.pipeline") |
| 52 | + def test_start_run_async_in_thread(self, mock_pipeline, mock_asyncio): |
| 53 | + mock_pipeline.side_effect = None |
| 54 | + |
| 55 | + mock_loop = Mock() |
| 56 | + mock_asyncio.new_event_loop.side_effect = lambda: mock_loop |
| 57 | + |
| 58 | + in_process_server = InProcessServer(model_id=mock_model_id) |
| 59 | + in_process_server._start_run_async_in_thread() |
| 60 | + |
| 61 | + mock_asyncio.set_event_loop.assert_called_once_with(mock_loop) |
| 62 | + mock_loop.run_until_complete.assert_called() |
| 63 | + |
| 64 | + @patch("sagemaker.serve.app.pipeline") |
| 65 | + async def test_serve(self, mock_pipeline): |
| 66 | + mock_pipeline.side_effect = None |
| 67 | + |
| 68 | + mock_server = Mock() |
| 69 | + |
| 70 | + in_process_server = InProcessServer(model_id=mock_model_id) |
| 71 | + in_process_server.server = mock_server |
| 72 | + |
| 73 | + await in_process_server._serve() |
31 | 74 |
|
32 |
| - @patch("sagemaker.server.app.uvicorn") |
33 |
| - def test_uvicorn_import(self, mock_uvicorn): |
34 |
| - mock_uvicorn.return_value.exists.side_effect = lambda *args, **kwargs: False |
35 |
| - self.assertRaises(ImportError, in_process_mode.load, "/tmp/model-builder/code/") |
36 |
| - |
37 |
| - def test_transformers_import(self): |
38 |
| - self.assertRaises(ImportError, in_process_mode.load, "/tmp/model-builder/code/") |
39 |
| - |
40 |
| - def test_fastapi_import(self): |
41 |
| - self.assertRaises(ImportError, in_process_mode.load, "/tmp/model-builder/code/") |
42 |
| - |
43 |
| - |
44 |
| - |
45 |
| - # @patch("sagemaker.serve.mode.in_process_mode.Path") |
46 |
| - # @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") |
47 |
| - # @patch("sagemaker.session.Session") |
48 |
| - # def test_load_ex(self, mock_session, mock_inference_spec, mock_path): |
49 |
| - # mock_path.return_value.exists.side_effect = lambda *args, **kwargs: False |
50 |
| - # mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: True |
51 |
| - |
52 |
| - # mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" |
53 |
| - |
54 |
| - # mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output) |
55 |
| - # in_process_mode = InProcessMode( |
56 |
| - # model_server=ModelServer.MMS, |
57 |
| - # inference_spec=mock_inference_spec, |
58 |
| - # schema_builder=mock_schema_builder, |
59 |
| - # session=mock_session, |
60 |
| - # model_path="model_path", |
61 |
| - # ) |
62 |
| - |
63 |
| - # self.assertRaises(ValueError, in_process_mode.load, "/tmp/model-builder/code/") |
64 |
| - |
65 |
| - # mock_path.return_value.exists.side_effect = lambda *args, **kwargs: True |
66 |
| - # mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: False |
67 |
| - |
68 |
| - # mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" |
69 |
| - # mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output) |
70 |
| - # in_process_mode = InProcessMode( |
71 |
| - # model_server=ModelServer.MMS, |
72 |
| - # inference_spec=mock_inference_spec, |
73 |
| - # schema_builder=mock_schema_builder, |
74 |
| - # session=mock_session, |
75 |
| - # model_path="model_path", |
76 |
| - # ) |
77 |
| - |
78 |
| - # self.assertRaises(ValueError, in_process_mode.load, "/tmp/model-builder/code/") |
79 |
| - |
80 |
| - # @patch("sagemaker.serve.mode.in_process_mode.logger") |
81 |
| - # @patch("sagemaker.base_predictor.PredictorBase") |
82 |
| - # @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") |
83 |
| - # @patch("sagemaker.session.Session") |
84 |
| - # def test_create_server_happy( |
85 |
| - # self, mock_session, mock_inference_spec, mock_predictor, mock_logger |
86 |
| - # ): |
87 |
| - # mock_start_serving = Mock() |
88 |
| - # mock_start_serving.side_effect = lambda *args, **kwargs: ( |
89 |
| - # True, |
90 |
| - # None, |
91 |
| - # ) |
92 |
| - |
93 |
| - # mock_response = "Fake response" |
94 |
| - # mock_multi_model_server_deep_ping = Mock() |
95 |
| - # mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: ( |
96 |
| - # True, |
97 |
| - # mock_response, |
98 |
| - # ) |
99 |
| - |
100 |
| - # in_process_mode = InProcessMode( |
101 |
| - # model_server=ModelServer.MMS, |
102 |
| - # inference_spec=mock_inference_spec, |
103 |
| - # schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output), |
104 |
| - # session=mock_session, |
105 |
| - # model_path="model_path", |
106 |
| - # ) |
107 |
| - |
108 |
| - # in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping |
109 |
| - # in_process_mode._start_serving = mock_start_serving |
110 |
| - |
111 |
| - # in_process_mode.create_server(predictor=mock_predictor) |
112 |
| - |
113 |
| - # mock_logger.info.assert_called_once_with( |
114 |
| - # "Waiting for model server %s to start up...", ModelServer.MMS |
115 |
| - # ) |
116 |
| - # mock_logger.debug.assert_called_once_with( |
117 |
| - # "Ping health check has passed. Returned %s", str(mock_response) |
118 |
| - # ) |
119 |
| - |
120 |
| - # @patch("sagemaker.base_predictor.PredictorBase") |
121 |
| - # @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") |
122 |
| - # @patch("sagemaker.session.Session") |
123 |
| - # def test_create_server_ex( |
124 |
| - # self, |
125 |
| - # mock_session, |
126 |
| - # mock_inference_spec, |
127 |
| - # mock_predictor, |
128 |
| - # ): |
129 |
| - # mock_start_serving = Mock() |
130 |
| - # mock_start_serving.side_effect = lambda *args, **kwargs: ( |
131 |
| - # True, |
132 |
| - # None, |
133 |
| - # ) |
134 |
| - |
135 |
| - # mock_multi_model_server_deep_ping = Mock() |
136 |
| - # mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: ( |
137 |
| - # False, |
138 |
| - # None, |
139 |
| - # ) |
140 |
| - |
141 |
| - # in_process_mode = InProcessMode( |
142 |
| - # model_server=ModelServer.MMS, |
143 |
| - # inference_spec=mock_inference_spec, |
144 |
| - # schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output), |
145 |
| - # session=mock_session, |
146 |
| - # model_path="model_path", |
147 |
| - # ) |
148 |
| - |
149 |
| - # in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping |
150 |
| - # in_process_mode._start_serving = mock_start_serving |
151 |
| - |
152 |
| - # self.assertRaises(InProcessDeepPingException, in_process_mode.create_server, mock_predictor) |
| 75 | + mock_server.serve.assert_called() |
0 commit comments