diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 25804bba74..8ac813a6c5 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -293,7 +293,8 @@ def _model_id_retrieval_function( raise KeyError(error_msg) error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. " - error_msg += f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. " + error_msg += "Specify a different model ID or try a different AWS Region. " + error_msg += f"For a list of available models, see {MODEL_ID_LIST_WEB_URL}. " other_model_id_version = None if model_type == JumpStartModelType.OPEN_WEIGHTS: diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index 742a6b8d3f..13994c2ed9 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -32,8 +32,8 @@ ) INVALID_MODEL_ID_ERROR_MSG = ( - "Invalid model ID: '{model_id}'. Please visit " - f"{MODEL_ID_LIST_WEB_URL} for a list of valid model IDs. " + "Invalid model ID: '{model_id}'. Specify a different model ID or try a different AWS Region. " + f"For a list of available models, see {MODEL_ID_LIST_WEB_URL}. " "The module `sagemaker.jumpstart.notebook_utils` contains utilities for " "fetching model IDs. We recommend upgrading to the latest version of sagemaker " "to get access to the most models." diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index c881bce482..4d784c8275 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -50,7 +50,12 @@ ) from sagemaker.session import Session from sagemaker.config import load_sagemaker_config -from sagemaker.utils import resolve_value_from_config, TagsDict, get_instance_rate_per_hour +from sagemaker.utils import ( + resolve_value_from_config, + TagsDict, + get_instance_rate_per_hour, + get_domain_for_region, +) from sagemaker.workflow import is_pipeline_variable from sagemaker.user_agent import get_user_agent_extra_suffix @@ -553,7 +558,7 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str: return ( f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). " f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}." - f"amazonaws.com{'.cn' if region.startswith('cn-') else ''}" + f"{get_domain_for_region(region)}" f"/{model_specs.hosting_eula_key} for terms of use." ) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index cc42896cf5..3f640bbe33 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -52,6 +52,15 @@ from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string from sagemaker.workflow.entities import PipelineVariable +ALTERNATE_DOMAINS = { + "cn-north-1": "amazonaws.com.cn", + "cn-northwest-1": "amazonaws.com.cn", + "us-iso-east-1": "c2s.ic.gov", + "us-isob-east-1": "sc2s.sgov.gov", + "us-isof-south-1": "csp.hci.ic.gov", + "us-isof-east-1": "csp.hci.ic.gov", +} + ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$" MODEL_PACKAGE_ARN_PATTERN = ( r"arn:aws([a-z\-]*)?:sagemaker:([a-z0-9\-]*):([0-9]{12}):model-package/(.*)" @@ -1905,3 +1914,12 @@ def remove_tag_with_key(key: str, tags: Optional[Tags]) -> Optional[Tags]: if len(updated_tags) == 1: return updated_tags[0] return updated_tags + + +def get_domain_for_region(region: str) -> str: + """Returns the domain for the given region. + + Args: + region (str): AWS region name. + """ + return ALTERNATE_DOMAINS.get(region, "amazonaws.com") diff --git a/tests/unit/sagemaker/image_uris/expected_uris.py b/tests/unit/sagemaker/image_uris/expected_uris.py index 3cf900565b..01e4d4991f 100644 --- a/tests/unit/sagemaker/image_uris/expected_uris.py +++ b/tests/unit/sagemaker/image_uris/expected_uris.py @@ -12,14 +12,8 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -ALTERNATE_DOMAINS = { - "cn-north-1": "amazonaws.com.cn", - "cn-northwest-1": "amazonaws.com.cn", - "us-iso-east-1": "c2s.ic.gov", - "us-isob-east-1": "sc2s.sgov.gov", - "us-isof-south-1": "csp.hci.ic.gov", - "us-isof-east-1": "csp.hci.ic.gov", -} +from sagemaker.utils import ALTERNATE_DOMAINS + DOMAIN = "amazonaws.com" IMAGE_URI_FORMAT = "{}.dkr.ecr.{}.{}/{}:{}" MONITOR_URI_FORMAT = "{}.dkr.ecr.{}.{}/sagemaker-model-monitor-analyzer" diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index c97e6ba895..da20debc6a 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -205,8 +205,11 @@ def test_jumpstart_cache_get_header(): ) assert ( "Unable to find model manifest for 'pytorch-ic-imagenet-inception-v3-classification-4' with " - "version '3.*'. Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " - "for updated list of models. Consider using model ID 'pytorch-ic-imagenet-inception-v3-" + "version '3.*'. Specify a different model ID or try a different AWS Region. " + "For a list of available models, see " + "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html. " + "Consider using model ID " + "'pytorch-ic-imagenet-inception-v3-" "classification-4' with version '2.0.0'." ) in str(e.value) @@ -214,8 +217,9 @@ def test_jumpstart_cache_get_header(): cache.get_header(model_id="pytorch-ic-", semantic_version_str="*") assert ( "Unable to find model manifest for 'pytorch-ic-' with version '*'. " - "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " - "for updated list of models. " + "Specify a different model ID or try a different AWS Region. " + "For a list of available models, see " + "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html. " "Did you mean to use model ID 'pytorch-ic-imagenet-inception-v3-classification-4'?" ) in str(e.value) @@ -223,8 +227,9 @@ def test_jumpstart_cache_get_header(): cache.get_header(model_id="tensorflow-ic-", semantic_version_str="*") assert ( "Unable to find model manifest for 'tensorflow-ic-' with version '*'. " - "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " - "for updated list of models. " + "Specify a different model ID or try a different AWS Region. For a list " + "of available models, see " + "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html. " "Did you mean to use model ID 'tensorflow-ic-imagenet-inception-" "v3-classification-4'?" ) in str(e.value) @@ -237,8 +242,9 @@ def test_jumpstart_cache_get_header(): ) assert ( "Unable to find model manifest for 'ai21-summarize' with version '1.1.003'. " - "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " - "for updated list of models. " + "Specify a different model ID or try a different AWS Region. " + "For a list of available models, see " + "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html. " "Did you mean to use model ID 'ai21-summarization'?" ) in str(e.value) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index cbf918dee8..fe2ba749cd 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -2150,3 +2150,21 @@ def test_has_instance_rate_stat(stats, expected): def test_deployment_config_response_data(data, expected): out = utils.deployment_config_response_data(data) assert out == expected + + +class TestGetEulaMessage(TestCase): + mock_model_specs = Mock(model_id="some-model-id", hosting_eula_key="some-eula-key") + + def test_get_domain_for_region(self): + self.assertEqual( + utils.get_eula_message(self.mock_model_specs, "us-west-2"), + "Model 'some-model-id' requires accepting end-user license agreement (EULA). See" + " https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com/some-eula-key " + "for terms of use.", + ) + self.assertEqual( + utils.get_eula_message(self.mock_model_specs, "cn-north-1"), + "Model 'some-model-id' requires accepting end-user license agreement (EULA). See" + " https://jumpstart-cache-prod-cn-north-1.s3.cn-north-1.amazonaws.com.cn/some-eula-key " + "for terms of use.", + ) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 3284d966e2..f243bf1635 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -38,6 +38,7 @@ camel_case_to_pascal_case, deep_override_dict, flatten_dict, + get_domain_for_region, get_instance_type_family, retry_with_backoff, check_and_get_run_experiment_config, @@ -2231,3 +2232,15 @@ def test_remove_non_existent_tag(self): def test_remove_only_tag(self): original_tags = [{"Key": "Tag1", "Value": "Value1"}] self.assertIsNone(remove_tag_with_key("Tag1", original_tags)) + + +class TestGetDomainForRegion(TestCase): + def test_get_domain_for_region(self): + self.assertEqual(get_domain_for_region("us-west-2"), "amazonaws.com") + self.assertEqual(get_domain_for_region("eu-west-1"), "amazonaws.com") + self.assertEqual(get_domain_for_region("ap-northeast-1"), "amazonaws.com") + self.assertEqual(get_domain_for_region("us-gov-west-1"), "amazonaws.com") + self.assertEqual(get_domain_for_region("cn-northwest-1"), "amazonaws.com.cn") + self.assertEqual(get_domain_for_region("us-iso-east-1"), "c2s.ic.gov") + self.assertEqual(get_domain_for_region("us-isob-east-1"), "sc2s.sgov.gov") + self.assertEqual(get_domain_for_region("invalid-region"), "amazonaws.com")