Skip to content

Commit 182a321

Browse files
committed
change: improve jumpstart retrieve fx impl, cleanup tests, comments, and code
1 parent cfc0df3 commit 182a321

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

tests/unit/sagemaker/image_uris/test_jumpstart.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
import pytest
1415

1516
from mock.mock import patch
1617

1718
from sagemaker import image_uris
18-
import pytest
1919

2020
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
2121
from sagemaker.jumpstart import constants as sagemaker_constants
@@ -62,6 +62,20 @@ def test_jumpstart_script_uri(patched_get_model_specs):
6262
sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, "pytorch-ic-mobilenet-v2", "*"
6363
)
6464

65+
patched_get_model_specs.reset_mock()
66+
67+
image_uris.retrieve(
68+
framework=None,
69+
region="us-west-2",
70+
image_scope="training",
71+
model_id="pytorch-ic-mobilenet-v2",
72+
model_version="1.*",
73+
instance_type="ml.p2.xlarge",
74+
)
75+
patched_get_model_specs.assert_called_once_with(
76+
sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, "pytorch-ic-mobilenet-v2", "1.*"
77+
)
78+
6579
with pytest.raises(ValueError):
6680
image_uris.retrieve(
6781
framework=None,

tests/unit/sagemaker/test_model_uris.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
import pytest
1415

1516
from mock.mock import patch
1617

1718
from sagemaker import model_uris
1819
from sagemaker.jumpstart import constants as sagemaker_constants
19-
import pytest
2020
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
2121
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
2222

@@ -60,6 +60,16 @@ def test_jumpstart_model_uri(patched_get_model_specs):
6060
patched_get_model_specs.assert_called_once_with(
6161
sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, "pytorch-ic-mobilenet-v2", "*"
6262
)
63+
patched_get_model_specs.reset_mock()
64+
65+
model_uris.retrieve(
66+
model_scope="training",
67+
model_id="pytorch-ic-mobilenet-v2",
68+
model_version="1.*",
69+
)
70+
patched_get_model_specs.assert_called_once_with(
71+
sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, "pytorch-ic-mobilenet-v2", "1.*"
72+
)
6373

6474
with pytest.raises(ValueError):
6575
model_uris.retrieve(

tests/unit/sagemaker/test_script_uris.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
import pytest
1415

1516
from mock.mock import patch
1617

1718
from sagemaker import script_uris
18-
import pytest
1919

2020
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
2121
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
@@ -62,6 +62,17 @@ def test_jumpstart_script_uri(patched_get_model_specs):
6262
sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, "pytorch-ic-mobilenet-v2", "*"
6363
)
6464

65+
patched_get_model_specs.reset_mock()
66+
67+
script_uris.retrieve(
68+
script_scope="training",
69+
model_id="pytorch-ic-mobilenet-v2",
70+
model_version="1.*",
71+
)
72+
patched_get_model_specs.assert_called_once_with(
73+
sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, "pytorch-ic-mobilenet-v2", "1.*"
74+
)
75+
6576
with pytest.raises(ValueError):
6677
script_uris.retrieve(
6778
region="us-west-2",

0 commit comments

Comments
 (0)