File tree Expand file tree Collapse file tree 3 files changed +38
-3
lines changed Expand file tree Collapse file tree 3 files changed +38
-3
lines changed Original file line number Diff line number Diff line change 1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
1313from __future__ import absolute_import
14+ import pytest
1415
1516from mock .mock import patch
1617
1718from sagemaker import image_uris
18- import pytest
1919
2020from tests .unit .sagemaker .jumpstart .utils import get_spec_from_base_spec
2121from sagemaker .jumpstart import constants as sagemaker_constants
@@ -62,6 +62,20 @@ def test_jumpstart_script_uri(patched_get_model_specs):
6262 sagemaker_constants .JUMPSTART_DEFAULT_REGION_NAME , "pytorch-ic-mobilenet-v2" , "*"
6363 )
6464
65+ patched_get_model_specs .reset_mock ()
66+
67+ image_uris .retrieve (
68+ framework = None ,
69+ region = "us-west-2" ,
70+ image_scope = "training" ,
71+ model_id = "pytorch-ic-mobilenet-v2" ,
72+ model_version = "1.*" ,
73+ instance_type = "ml.p2.xlarge" ,
74+ )
75+ patched_get_model_specs .assert_called_once_with (
76+ sagemaker_constants .JUMPSTART_DEFAULT_REGION_NAME , "pytorch-ic-mobilenet-v2" , "1.*"
77+ )
78+
6579 with pytest .raises (ValueError ):
6680 image_uris .retrieve (
6781 framework = None ,
Original file line number Diff line number Diff line change 1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
1313from __future__ import absolute_import
14+ import pytest
1415
1516from mock .mock import patch
1617
1718from sagemaker import model_uris
1819from sagemaker .jumpstart import constants as sagemaker_constants
19- import pytest
2020from tests .unit .sagemaker .jumpstart .utils import get_spec_from_base_spec
2121from sagemaker .jumpstart .utils import get_jumpstart_content_bucket
2222
@@ -60,6 +60,16 @@ def test_jumpstart_model_uri(patched_get_model_specs):
6060 patched_get_model_specs .assert_called_once_with (
6161 sagemaker_constants .JUMPSTART_DEFAULT_REGION_NAME , "pytorch-ic-mobilenet-v2" , "*"
6262 )
63+ patched_get_model_specs .reset_mock ()
64+
65+ model_uris .retrieve (
66+ model_scope = "training" ,
67+ model_id = "pytorch-ic-mobilenet-v2" ,
68+ model_version = "1.*" ,
69+ )
70+ patched_get_model_specs .assert_called_once_with (
71+ sagemaker_constants .JUMPSTART_DEFAULT_REGION_NAME , "pytorch-ic-mobilenet-v2" , "1.*"
72+ )
6373
6474 with pytest .raises (ValueError ):
6575 model_uris .retrieve (
Original file line number Diff line number Diff line change 1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
1313from __future__ import absolute_import
14+ import pytest
1415
1516from mock .mock import patch
1617
1718from sagemaker import script_uris
18- import pytest
1919
2020from tests .unit .sagemaker .jumpstart .utils import get_spec_from_base_spec
2121from sagemaker .jumpstart .utils import get_jumpstart_content_bucket
@@ -62,6 +62,17 @@ def test_jumpstart_script_uri(patched_get_model_specs):
6262 sagemaker_constants .JUMPSTART_DEFAULT_REGION_NAME , "pytorch-ic-mobilenet-v2" , "*"
6363 )
6464
65+ patched_get_model_specs .reset_mock ()
66+
67+ script_uris .retrieve (
68+ script_scope = "training" ,
69+ model_id = "pytorch-ic-mobilenet-v2" ,
70+ model_version = "1.*" ,
71+ )
72+ patched_get_model_specs .assert_called_once_with (
73+ sagemaker_constants .JUMPSTART_DEFAULT_REGION_NAME , "pytorch-ic-mobilenet-v2" , "1.*"
74+ )
75+
6576 with pytest .raises (ValueError ):
6677 script_uris .retrieve (
6778 region = "us-west-2" ,
You can’t perform that action at this time.
0 commit comments