Skip to content

Commit 9eb82a4

Browse files
committed
change: improve jumpstart retrieve uri unit tests, fix logic for image uris
1 parent 182a321 commit 9eb82a4

18 files changed

+692
-339
lines changed

tests/unit/sagemaker/image_uris/jumpstart/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
2525
def test_jumpstart_common_image_uri(patched_get_model_specs):
26-
2726
patched_get_model_specs.side_effect = get_spec_from_base_spec
2827

2928
image_uris.retrieve(

tests/unit/sagemaker/image_uris/test_jumpstart.py

Lines changed: 0 additions & 124 deletions
This file was deleted.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from mock.mock import patch
16+
17+
from sagemaker import model_uris
18+
19+
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
20+
21+
22+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsCache.get_model_specs")
23+
def test_jumpstart_catboost_model_uri(patched_get_model_specs):
24+
25+
patched_get_model_specs.side_effect = get_prototype_model_spec
26+
27+
# inference
28+
uri = model_uris.retrieve(
29+
region="us-west-2",
30+
model_scope="inference",
31+
model_id="catboost-classification-model",
32+
model_version="*",
33+
)
34+
assert (
35+
uri == "s3://jumpstart-cache-prod-us-west-2/catboost-infer/"
36+
"infer-catboost-classification-model.tar.gz"
37+
)
38+
39+
# training
40+
uri = model_uris.retrieve(
41+
region="us-west-2",
42+
model_scope="training",
43+
model_id="catboost-classification-model",
44+
model_version="*",
45+
)
46+
assert (
47+
uri == "s3://jumpstart-cache-prod-us-west-2/catboost-training/"
48+
"train-catboost-classification-model.tar.gz"
49+
)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from mock.mock import patch
16+
17+
from sagemaker import model_uris
18+
19+
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
20+
21+
22+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsCache.get_model_specs")
23+
def test_jumpstart_huggingface_model_uri(patched_get_model_specs):
24+
25+
patched_get_model_specs.side_effect = get_prototype_model_spec
26+
27+
# inference
28+
uri = model_uris.retrieve(
29+
region="us-west-2",
30+
model_scope="inference",
31+
model_id="huggingface-spc-bert-base-cased",
32+
model_version="*",
33+
)
34+
assert (
35+
uri == "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/"
36+
"infer-huggingface-spc-bert-base-cased.tar.gz"
37+
)
38+
39+
# training
40+
uri = model_uris.retrieve(
41+
region="us-west-2",
42+
model_scope="training",
43+
model_id="huggingface-spc-bert-base-cased",
44+
model_version="*",
45+
)
46+
assert (
47+
uri == "s3://jumpstart-cache-prod-us-west-2/huggingface-training/"
48+
"train-huggingface-spc-bert-base-cased.tar.gz"
49+
)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from mock.mock import patch
16+
17+
from sagemaker import model_uris
18+
19+
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
20+
21+
22+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsCache.get_model_specs")
23+
def test_jumpstart_lightgbm_model_uri(patched_get_model_specs):
24+
25+
patched_get_model_specs.side_effect = get_prototype_model_spec
26+
27+
# inference
28+
uri = model_uris.retrieve(
29+
region="us-west-2",
30+
model_scope="inference",
31+
model_id="lightgbm-classification-model",
32+
model_version="*",
33+
)
34+
assert (
35+
uri == "s3://jumpstart-cache-prod-us-west-2/lightgbm-infer/"
36+
"infer-lightgbm-classification-model.tar.gz"
37+
)
38+
39+
# training
40+
uri = model_uris.retrieve(
41+
region="us-west-2",
42+
model_scope="training",
43+
model_id="lightgbm-classification-model",
44+
model_version="*",
45+
)
46+
assert (
47+
uri == "s3://jumpstart-cache-prod-us-west-2/lightgbm-training/"
48+
"train-lightgbm-classification-model.tar.gz"
49+
)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from mock.mock import patch
16+
17+
from sagemaker import model_uris
18+
19+
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
20+
21+
22+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsCache.get_model_specs")
23+
def test_jumpstart_mxnet_model_uri(patched_get_model_specs):
24+
25+
patched_get_model_specs.side_effect = get_prototype_model_spec
26+
27+
# inference
28+
uri = model_uris.retrieve(
29+
region="us-west-2",
30+
model_scope="inference",
31+
model_id="mxnet-semseg-fcn-resnet50-ade",
32+
model_version="*",
33+
)
34+
assert (
35+
uri == "s3://jumpstart-cache-prod-us-west-2/mxnet-infer/"
36+
"infer-mxnet-semseg-fcn-resnet50-ade.tar.gz"
37+
)
38+
39+
# training
40+
uri = model_uris.retrieve(
41+
region="us-west-2",
42+
model_scope="training",
43+
model_id="mxnet-semseg-fcn-resnet50-ade",
44+
model_version="*",
45+
)
46+
assert (
47+
uri == "s3://jumpstart-cache-prod-us-west-2/mxnet-training/"
48+
"train-mxnet-semseg-fcn-resnet50-ade.tar.gz"
49+
)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from mock.mock import patch
16+
17+
from sagemaker import model_uris
18+
19+
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
20+
21+
22+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsCache.get_model_specs")
23+
def test_jumpstart_pytorch_model_uri(patched_get_model_specs):
24+
25+
patched_get_model_specs.side_effect = get_prototype_model_spec
26+
27+
# inference
28+
uri = model_uris.retrieve(
29+
region="us-west-2",
30+
model_scope="inference",
31+
model_id="pytorch-eqa-bert-base-cased",
32+
model_version="*",
33+
)
34+
assert (
35+
uri == "s3://jumpstart-cache-prod-us-west-2/pytorch-infer/"
36+
"infer-pytorch-eqa-bert-base-cased.tar.gz"
37+
)
38+
39+
# training
40+
uri = model_uris.retrieve(
41+
region="us-west-2",
42+
model_scope="training",
43+
model_id="pytorch-eqa-bert-base-cased",
44+
model_version="*",
45+
)
46+
assert (
47+
uri == "s3://jumpstart-cache-prod-us-west-2/pytorch-training/"
48+
"train-pytorch-eqa-bert-base-cased.tar.gz"
49+
)

0 commit comments

Comments
 (0)