@@ -186,3 +186,54 @@ def test_jumpstart_common_image_uri(
186
186
model_id = "pytorch-ic-mobilenet-v2" ,
187
187
instance_type = "ml.m5.xlarge" ,
188
188
)
189
+
190
+
191
+ @patch ("sagemaker.image_uris.JUMPSTART_LOGGER.info" )
192
+ @patch ("sagemaker.jumpstart.utils.validate_model_id_and_get_type" )
193
+ @patch ("sagemaker.jumpstart.artifacts.image_uris.verify_model_region_and_return_specs" )
194
+ @patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
195
+ def test_jumpstart_image_uri_logging_extra_fields (
196
+ patched_get_model_specs ,
197
+ patched_verify_model_region_and_return_specs ,
198
+ patched_validate_model_id_and_get_type ,
199
+ patched_info_log ,
200
+ ):
201
+
202
+ patched_verify_model_region_and_return_specs .side_effect = verify_model_region_and_return_specs
203
+ patched_get_model_specs .side_effect = get_spec_from_base_spec
204
+ patched_validate_model_id_and_get_type .return_value = JumpStartModelType .OPEN_WEIGHTS
205
+
206
+ region = "us-west-2"
207
+ mock_client = boto3 .client ("s3" )
208
+ mock_session = Mock (s3_client = mock_client , boto_region_name = region )
209
+
210
+ image_uris .retrieve (
211
+ framework = None ,
212
+ region = "us-west-2" ,
213
+ image_scope = "training" ,
214
+ model_id = "pytorch-ic-mobilenet-v2" ,
215
+ model_version = "*" ,
216
+ instance_type = "ml.m5.xlarge" ,
217
+ sagemaker_session = mock_session ,
218
+ )
219
+
220
+ patched_info_log .assert_not_called ()
221
+
222
+ image_uris .retrieve (
223
+ framework = "framework" ,
224
+ container_version = "1.2.3" ,
225
+ region = "us-west-2" ,
226
+ image_scope = "training" ,
227
+ model_id = "pytorch-ic-mobilenet-v2" ,
228
+ model_version = "*" ,
229
+ instance_type = "ml.m5.xlarge" ,
230
+ sagemaker_session = mock_session ,
231
+ )
232
+
233
+ patched_info_log .assert_called_once_with (
234
+ "Ignoring the following fields "
235
+ "when retriving image uri for "
236
+ "JumpStart model id '%s': %s" ,
237
+ "pytorch-ic-mobilenet-v2" ,
238
+ "{'framework': 'framework', 'container_version': '1.2.3'}" ,
239
+ )
0 commit comments