2929REGION = "us-west-2"
3030SCRIPT_PATH = "script.py"
3131TIMESTAMP = "2017-10-10-14-14-15"
32+ ECR_PREFIX_FORMAT = "{}.dkr.ecr.mars-south-3.amazonaws.com"
3233
34+ MOCK_ACCOUNT = "520713654638"
3335MOCK_FRAMEWORK = "mlfw"
3436MOCK_REGION = "mars-south-3"
3537MOCK_ACCELERATOR = "eia1.medium"
@@ -165,7 +167,9 @@ def sagemaker_session():
165167 return session_mock
166168
167169
168- def test_create_image_uri_cpu ():
170+ @patch ("sagemaker.fw_utils.get_ecr_image_uri_prefix" )
171+ def test_create_image_uri_cpu (ecr_prefix ):
172+ ecr_prefix .return_value = ECR_PREFIX_FORMAT .format ("23" )
169173 image_uri = fw_utils .create_image_uri (
170174 MOCK_REGION , MOCK_FRAMEWORK , "ml.c4.large" , "1.0rc" , "py2" , "23"
171175 )
@@ -176,20 +180,23 @@ def test_create_image_uri_cpu():
176180 )
177181 assert image_uri == "23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2"
178182
183+ ecr_prefix .return_value = "246785580436.dkr.ecr.us-gov-west-1.amazonaws.com"
179184 image_uri = fw_utils .create_image_uri (
180185 "us-gov-west-1" , MOCK_FRAMEWORK , "ml.c4.large" , "1.0rc" , "py2"
181186 )
182187 assert (
183188 image_uri == "246785580436.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2"
184189 )
185190
191+ ecr_prefix .return_value = "744548109606.dkr.ecr.us-iso-east-1.c2s.ic.gov"
186192 image_uri = fw_utils .create_image_uri (
187193 "us-iso-east-1" , MOCK_FRAMEWORK , "ml.c4.large" , "1.0rc" , "py2"
188194 )
189195 assert image_uri == "744548109606.dkr.ecr.us-iso-east-1.c2s.ic.gov/sagemaker-mlfw:1.0rc-cpu-py2"
190196
191197
192- def test_create_image_uri_no_python ():
198+ @patch ("sagemaker.fw_utils.get_ecr_image_uri_prefix" , return_value = ECR_PREFIX_FORMAT .format ("23" ))
199+ def test_create_image_uri_no_python (ecr_prefix ):
193200 image_uri = fw_utils .create_image_uri (
194201 MOCK_REGION , MOCK_FRAMEWORK , "ml.c4.large" , "1.0rc" , account = "23"
195202 )
@@ -201,7 +208,8 @@ def test_create_image_uri_bad_python():
201208 fw_utils .create_image_uri (MOCK_REGION , MOCK_FRAMEWORK , "ml.c4.large" , "1.0rc" , "py0" )
202209
203210
204- def test_create_image_uri_gpu ():
211+ @patch ("sagemaker.fw_utils.get_ecr_image_uri_prefix" , return_value = ECR_PREFIX_FORMAT .format ("23" ))
212+ def test_create_image_uri_gpu (ecr_prefix ):
205213 image_uri = fw_utils .create_image_uri (
206214 MOCK_REGION , MOCK_FRAMEWORK , "ml.p3.2xlarge" , "1.0rc" , "py3" , "23"
207215 )
@@ -213,7 +221,8 @@ def test_create_image_uri_gpu():
213221 assert image_uri == "23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3"
214222
215223
216- def test_create_image_uri_accelerator_tfs ():
224+ @patch ("sagemaker.fw_utils.get_ecr_image_uri_prefix" , return_value = ECR_PREFIX_FORMAT .format ("23" ))
225+ def test_create_image_uri_accelerator_tfs (ecr_prefix ):
217226 image_uri = fw_utils .create_image_uri (
218227 MOCK_REGION ,
219228 "tensorflow-serving" ,
@@ -228,7 +237,11 @@ def test_create_image_uri_accelerator_tfs():
228237 )
229238
230239
231- def test_create_image_uri_default_account ():
240+ @patch (
241+ "sagemaker.fw_utils.get_ecr_image_uri_prefix" ,
242+ return_value = ECR_PREFIX_FORMAT .format (MOCK_ACCOUNT ),
243+ )
244+ def test_create_image_uri_default_account (ecr_prefix ):
232245 image_uri = fw_utils .create_image_uri (
233246 MOCK_REGION , MOCK_FRAMEWORK , "ml.p3.2xlarge" , "1.0rc" , "py3"
234247 )
@@ -511,7 +524,11 @@ def test_create_image_uri_tensorflow(tf_version):
511524 )
512525
513526
514- def test_create_image_uri_accelerator_tf ():
527+ @patch (
528+ "sagemaker.fw_utils.get_ecr_image_uri_prefix" ,
529+ return_value = ECR_PREFIX_FORMAT .format (MOCK_ACCOUNT ),
530+ )
531+ def test_create_image_uri_accelerator_tf (ecr_prefix ):
515532 image_uri = fw_utils .create_image_uri (
516533 MOCK_REGION , "tensorflow" , "ml.p3.2xlarge" , "1.0" , "py3" , accelerator_type = "ml.eia1.medium"
517534 )
@@ -521,7 +538,11 @@ def test_create_image_uri_accelerator_tf():
521538 )
522539
523540
524- def test_create_image_uri_accelerator_mxnet_serving ():
541+ @patch (
542+ "sagemaker.fw_utils.get_ecr_image_uri_prefix" ,
543+ return_value = ECR_PREFIX_FORMAT .format (MOCK_ACCOUNT ),
544+ )
545+ def test_create_image_uri_accelerator_mxnet_serving (ecr_prefix ):
525546 image_uri = fw_utils .create_image_uri (
526547 MOCK_REGION ,
527548 "mxnet-serving" ,
@@ -536,7 +557,11 @@ def test_create_image_uri_accelerator_mxnet_serving():
536557 )
537558
538559
539- def test_create_image_uri_local_sagemaker_notebook_accelerator ():
560+ @patch (
561+ "sagemaker.fw_utils.get_ecr_image_uri_prefix" ,
562+ return_value = ECR_PREFIX_FORMAT .format (MOCK_ACCOUNT ),
563+ )
564+ def test_create_image_uri_local_sagemaker_notebook_accelerator (ecr_prefix ):
540565 image_uri = fw_utils .create_image_uri (
541566 MOCK_REGION ,
542567 "mxnet" ,
@@ -608,7 +633,11 @@ def test_invalid_instance_type():
608633 fw_utils .create_image_uri (MOCK_REGION , MOCK_FRAMEWORK , "p3.2xlarge" , "1.0.0" , "py3" )
609634
610635
611- def test_optimized_family ():
636+ @patch (
637+ "sagemaker.fw_utils.get_ecr_image_uri_prefix" ,
638+ return_value = ECR_PREFIX_FORMAT .format (MOCK_ACCOUNT ),
639+ )
640+ def test_optimized_family (ecr_prefix ):
612641 image_uri = fw_utils .create_image_uri (
613642 MOCK_REGION ,
614643 MOCK_FRAMEWORK ,
@@ -622,7 +651,11 @@ def test_optimized_family():
622651 )
623652
624653
625- def test_unoptimized_cpu_family ():
654+ @patch (
655+ "sagemaker.fw_utils.get_ecr_image_uri_prefix" ,
656+ return_value = ECR_PREFIX_FORMAT .format (MOCK_ACCOUNT ),
657+ )
658+ def test_unoptimized_cpu_family (ecr_prefix ):
626659 image_uri = fw_utils .create_image_uri (
627660 MOCK_REGION , MOCK_FRAMEWORK , "ml.m4.xlarge" , "1.0.0" , "py3" , optimized_families = ["c5" , "p3" ]
628661 )
@@ -631,7 +664,11 @@ def test_unoptimized_cpu_family():
631664 )
632665
633666
634- def test_unoptimized_gpu_family ():
667+ @patch (
668+ "sagemaker.fw_utils.get_ecr_image_uri_prefix" ,
669+ return_value = ECR_PREFIX_FORMAT .format (MOCK_ACCOUNT ),
670+ )
671+ def test_unoptimized_gpu_family (ecr_prefix ):
635672 image_uri = fw_utils .create_image_uri (
636673 MOCK_REGION , MOCK_FRAMEWORK , "ml.p2.xlarge" , "1.0.0" , "py3" , optimized_families = ["c5" , "p3" ]
637674 )
0 commit comments