File tree Expand file tree Collapse file tree 3 files changed +38
-3
lines changed Expand file tree Collapse file tree 3 files changed +38
-3
lines changed Original file line number Diff line number Diff line change 11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
+ import pytest
14
15
15
16
from mock .mock import patch
16
17
17
18
from sagemaker import image_uris
18
- import pytest
19
19
20
20
from tests .unit .sagemaker .jumpstart .utils import get_spec_from_base_spec
21
21
from sagemaker .jumpstart import constants as sagemaker_constants
@@ -62,6 +62,20 @@ def test_jumpstart_script_uri(patched_get_model_specs):
62
62
sagemaker_constants .JUMPSTART_DEFAULT_REGION_NAME , "pytorch-ic-mobilenet-v2" , "*"
63
63
)
64
64
65
+ patched_get_model_specs .reset_mock ()
66
+
67
+ image_uris .retrieve (
68
+ framework = None ,
69
+ region = "us-west-2" ,
70
+ image_scope = "training" ,
71
+ model_id = "pytorch-ic-mobilenet-v2" ,
72
+ model_version = "1.*" ,
73
+ instance_type = "ml.p2.xlarge" ,
74
+ )
75
+ patched_get_model_specs .assert_called_once_with (
76
+ sagemaker_constants .JUMPSTART_DEFAULT_REGION_NAME , "pytorch-ic-mobilenet-v2" , "1.*"
77
+ )
78
+
65
79
with pytest .raises (ValueError ):
66
80
image_uris .retrieve (
67
81
framework = None ,
Original file line number Diff line number Diff line change 11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
+ import pytest
14
15
15
16
from mock .mock import patch
16
17
17
18
from sagemaker import model_uris
18
19
from sagemaker .jumpstart import constants as sagemaker_constants
19
- import pytest
20
20
from tests .unit .sagemaker .jumpstart .utils import get_spec_from_base_spec
21
21
from sagemaker .jumpstart .utils import get_jumpstart_content_bucket
22
22
@@ -60,6 +60,16 @@ def test_jumpstart_model_uri(patched_get_model_specs):
60
60
patched_get_model_specs .assert_called_once_with (
61
61
sagemaker_constants .JUMPSTART_DEFAULT_REGION_NAME , "pytorch-ic-mobilenet-v2" , "*"
62
62
)
63
+ patched_get_model_specs .reset_mock ()
64
+
65
+ model_uris .retrieve (
66
+ model_scope = "training" ,
67
+ model_id = "pytorch-ic-mobilenet-v2" ,
68
+ model_version = "1.*" ,
69
+ )
70
+ patched_get_model_specs .assert_called_once_with (
71
+ sagemaker_constants .JUMPSTART_DEFAULT_REGION_NAME , "pytorch-ic-mobilenet-v2" , "1.*"
72
+ )
63
73
64
74
with pytest .raises (ValueError ):
65
75
model_uris .retrieve (
Original file line number Diff line number Diff line change 11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
+ import pytest
14
15
15
16
from mock .mock import patch
16
17
17
18
from sagemaker import script_uris
18
- import pytest
19
19
20
20
from tests .unit .sagemaker .jumpstart .utils import get_spec_from_base_spec
21
21
from sagemaker .jumpstart .utils import get_jumpstart_content_bucket
@@ -62,6 +62,17 @@ def test_jumpstart_script_uri(patched_get_model_specs):
62
62
sagemaker_constants .JUMPSTART_DEFAULT_REGION_NAME , "pytorch-ic-mobilenet-v2" , "*"
63
63
)
64
64
65
+ patched_get_model_specs .reset_mock ()
66
+
67
+ script_uris .retrieve (
68
+ script_scope = "training" ,
69
+ model_id = "pytorch-ic-mobilenet-v2" ,
70
+ model_version = "1.*" ,
71
+ )
72
+ patched_get_model_specs .assert_called_once_with (
73
+ sagemaker_constants .JUMPSTART_DEFAULT_REGION_NAME , "pytorch-ic-mobilenet-v2" , "1.*"
74
+ )
75
+
65
76
with pytest .raises (ValueError ):
66
77
script_uris .retrieve (
67
78
region = "us-west-2" ,
You can’t perform that action at this time.
0 commit comments