5050mock_secret_key = "mock_secret_key"
5151mock_instance_type = "mock instance type"
5252
53- supported_model_server = {
53+ supported_model_servers = {
5454 ModelServer .TORCHSERVE ,
5555 ModelServer .TRITON ,
5656 ModelServer .DJL_SERVING ,
5757 ModelServer .TENSORFLOW_SERVING ,
58+ ModelServer .MMS ,
59+ ModelServer .TGI ,
60+ ModelServer .TEI ,
5861}
5962
6063mock_session = MagicMock ()
@@ -78,7 +81,7 @@ def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSet
7881 builder = ModelBuilder (inference_spec = "some value" , model = Mock (spec = object ))
7982 self .assertRaisesRegex (
8083 Exception ,
81- "Cannot have both the Model and Inference spec in the builder " ,
84+ "Can only set one of the following: model, inference_spec. " ,
8285 builder .build ,
8386 Mode .SAGEMAKER_ENDPOINT ,
8487 mock_role_arn ,
@@ -91,7 +94,7 @@ def test_validation_unsupported_model_server_type(self, mock_serveSettings):
9194 self .assertRaisesRegex (
9295 Exception ,
9396 "%s is not supported yet! Supported model servers: %s"
94- % (builder .model_server , supported_model_server ),
97+ % (builder .model_server , supported_model_servers ),
9598 builder .build ,
9699 Mode .SAGEMAKER_ENDPOINT ,
97100 mock_role_arn ,
@@ -104,7 +107,7 @@ def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings
104107 self .assertRaisesRegex (
105108 Exception ,
106109 "Model_server must be set when non-first-party image_uri is set. "
107- + "Supported model servers: %s" % supported_model_server ,
110+ + "Supported model servers: %s" % supported_model_servers ,
108111 builder .build ,
109112 Mode .SAGEMAKER_ENDPOINT ,
110113 mock_role_arn ,
@@ -125,6 +128,120 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set
125128 mock_session ,
126129 )
127130
131+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
132+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl" )
133+ def test_model_server_override_djl_with_model (self , mock_build_for_djl , mock_serve_settings ):
134+ mock_setting_object = mock_serve_settings .return_value
135+ mock_setting_object .role_arn = mock_role_arn
136+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
137+
138+ builder = ModelBuilder (model_server = ModelServer .DJL_SERVING , model = "gpt_llm_burt" )
139+ builder .build (sagemaker_session = mock_session )
140+
141+ mock_build_for_djl .assert_called_once ()
142+
143+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
144+ def test_model_server_override_djl_without_model_or_mlflow (self , mock_serve_settings ):
145+ builder = ModelBuilder (
146+ model_server = ModelServer .DJL_SERVING , model = None , inference_spec = None
147+ )
148+ self .assertRaisesRegex (
149+ Exception ,
150+ "Missing required parameter `model` or 'ml_flow' path" ,
151+ builder .build ,
152+ Mode .SAGEMAKER_ENDPOINT ,
153+ mock_role_arn ,
154+ mock_session ,
155+ )
156+
157+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
158+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_torchserve" )
159+ def test_model_server_override_torchserve_with_model (
160+ self , mock_build_for_ts , mock_serve_settings
161+ ):
162+ mock_setting_object = mock_serve_settings .return_value
163+ mock_setting_object .role_arn = mock_role_arn
164+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
165+
166+ builder = ModelBuilder (model_server = ModelServer .TORCHSERVE , model = "gpt_llm_burt" )
167+ builder .build (sagemaker_session = mock_session )
168+
169+ mock_build_for_ts .assert_called_once ()
170+
171+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
172+ def test_model_server_override_torchserve_without_model_or_mlflow (self , mock_serve_settings ):
173+ builder = ModelBuilder (model_server = ModelServer .TORCHSERVE )
174+ self .assertRaisesRegex (
175+ Exception ,
176+ "Missing required parameter `model` or 'ml_flow' path" ,
177+ builder .build ,
178+ Mode .SAGEMAKER_ENDPOINT ,
179+ mock_role_arn ,
180+ mock_session ,
181+ )
182+
183+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
184+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_triton" )
185+ def test_model_server_override_triton_with_model (self , mock_build_for_ts , mock_serve_settings ):
186+ mock_setting_object = mock_serve_settings .return_value
187+ mock_setting_object .role_arn = mock_role_arn
188+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
189+
190+ builder = ModelBuilder (model_server = ModelServer .TRITON , model = "gpt_llm_burt" )
191+ builder .build (sagemaker_session = mock_session )
192+
193+ mock_build_for_ts .assert_called_once ()
194+
195+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
196+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tensorflow_serving" )
197+ def test_model_server_override_tensor_with_model (self , mock_build_for_ts , mock_serve_settings ):
198+ mock_setting_object = mock_serve_settings .return_value
199+ mock_setting_object .role_arn = mock_role_arn
200+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
201+
202+ builder = ModelBuilder (model_server = ModelServer .TENSORFLOW_SERVING , model = "gpt_llm_burt" )
203+ builder .build (sagemaker_session = mock_session )
204+
205+ mock_build_for_ts .assert_called_once ()
206+
207+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
208+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tei" )
209+ def test_model_server_override_tei_with_model (self , mock_build_for_ts , mock_serve_settings ):
210+ mock_setting_object = mock_serve_settings .return_value
211+ mock_setting_object .role_arn = mock_role_arn
212+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
213+
214+ builder = ModelBuilder (model_server = ModelServer .TEI , model = "gpt_llm_burt" )
215+ builder .build (sagemaker_session = mock_session )
216+
217+ mock_build_for_ts .assert_called_once ()
218+
219+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
220+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tgi" )
221+ def test_model_server_override_tgi_with_model (self , mock_build_for_ts , mock_serve_settings ):
222+ mock_setting_object = mock_serve_settings .return_value
223+ mock_setting_object .role_arn = mock_role_arn
224+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
225+
226+ builder = ModelBuilder (model_server = ModelServer .TGI , model = "gpt_llm_burt" )
227+ builder .build (sagemaker_session = mock_session )
228+
229+ mock_build_for_ts .assert_called_once ()
230+
231+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
232+ @patch ("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers" )
233+ def test_model_server_override_transformers_with_model (
234+ self , mock_build_for_ts , mock_serve_settings
235+ ):
236+ mock_setting_object = mock_serve_settings .return_value
237+ mock_setting_object .role_arn = mock_role_arn
238+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
239+
240+ builder = ModelBuilder (model_server = ModelServer .MMS , model = "gpt_llm_burt" )
241+ builder .build (sagemaker_session = mock_session )
242+
243+ mock_build_for_ts .assert_called_once ()
244+
128245 @patch ("os.makedirs" , Mock ())
129246 @patch ("sagemaker.serve.builder.model_builder._detect_framework_and_version" )
130247 @patch ("sagemaker.serve.builder.model_builder.prepare_for_torchserve" )
0 commit comments