diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 29a903e00b..5a4be3f53f 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -372,10 +372,18 @@ def _get_json_file( object and None when reading from the local file system. """ if self._is_local_metadata_mode(): - file_content, etag = self._get_json_file_from_local_override(key, filetype), None - else: - file_content, etag = self._get_json_file_and_etag_from_s3(key) - return file_content, etag + if filetype in { + JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, + JumpStartS3FileType.OPEN_WEIGHT_SPECS, + }: + return self._get_json_file_from_local_override(key, filetype), None + else: + JUMPSTART_LOGGER.warning( + "Local metadata mode is enabled, but the file type %s is not supported " + "for local override. Falling back to s3.", + filetype, + ) + return self._get_json_file_and_etag_from_s3(key) def _get_json_md5_hash(self, key: str): """Retrieves md5 object hash for s3 objects, using `s3.head_object`. diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 12eb30daaf..051cda0f4a 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -54,9 +54,9 @@ from sagemaker.jumpstart.constants import ( JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, + JUMPSTART_MODEL_HUB_NAME, TRAINING_ENTRY_POINT_SCRIPT_NAME, SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, - JUMPSTART_MODEL_HUB_NAME, ) from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.factory import model @@ -634,10 +634,10 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE """Sets model uri in kwargs based on default or override, returns full kwargs.""" # hub_arn is by default None unless the user specifies the hub_name # If no hub_name is specified, it is assumed the public hub + # Training platform enforces that private hub models must use model channel is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs.hub_arn if kwargs.hub_arn else False - if ( - _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)) - or is_private_hub + if is_private_hub or _model_supports_training_model_uri( + **get_model_info_default_kwargs(kwargs) ): default_model_uri = model_uris.retrieve( model_scope=JumpStartScriptScope.TRAINING, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 0cd4bcc902..5b45b21bd8 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1940,12 +1940,20 @@ def use_inference_script_uri(self) -> bool: def use_training_model_artifact(self) -> bool: """Returns True if the model should use a model uri when kicking off training job.""" - # gated model never use training model artifact - if self.gated_bucket: + # old models with this environment variable present don't use model channel + if any( + self.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value( + instance_type + ) + for instance_type in self.supported_training_instance_types + ): + return False + + # even older models with training model package artifact uris present also don't use model channel + if len(self.training_model_package_artifact_uris or {}) > 0: return False - # otherwise, return true is a training model package is not set - return len(self.training_model_package_artifact_uris or {}) == 0 + return getattr(self, "training_artifact_key", None) is not None def is_gated_model(self) -> bool: """Returns True if the model has a EULA key or the model bucket is gated.""" diff --git a/tests/unit/sagemaker/jumpstart/factory/__init__.py b/tests/unit/sagemaker/jumpstart/factory/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/jumpstart/factory/test_estimator.py b/tests/unit/sagemaker/jumpstart/factory/test_estimator.py new file mode 100644 index 0000000000..fd59961f09 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/factory/test_estimator.py @@ -0,0 +1,162 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from unittest.mock import patch +from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME +from sagemaker.jumpstart.factory.estimator import ( + _add_model_uri_to_kwargs, + get_model_info_default_kwargs, +) +from sagemaker.jumpstart.types import JumpStartEstimatorInitKwargs +from sagemaker.jumpstart.enums import JumpStartScriptScope + + +class TestAddModelUriToKwargs: + @pytest.fixture + def mock_kwargs(self): + return JumpStartEstimatorInitKwargs( + model_id="test-model", + model_version="1.0.0", + instance_type="ml.m5.large", + model_uri=None, + ) + + @patch( + "sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri", + return_value=True, + ) + @patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve") + def test_add_model_uri_to_kwargs_default_uri( + self, mock_retrieve, mock_supports_training, mock_kwargs + ): + """Test adding default model URI when none is provided.""" + default_uri = "s3://jumpstart-models/training/test-model/1.0.0" + mock_retrieve.return_value = default_uri + + result = _add_model_uri_to_kwargs(mock_kwargs) + + mock_supports_training.assert_called_once() + mock_retrieve.assert_called_once_with( + model_scope=JumpStartScriptScope.TRAINING, + instance_type=mock_kwargs.instance_type, + **get_model_info_default_kwargs(mock_kwargs), + ) + assert result.model_uri == default_uri + + @patch( + "sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri", + return_value=True, + ) + @patch( + "sagemaker.jumpstart.factory.estimator._model_supports_incremental_training", + return_value=True, + ) + @patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve") + def test_add_model_uri_to_kwargs_custom_uri_with_incremental( + self, mock_retrieve, mock_supports_incremental, mock_supports_training, mock_kwargs + ): + """Test using custom model URI with incremental training support.""" + default_uri = "s3://jumpstart-models/training/test-model/1.0.0" + custom_uri = "s3://custom-bucket/my-model" + mock_retrieve.return_value = default_uri + mock_kwargs.model_uri = custom_uri + + result = _add_model_uri_to_kwargs(mock_kwargs) + + mock_supports_training.assert_called_once() + mock_supports_incremental.assert_called_once() + assert result.model_uri == custom_uri + + @patch( + "sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri", + return_value=True, + ) + @patch( + "sagemaker.jumpstart.factory.estimator._model_supports_incremental_training", + return_value=False, + ) + @patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve") + @patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning") + def test_add_model_uri_to_kwargs_custom_uri_without_incremental( + self, + mock_warning, + mock_retrieve, + mock_supports_incremental, + mock_supports_training, + mock_kwargs, + ): + """Test using custom model URI without incremental training support logs warning.""" + default_uri = "s3://jumpstart-models/training/test-model/1.0.0" + custom_uri = "s3://custom-bucket/my-model" + mock_retrieve.return_value = default_uri + mock_kwargs.model_uri = custom_uri + + result = _add_model_uri_to_kwargs(mock_kwargs) + + mock_supports_training.assert_called_once() + mock_supports_incremental.assert_called_once() + mock_warning.assert_called_once() + assert "does not support incremental training" in mock_warning.call_args[0][0] + assert result.model_uri == custom_uri + + @patch( + "sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri", + return_value=False, + ) + def test_add_model_uri_to_kwargs_no_training_support(self, mock_supports_training, mock_kwargs): + """Test when model doesn't support training model URI.""" + result = _add_model_uri_to_kwargs(mock_kwargs) + + mock_supports_training.assert_called_once() + assert result.model_uri is None + + @patch( + "sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri", + return_value=False, + ) + @patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve") + def test_add_model_uri_to_kwargs_private_hub( + self, mock_retrieve, mock_supports_training, mock_kwargs + ): + """Test when model is from a private hub.""" + default_uri = "s3://jumpstart-models/training/test-model/1.0.0" + mock_retrieve.return_value = default_uri + mock_kwargs.hub_arn = "arn:aws:sagemaker:us-west-2:123456789012:hub/private-hub" + + result = _add_model_uri_to_kwargs(mock_kwargs) + + # Should not check if model supports training model URI for private hub + mock_supports_training.assert_not_called() + mock_retrieve.assert_called_once() + assert result.model_uri == default_uri + + @patch( + "sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri", + return_value=False, + ) + @patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve") + def test_add_model_uri_to_kwargs_public_hub( + self, mock_retrieve, mock_supports_training, mock_kwargs + ): + """Test when model is from the public hub.""" + mock_kwargs.hub_arn = ( + f"arn:aws:sagemaker:us-west-2:123456789012:hub/{JUMPSTART_MODEL_HUB_NAME}" + ) + + result = _add_model_uri_to_kwargs(mock_kwargs) + + # Should check if model supports training model URI for public hub + mock_supports_training.assert_called_once() + mock_retrieve.assert_not_called() + assert result.model_uri is None diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 17996f4f15..a652a11f4e 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -1288,3 +1288,78 @@ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(retrieval_func assert_key = JumpStartVersionedModelId("test-model", "abc") assert result == assert_key + + +@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region") +@patch( + "sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket" +) +def test_get_json_file_from_s3(): + """Test _get_json_file retrieves from S3 in normal mode.""" + cache = JumpStartModelsCache() + test_key = "test/file/path.json" + test_json_data = {"key": "value"} + test_etag = "test-etag-123" + + with patch.object( + JumpStartModelsCache, + "_get_json_file_and_etag_from_s3", + return_value=(test_json_data, test_etag), + ) as mock_s3_get: + result, etag = cache._get_json_file(test_key, JumpStartS3FileType.OPEN_WEIGHT_MANIFEST) + + mock_s3_get.assert_called_once_with(test_key) + assert result == test_json_data + assert etag == test_etag + + +@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region") +@patch( + "sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket" +) +def test_get_json_file_from_local_supported_type(): + """Test _get_json_file retrieves from local override for supported file types.""" + cache = JumpStartModelsCache() + test_key = "test/file/path.json" + test_json_data = {"key": "value"} + + with ( + patch.object(JumpStartModelsCache, "_is_local_metadata_mode", return_value=True), + patch.object( + JumpStartModelsCache, "_get_json_file_from_local_override", return_value=test_json_data + ) as mock_local_get, + ): + result, etag = cache._get_json_file(test_key, JumpStartS3FileType.OPEN_WEIGHT_MANIFEST) + + mock_local_get.assert_called_once_with(test_key, JumpStartS3FileType.OPEN_WEIGHT_MANIFEST) + assert result == test_json_data + assert etag is None + + +@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region") +@patch( + "sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket" +) +def test_get_json_file_local_mode_unsupported_type(): + """Test _get_json_file falls back to S3 for unsupported file types in local mode.""" + cache = JumpStartModelsCache() + test_key = "test/file/path.json" + test_json_data = {"key": "value"} + test_etag = "test-etag-123" + + with ( + patch.object(JumpStartModelsCache, "_is_local_metadata_mode", return_value=True), + patch.object( + JumpStartModelsCache, + "_get_json_file_and_etag_from_s3", + return_value=(test_json_data, test_etag), + ) as mock_s3_get, + patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning") as mock_warning, + ): + result, etag = cache._get_json_file(test_key, JumpStartS3FileType.PROPRIETARY_MANIFEST) + + mock_s3_get.assert_called_once_with(test_key) + mock_warning.assert_called_once() + assert "not supported for local override" in mock_warning.call_args[0][0] + assert result == test_json_data + assert etag == test_etag diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 0b5ef63947..03a85fee44 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -39,6 +39,8 @@ INIT_KWARGS, ) +from unittest.mock import Mock + INSTANCE_TYPE_VARIANT = JumpStartInstanceTypeVariants( { "regional_aliases": { @@ -329,14 +331,67 @@ def test_jumpstart_model_header(): assert header1 == header3 -def test_use_training_model_artifact(): - specs1 = JumpStartModelSpecs(BASE_SPEC) - assert specs1.use_training_model_artifact() - specs1.gated_bucket = True - assert not specs1.use_training_model_artifact() - specs1.gated_bucket = False - specs1.training_model_package_artifact_uris = {"region1": "blah", "region2": "blah2"} - assert not specs1.use_training_model_artifact() +class TestUseTrainingModelArtifact: + @pytest.fixture + def mock_specs(self): + specs = Mock(spec=JumpStartModelSpecs) + specs.training_instance_type_variants = Mock() + specs.supported_training_instance_types = ["ml.p3.2xlarge", "ml.g4dn.xlarge"] + specs.training_model_package_artifact_uris = {} + specs.training_artifact_key = None + return specs + + def test_use_training_model_artifact_with_env_var(self, mock_specs): + """Test when instance type variants have env var values.""" + mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.side_effect = [ + "some-value", + None, + ] + + result = JumpStartModelSpecs.use_training_model_artifact(mock_specs) + + assert result is False + mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.assert_any_call( + "ml.p3.2xlarge" + ) + + def test_use_training_model_artifact_with_package_uris(self, mock_specs): + """Test when model has training package artifact URIs.""" + mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.return_value = ( + None + ) + mock_specs.training_model_package_artifact_uris = { + "ml.p3.2xlarge": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/" + "llama2-13b-e155a2e0347b323fb882f1875851c5d3" + } + + result = JumpStartModelSpecs.use_training_model_artifact(mock_specs) + + assert result is False + + def test_use_training_model_artifact_with_artifact_key(self, mock_specs): + """Test when model has training artifact key.""" + mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.return_value = ( + None + ) + mock_specs.training_model_package_artifact_uris = {} + mock_specs.training_artifact_key = "some-key" + + result = JumpStartModelSpecs.use_training_model_artifact(mock_specs) + + assert result is True + + def test_use_training_model_artifact_without_artifact_key(self, mock_specs): + """Test when model has no training artifact key.""" + mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.return_value = ( + None + ) + mock_specs.training_model_package_artifact_uris = {} + mock_specs.training_artifact_key = None + + result = JumpStartModelSpecs.use_training_model_artifact(mock_specs) + + assert result is False def test_jumpstart_model_specs():