Skip to content

Commit 6eb825c

Browse files
committed
feat: jumpstart retrieve functions (wip)
1 parent b691d3d commit 6eb825c

File tree

12 files changed

+345
-8
lines changed

12 files changed

+345
-8
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from mock.mock import Mock
2+
import pytest
3+
4+
REGION_NAME = "us-west-2"
5+
BUCKET_NAME = "some-bucket-name"
6+
7+
8+
@pytest.fixture(scope="module")
9+
def session():
10+
boto_mock = Mock(region_name=REGION_NAME)
11+
sms = Mock(
12+
boto_session=boto_mock,
13+
boto_region_name=REGION_NAME,
14+
config=None,
15+
)
16+
sms.default_bucket = Mock(return_value=BUCKET_NAME)
17+
return sms

tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
26-
def test_jumpstart_catboost_image_uri(patched_get_model_specs):
26+
def test_jumpstart_catboost_image_uri(patched_get_model_specs, session):
2727

2828
patched_get_model_specs.side_effect = get_prototype_model_spec
2929

@@ -49,6 +49,7 @@ def test_jumpstart_catboost_image_uri(patched_get_model_specs):
4949
entry_point="mock_entry_point",
5050
framework_version=model_specs.hosting_ecr_specs.framework_version,
5151
py_version=model_specs.hosting_ecr_specs.py_version,
52+
sagemaker_session=session,
5253
).serving_image_uri(region, instance_type)
5354

5455
assert uri == framework_class_uri
@@ -71,6 +72,7 @@ def test_jumpstart_catboost_image_uri(patched_get_model_specs):
7172
py_version=model_specs.training_ecr_specs.py_version,
7273
instance_type=instance_type,
7374
instance_count=1,
75+
sagemaker_session=session,
7476
).training_image_uri(region=region)
7577

7678
assert uri == framework_class_uri

tests/unit/sagemaker/image_uris/jumpstart/test_huggingface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
26-
def test_jumpstart_huggingface_image_uri(patched_get_model_specs):
26+
def test_jumpstart_huggingface_image_uri(patched_get_model_specs, session):
2727

2828
patched_get_model_specs.side_effect = get_prototype_model_spec
2929

@@ -48,6 +48,7 @@ def test_jumpstart_huggingface_image_uri(patched_get_model_specs):
4848
transformers_version=model_specs.hosting_ecr_specs.huggingface_transformers_version,
4949
pytorch_version=model_specs.hosting_ecr_specs.framework_version,
5050
py_version=model_specs.hosting_ecr_specs.py_version,
51+
sagemaker_session=session,
5152
).serving_image_uri(region, instance_type)
5253

5354
assert uri == framework_class_uri
@@ -76,6 +77,7 @@ def test_jumpstart_huggingface_image_uri(patched_get_model_specs):
7677
pytorch_version=model_specs.training_ecr_specs.framework_version,
7778
instance_type=instance_type,
7879
instance_count=1,
80+
sagemaker_session=session,
7981
).training_image_uri(region=region)
8082

8183
assert (

tests/unit/sagemaker/image_uris/jumpstart/test_lightgbm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
26-
def test_jumpstart_lightgbm_image_uri(patched_get_model_specs):
26+
def test_jumpstart_lightgbm_image_uri(patched_get_model_specs, session):
2727

2828
patched_get_model_specs.side_effect = get_prototype_model_spec
2929

@@ -49,6 +49,7 @@ def test_jumpstart_lightgbm_image_uri(patched_get_model_specs):
4949
entry_point="mock_entry_point",
5050
framework_version=model_specs.hosting_ecr_specs.framework_version,
5151
py_version=model_specs.hosting_ecr_specs.py_version,
52+
sagemaker_session=session,
5253
).serving_image_uri(region, instance_type)
5354

5455
assert uri == framework_class_uri
@@ -71,6 +72,7 @@ def test_jumpstart_lightgbm_image_uri(patched_get_model_specs):
7172
py_version=model_specs.training_ecr_specs.py_version,
7273
instance_type=instance_type,
7374
instance_count=1,
75+
sagemaker_session=session,
7476
).training_image_uri(region=region)
7577

7678
assert uri == framework_class_uri

tests/unit/sagemaker/image_uris/jumpstart/test_mxnet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
26-
def test_jumpstart_mxnet_image_uri(patched_get_model_specs):
26+
def test_jumpstart_mxnet_image_uri(patched_get_model_specs, session):
2727

2828
patched_get_model_specs.side_effect = get_prototype_model_spec
2929

@@ -49,6 +49,7 @@ def test_jumpstart_mxnet_image_uri(patched_get_model_specs):
4949
entry_point="mock_entry_point",
5050
framework_version=model_specs.hosting_ecr_specs.framework_version,
5151
py_version=model_specs.hosting_ecr_specs.py_version,
52+
sagemaker_session=session,
5253
).serving_image_uri(region, instance_type)
5354

5455
assert uri == framework_class_uri
@@ -71,6 +72,7 @@ def test_jumpstart_mxnet_image_uri(patched_get_model_specs):
7172
py_version=model_specs.training_ecr_specs.py_version,
7273
instance_type=instance_type,
7374
instance_count=1,
75+
sagemaker_session=session,
7476
).training_image_uri(region=region)
7577

7678
assert uri == framework_class_uri

tests/unit/sagemaker/image_uris/jumpstart/test_pytorch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
26-
def test_jumpstart_pytorch_image_uri(patched_get_model_specs):
26+
def test_jumpstart_pytorch_image_uri(patched_get_model_specs, session):
2727

2828
patched_get_model_specs.side_effect = get_prototype_model_spec
2929

@@ -49,6 +49,7 @@ def test_jumpstart_pytorch_image_uri(patched_get_model_specs):
4949
entry_point="mock_entry_point",
5050
framework_version=model_specs.hosting_ecr_specs.framework_version,
5151
py_version=model_specs.hosting_ecr_specs.py_version,
52+
sagemaker_session=session,
5253
).serving_image_uri(region, instance_type)
5354

5455
assert uri == framework_class_uri
@@ -71,6 +72,7 @@ def test_jumpstart_pytorch_image_uri(patched_get_model_specs):
7172
py_version=model_specs.training_ecr_specs.py_version,
7273
instance_type=instance_type,
7374
instance_count=1,
75+
sagemaker_session=session,
7476
).training_image_uri(region=region)
7577

7678
assert uri == framework_class_uri

tests/unit/sagemaker/image_uris/jumpstart/test_sklearn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
27-
def test_jumpstart_sklearn_image_uri(patched_get_model_specs):
27+
def test_jumpstart_sklearn_image_uri(patched_get_model_specs, session):
2828

2929
patched_get_model_specs.side_effect = get_prototype_model_spec
3030

@@ -50,6 +50,7 @@ def test_jumpstart_sklearn_image_uri(patched_get_model_specs):
5050
entry_point="mock_entry_point",
5151
framework_version=model_specs.hosting_ecr_specs.framework_version,
5252
py_version=model_specs.hosting_ecr_specs.py_version,
53+
sagemaker_session=session,
5354
).serving_image_uri(region, instance_type)
5455

5556
assert uri == framework_class_uri
@@ -75,6 +76,7 @@ def test_jumpstart_sklearn_image_uri(patched_get_model_specs):
7576
instance_type=instance_type,
7677
instance_count=1,
7778
image_uri_region=region,
79+
sagemaker_session=session,
7880
).training_image_uri(region=region)
7981

8082
assert uri == framework_class_uri

tests/unit/sagemaker/image_uris/jumpstart/test_tensorflow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
26-
def test_jumpstart_tensorflow_image_uri(patched_get_model_specs):
26+
def test_jumpstart_tensorflow_image_uri(patched_get_model_specs, session):
2727

2828
patched_get_model_specs.side_effect = get_prototype_model_spec
2929

@@ -48,6 +48,7 @@ def test_jumpstart_tensorflow_image_uri(patched_get_model_specs):
4848
model_data="mock_data",
4949
entry_point="mock_entry_point",
5050
framework_version=model_specs.hosting_ecr_specs.framework_version,
51+
sagemaker_session=session,
5152
).serving_image_uri(region, instance_type)
5253

5354
assert uri == framework_class_uri
@@ -70,6 +71,7 @@ def test_jumpstart_tensorflow_image_uri(patched_get_model_specs):
7071
py_version=model_specs.training_ecr_specs.py_version,
7172
instance_type=instance_type,
7273
instance_count=1,
74+
sagemaker_session=session,
7375
).training_image_uri(region=region)
7476

7577
assert uri == framework_class_uri

tests/unit/sagemaker/image_uris/jumpstart/test_xgboost.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
26-
def test_jumpstart_xgboost_image_uri(patched_get_model_specs):
26+
def test_jumpstart_xgboost_image_uri(patched_get_model_specs, session):
2727

2828
patched_get_model_specs.side_effect = get_prototype_model_spec
2929

@@ -49,6 +49,7 @@ def test_jumpstart_xgboost_image_uri(patched_get_model_specs):
4949
entry_point="mock_entry_point",
5050
framework_version=model_specs.hosting_ecr_specs.framework_version,
5151
py_version=model_specs.hosting_ecr_specs.py_version,
52+
sagemaker_session=session,
5253
).serving_image_uri(region, instance_type)
5354

5455
assert uri == framework_class_uri
@@ -72,6 +73,7 @@ def test_jumpstart_xgboost_image_uri(patched_get_model_specs):
7273
instance_type=instance_type,
7374
instance_count=1,
7475
image_uri_region=region,
76+
sagemaker_session=session,
7577
).training_image_uri(region=region)
7678

7779
assert uri == framework_class_uri
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from mock.mock import patch
16+
17+
from sagemaker import image_uris
18+
import pytest
19+
20+
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
21+
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
22+
from sagemaker.jumpstart import constants as sagemaker_constants
23+
24+
25+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
26+
def test_jumpstart_script_uri(patched_get_model_specs):
27+
28+
patched_get_model_specs.side_effect = get_spec_from_base_spec
29+
uri = image_uris.retrieve(
30+
framework=None,
31+
region="us-west-2",
32+
image_scope="inference",
33+
model_id="pytorch-ic-mobilenet-v2",
34+
model_version="*",
35+
instance_type="ml.p2.xlarge",
36+
)
37+
assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.5.0-gpu-py3"
38+
patched_get_model_specs.assert_called_once_with("us-west-2", "pytorch-ic-mobilenet-v2", "*")
39+
40+
patched_get_model_specs.reset_mock()
41+
42+
uri = image_uris.retrieve(
43+
framework=None,
44+
region="us-west-2",
45+
image_scope="training",
46+
model_id="pytorch-ic-mobilenet-v2",
47+
model_version="*",
48+
instance_type="ml.p2.xlarge",
49+
)
50+
assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3"
51+
patched_get_model_specs.assert_called_once_with("us-west-2", "pytorch-ic-mobilenet-v2", "*")
52+
patched_get_model_specs.reset_mock()
53+
54+
image_uris.retrieve(
55+
framework=None,
56+
region="us-west-2",
57+
image_scope="training",
58+
model_id="pytorch-ic-mobilenet-v2",
59+
model_version="*",
60+
instance_type="ml.p2.xlarge",
61+
)
62+
patched_get_model_specs.assert_called_once_with(
63+
sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, "pytorch-ic-mobilenet-v2", "*"
64+
)
65+
66+
with pytest.raises(ValueError) as e:
67+
image_uris.retrieve(
68+
framework=None,
69+
region="us-west-2",
70+
image_scope="BAD_SCOPE",
71+
model_id="pytorch-ic-mobilenet-v2",
72+
model_version="*",
73+
instance_type="ml.p2.xlarge",
74+
)
75+
76+
with pytest.raises(ValueError) as e:
77+
image_uris.retrieve(
78+
framework=None,
79+
region="mars-south-1",
80+
image_scope="training",
81+
model_id="pytorch-ic-mobilenet-v2",
82+
model_version="*",
83+
instance_type="ml.p2.xlarge",
84+
)
85+
86+
with pytest.raises(ValueError) as e:
87+
image_uris.retrieve(
88+
framework=None,
89+
region="us-west-2",
90+
model_id="pytorch-ic-mobilenet-v2",
91+
model_version="*",
92+
instance_type="ml.p2.xlarge",
93+
)
94+
95+
with pytest.raises(ValueError) as e:
96+
image_uris.retrieve(
97+
framework=None,
98+
region="us-west-2",
99+
image_scope="training",
100+
model_version="*",
101+
instance_type="ml.p2.xlarge",
102+
)
103+
104+
with pytest.raises(ValueError) as e:
105+
image_uris.retrieve(
106+
region="us-west-2",
107+
framework=None,
108+
image_scope="training",
109+
model_id="pytorch-ic-mobilenet-v2",
110+
instance_type="ml.p2.xlarge",
111+
)

0 commit comments

Comments
 (0)