|
12 | 12 | # language governing permissions and limitations under the License.
|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
15 |
| -# import unittest |
16 |
| -# from unittest.mock import patch, Mock |
| 15 | +import unittest |
| 16 | +import subprocess |
| 17 | +from unittest.mock import patch, Mock |
17 | 18 |
|
18 |
| -# from sagemaker.serve.mode.in_process_mode import InProcessMode |
19 |
| -# from sagemaker.serve import SchemaBuilder |
20 |
| -# from sagemaker.serve.utils.types import ModelServer |
21 |
| -# from sagemaker.serve.utils.exceptions import LocalDeepPingException |
| 19 | +from sagemaker.serve.mode.in_process_mode import InProcessMode |
| 20 | +from sagemaker.serve.builder.requirements_manager import RequirementsManager |
22 | 21 |
|
| 22 | +class TestRequirementsManager(unittest.TestCase): |
23 | 23 |
|
24 |
| -# mock_prompt = "Hello, I'm a language model," |
25 |
| -# mock_response = "Hello, I'm a language model, and I'm here to help you with your English." |
26 |
| -# mock_sample_input = {"inputs": mock_prompt, "parameters": {}} |
27 |
| -# mock_sample_output = [{"generated_text": mock_response}] |
| 24 | + def test_detect_file_exists_fail(self, mock_dependencies: str = None) -> str: |
| 25 | + mock_dependencies = "mock.ini" |
| 26 | + self.assertRaises(ValueError, RequirementsManager().detect_file_exists(mock_dependencies)) |
28 | 27 |
|
| 28 | + @patch("sagemaker.serve.mode.in_process_mode.logger") |
| 29 | + @patch("sagemaker.session.Session") |
| 30 | + def test_install_requirements_txt(self, mock_logger): |
29 | 31 |
|
30 |
| -# class TestRequirementsManager(unittest.TestCase): |
| 32 | + mock_logger.info.assert_called_once_with("Running command to pip install") |
31 | 33 |
|
32 |
| -# @patch("sagemaker.serve.mode.in_process_mode.Path") |
33 |
| -# @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") |
34 |
| -# @patch("sagemaker.session.Session") |
35 |
| -# def test_load_happy(self, mock_session, mock_inference_spec, mock_path): |
36 |
| -# mock_path.return_value.exists.side_effect = lambda *args, **kwargs: True |
37 |
| -# mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: True |
| 34 | + mock_logger.info.assert_called_once_with("Command ran successfully") |
38 | 35 |
|
39 |
| -# mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" |
| 36 | + @patch("sagemaker.serve.mode.in_process_mode.logger") |
| 37 | + @patch("sagemaker.session.Session") |
| 38 | + def test_update_conda_env_in_path(self, mock_logger): |
40 | 39 |
|
41 |
| -# mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output) |
42 |
| -# in_process_mode = InProcessMode( |
43 |
| -# model_server=ModelServer.MMS, |
44 |
| -# inference_spec=mock_inference_spec, |
45 |
| -# schema_builder=mock_schema_builder, |
46 |
| -# session=mock_session, |
47 |
| -# model_path="model_path", |
48 |
| -# env_vars={"key": "val"}, |
49 |
| -# ) |
| 40 | + mock_logger.info.assert_called_once_with("Updating conda env") |
50 | 41 |
|
51 |
| -# res = in_process_mode.load(model_path="/tmp/model-builder/code/") |
52 | 42 |
|
53 |
| -# self.assertEqual(res, "Dummy load") |
54 |
| -# self.assertEqual(in_process_mode.inference_spec, mock_inference_spec) |
55 |
| -# self.assertEqual(in_process_mode.schema_builder, mock_schema_builder) |
56 |
| -# self.assertEqual(in_process_mode.model_path, "model_path") |
57 |
| -# self.assertEqual(in_process_mode.env_vars, {"key": "val"}) |
| 43 | + # mock_multi_model_server_deep_ping = Mock() |
| 44 | + # mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: ( |
| 45 | + # True, |
| 46 | + # ) |
58 | 47 |
|
59 |
| -# @patch("sagemaker.serve.mode.in_process_mode.Path") |
60 |
| -# @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") |
61 |
| -# @patch("sagemaker.session.Session") |
62 |
| -# def test_load_ex(self, mock_session, mock_inference_spec, mock_path): |
63 |
| -# mock_path.return_value.exists.side_effect = lambda *args, **kwargs: False |
64 |
| -# mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: True |
| 48 | + # in_process_mode = InProcessMode( |
| 49 | + # model_server=ModelServer.MMS, |
| 50 | + # inference_spec=mock_inference_spec, |
| 51 | + # schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output), |
| 52 | + # session=mock_session, |
| 53 | + # model_path="model_path", |
| 54 | + # ) |
65 | 55 |
|
66 |
| -# mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" |
| 56 | + # in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping |
67 | 57 |
|
68 |
| -# mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output) |
69 |
| -# in_process_mode = InProcessMode( |
70 |
| -# model_server=ModelServer.MMS, |
71 |
| -# inference_spec=mock_inference_spec, |
72 |
| -# schema_builder=mock_schema_builder, |
73 |
| -# session=mock_session, |
74 |
| -# model_path="model_path", |
75 |
| -# ) |
| 58 | + # in_process_mode.create_server(predictor=mock_predictor) |
76 | 59 |
|
77 |
| -# self.assertRaises(ValueError, in_process_mode.load, "/tmp/model-builder/code/") |
78 |
| - |
79 |
| -# mock_path.return_value.exists.side_effect = lambda *args, **kwargs: True |
80 |
| -# mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: False |
81 |
| - |
82 |
| -# mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" |
83 |
| -# mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output) |
84 |
| -# in_process_mode = InProcessMode( |
85 |
| -# model_server=ModelServer.MMS, |
86 |
| -# inference_spec=mock_inference_spec, |
87 |
| -# schema_builder=mock_schema_builder, |
88 |
| -# session=mock_session, |
89 |
| -# model_path="model_path", |
90 |
| -# ) |
91 |
| - |
92 |
| -# self.assertRaises(ValueError, in_process_mode.load, "/tmp/model-builder/code/") |
93 |
| - |
94 |
| -# @patch("sagemaker.serve.mode.in_process_mode.logger") |
95 |
| -# @patch("sagemaker.base_predictor.PredictorBase") |
96 |
| -# @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") |
97 |
| -# @patch("sagemaker.session.Session") |
98 |
| -# def test_create_server_happy( |
99 |
| -# self, mock_session, mock_inference_spec, mock_predictor, mock_logger |
100 |
| -# ): |
101 |
| -# mock_response = "Fake response" |
102 |
| -# mock_multi_model_server_deep_ping = Mock() |
103 |
| -# mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: ( |
104 |
| -# True, |
105 |
| -# mock_response, |
106 |
| -# ) |
107 |
| - |
108 |
| -# in_process_mode = InProcessMode( |
109 |
| -# model_server=ModelServer.MMS, |
110 |
| -# inference_spec=mock_inference_spec, |
111 |
| -# schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output), |
112 |
| -# session=mock_session, |
113 |
| -# model_path="model_path", |
114 |
| -# ) |
115 |
| - |
116 |
| -# in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping |
117 |
| - |
118 |
| -# in_process_mode.create_server(predictor=mock_predictor) |
119 |
| - |
120 |
| -# mock_logger.info.assert_called_once_with( |
121 |
| -# "Waiting for model server %s to start up...", ModelServer.MMS |
122 |
| -# ) |
123 |
| -# mock_logger.debug.assert_called_once_with( |
124 |
| -# "Ping health check has passed. Returned %s", str(mock_response) |
125 |
| -# ) |
126 |
| - |
127 |
| -# @patch("sagemaker.base_predictor.PredictorBase") |
128 |
| -# @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") |
129 |
| -# @patch("sagemaker.session.Session") |
130 |
| -# def test_create_server_ex( |
131 |
| -# self, |
132 |
| -# mock_session, |
133 |
| -# mock_inference_spec, |
134 |
| -# mock_predictor, |
135 |
| -# ): |
136 |
| -# mock_multi_model_server_deep_ping = Mock() |
137 |
| -# mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: ( |
138 |
| -# False, |
139 |
| -# None, |
140 |
| -# ) |
141 |
| - |
142 |
| -# in_process_mode = InProcessMode( |
143 |
| -# model_server=ModelServer.MMS, |
144 |
| -# inference_spec=mock_inference_spec, |
145 |
| -# schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output), |
146 |
| -# session=mock_session, |
147 |
| -# model_path="model_path", |
148 |
| -# ) |
149 |
| - |
150 |
| -# in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping |
151 |
| - |
152 |
| -# self.assertRaises(LocalDeepPingException, in_process_mode.create_server, mock_predictor) |
| 60 | + mock_logger.info.assert_called_once_with("Conda env updated successfully") |
0 commit comments