Skip to content

Commit d97c261

Browse files
author
Bryannah Hernandez
committed
unit test
1 parent d158f95 commit d97c261

File tree

2 files changed

+154
-6
lines changed

2 files changed

+154
-6
lines changed

src/sagemaker/serve/builder/requirements_manager.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@ def detect_file_exists(self, dependencies: str = None) -> str:
3636
"""
3737
dependencies = self._capture_from_local_runtime()
3838

39-
# No additional dependencies specified
40-
# if dependencies is None:
41-
# return None
42-
4339
# Dependencies specified as either req.txt or conda_env.yml
4440
if dependencies.endswith(".txt"):
4541
self._install_requirements_txt()
@@ -56,9 +52,9 @@ def _install_requirements_txt(self):
5652

5753
def _update_conda_env_in_path(self):
5854
"""Update conda env using conda yml file"""
59-
print("Updating conda env")
55+
logger.info("Updating conda env")
6056
subprocess.run("conda env update -f conda_in_process.yml", shell=True, check=True)
61-
print("Conda env updated successfully")
57+
logger.info("Conda env updated successfully")
6258

6359
def _get_active_conda_env_name(self) -> str:
6460
"""Returns the conda environment name from the set environment variable. None otherwise."""
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
# import unittest
16+
# from unittest.mock import patch, Mock
17+
18+
# from sagemaker.serve.mode.in_process_mode import InProcessMode
19+
# from sagemaker.serve import SchemaBuilder
20+
# from sagemaker.serve.utils.types import ModelServer
21+
# from sagemaker.serve.utils.exceptions import LocalDeepPingException
22+
23+
24+
# mock_prompt = "Hello, I'm a language model,"
25+
# mock_response = "Hello, I'm a language model, and I'm here to help you with your English."
26+
# mock_sample_input = {"inputs": mock_prompt, "parameters": {}}
27+
# mock_sample_output = [{"generated_text": mock_response}]
28+
29+
30+
# class TestRequirementsManager(unittest.TestCase):
31+
32+
# @patch("sagemaker.serve.mode.in_process_mode.Path")
33+
# @patch("sagemaker.serve.spec.inference_spec.InferenceSpec")
34+
# @patch("sagemaker.session.Session")
35+
# def test_load_happy(self, mock_session, mock_inference_spec, mock_path):
36+
# mock_path.return_value.exists.side_effect = lambda *args, **kwargs: True
37+
# mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: True
38+
39+
# mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load"
40+
41+
# mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output)
42+
# in_process_mode = InProcessMode(
43+
# model_server=ModelServer.MMS,
44+
# inference_spec=mock_inference_spec,
45+
# schema_builder=mock_schema_builder,
46+
# session=mock_session,
47+
# model_path="model_path",
48+
# env_vars={"key": "val"},
49+
# )
50+
51+
# res = in_process_mode.load(model_path="/tmp/model-builder/code/")
52+
53+
# self.assertEqual(res, "Dummy load")
54+
# self.assertEqual(in_process_mode.inference_spec, mock_inference_spec)
55+
# self.assertEqual(in_process_mode.schema_builder, mock_schema_builder)
56+
# self.assertEqual(in_process_mode.model_path, "model_path")
57+
# self.assertEqual(in_process_mode.env_vars, {"key": "val"})
58+
59+
# @patch("sagemaker.serve.mode.in_process_mode.Path")
60+
# @patch("sagemaker.serve.spec.inference_spec.InferenceSpec")
61+
# @patch("sagemaker.session.Session")
62+
# def test_load_ex(self, mock_session, mock_inference_spec, mock_path):
63+
# mock_path.return_value.exists.side_effect = lambda *args, **kwargs: False
64+
# mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: True
65+
66+
# mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load"
67+
68+
# mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output)
69+
# in_process_mode = InProcessMode(
70+
# model_server=ModelServer.MMS,
71+
# inference_spec=mock_inference_spec,
72+
# schema_builder=mock_schema_builder,
73+
# session=mock_session,
74+
# model_path="model_path",
75+
# )
76+
77+
# self.assertRaises(ValueError, in_process_mode.load, "/tmp/model-builder/code/")
78+
79+
# mock_path.return_value.exists.side_effect = lambda *args, **kwargs: True
80+
# mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: False
81+
82+
# mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load"
83+
# mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output)
84+
# in_process_mode = InProcessMode(
85+
# model_server=ModelServer.MMS,
86+
# inference_spec=mock_inference_spec,
87+
# schema_builder=mock_schema_builder,
88+
# session=mock_session,
89+
# model_path="model_path",
90+
# )
91+
92+
# self.assertRaises(ValueError, in_process_mode.load, "/tmp/model-builder/code/")
93+
94+
# @patch("sagemaker.serve.mode.in_process_mode.logger")
95+
# @patch("sagemaker.base_predictor.PredictorBase")
96+
# @patch("sagemaker.serve.spec.inference_spec.InferenceSpec")
97+
# @patch("sagemaker.session.Session")
98+
# def test_create_server_happy(
99+
# self, mock_session, mock_inference_spec, mock_predictor, mock_logger
100+
# ):
101+
# mock_response = "Fake response"
102+
# mock_multi_model_server_deep_ping = Mock()
103+
# mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: (
104+
# True,
105+
# mock_response,
106+
# )
107+
108+
# in_process_mode = InProcessMode(
109+
# model_server=ModelServer.MMS,
110+
# inference_spec=mock_inference_spec,
111+
# schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output),
112+
# session=mock_session,
113+
# model_path="model_path",
114+
# )
115+
116+
# in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping
117+
118+
# in_process_mode.create_server(predictor=mock_predictor)
119+
120+
# mock_logger.info.assert_called_once_with(
121+
# "Waiting for model server %s to start up...", ModelServer.MMS
122+
# )
123+
# mock_logger.debug.assert_called_once_with(
124+
# "Ping health check has passed. Returned %s", str(mock_response)
125+
# )
126+
127+
# @patch("sagemaker.base_predictor.PredictorBase")
128+
# @patch("sagemaker.serve.spec.inference_spec.InferenceSpec")
129+
# @patch("sagemaker.session.Session")
130+
# def test_create_server_ex(
131+
# self,
132+
# mock_session,
133+
# mock_inference_spec,
134+
# mock_predictor,
135+
# ):
136+
# mock_multi_model_server_deep_ping = Mock()
137+
# mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: (
138+
# False,
139+
# None,
140+
# )
141+
142+
# in_process_mode = InProcessMode(
143+
# model_server=ModelServer.MMS,
144+
# inference_spec=mock_inference_spec,
145+
# schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output),
146+
# session=mock_session,
147+
# model_path="model_path",
148+
# )
149+
150+
# in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping
151+
152+
# self.assertRaises(LocalDeepPingException, in_process_mode.create_server, mock_predictor)

0 commit comments

Comments
 (0)