diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 7d600ddfbc..39932be859 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -35,6 +35,7 @@ from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs from sagemaker.jumpstart.factory.model import get_default_predictor from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job +from sagemaker.jumpstart.types import JumpStartModelInternalConfig from sagemaker.jumpstart.utils import ( validate_model_id_and_get_type, resolve_model_sagemaker_config_field, @@ -607,8 +608,20 @@ def _validate_model_id_and_get_type_hook(): self.role = estimator_init_kwargs.role self.sagemaker_session = estimator_init_kwargs.sagemaker_session self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation - super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict()) + self._internal_config = JumpStartModelInternalConfig( + specs=verify_model_region_and_return_specs( + region=self.region, + model_id=self.model_id, + version=self.model_version, + hub_arn=self.hub_arn, + model_type=self.model_type, + scope=JumpStartScriptScope.TRAINING, + sagemaker_session=self.sagemaker_session, + tolerate_vulnerable_model=self.tolerate_vulnerable_model, + tolerate_deprecated_model=self.tolerate_deprecated_model, + ) + ) def fit( self, diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index b482d4fefd..853de47e73 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -38,7 +38,7 @@ get_register_kwargs, ) from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint -from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.types import JumpStartModelInternalConfig, JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( validate_model_id_and_get_type, verify_model_region_and_return_specs, @@ -361,10 +361,22 @@ def _validate_model_id_and_type(): self.log_subscription_warning() model_init_kwargs_dict = model_init_kwargs.to_kwargs_dict() - super(JumpStartModel, self).__init__(**model_init_kwargs_dict) self.model_package_arn = model_init_kwargs.model_package_arn + self._internal_config = JumpStartModelInternalConfig( + specs=verify_model_region_and_return_specs( + region=self.region, + model_id=self.model_id, + version=self.model_version, + hub_arn=self.hub_arn, + model_type=self.model_type, + scope=JumpStartScriptScope.INFERENCE, + sagemaker_session=self.sagemaker_session, + tolerate_vulnerable_model=self.tolerate_vulnerable_model, + tolerate_deprecated_model=self.tolerate_deprecated_model, + ) + ) def log_subscription_warning(self) -> None: """Log message prompting the customer to subscribe to the proprietary model.""" diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 171d9ce8a1..1e576526ef 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2482,3 +2482,17 @@ def __init__( self.source_uri = source_uri self.model_card = model_card self.accept_eula = accept_eula + + +class JumpStartModelInternalConfig(JumpStartDataHolderType): + """Data class for storing internal/private fields for JumpStart models.""" + + slots = ["specs"] + + def __init__(self, specs: JumpStartModelSpecs): + """Initializes a JumpStartModelInternalConfig object. + + Args: + specs (JumpStartModelSpecs): specs for model. + """ + self.specs = specs diff --git a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py index 00c87fac1b..ae38cffa1e 100644 --- a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py +++ b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py @@ -19,6 +19,7 @@ from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.estimator import JumpStartEstimator +from sagemaker.jumpstart.types import JumpStartModelSpecs import tests from tests.integ.sagemaker.jumpstart.constants import ( ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, @@ -169,6 +170,9 @@ def test_gated_model_training_v2(setup): sagemaker_session=get_sm_session(), ) + assert isinstance(attached_estimator._internal_config.specs, JumpStartModelSpecs) + assert attached_estimator._internal_config.specs.model_id == model_id + # uses ml.g5.2xlarge instance predictor = attached_estimator.deploy( tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 6bc0a5c996..77acf20464 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -17,6 +17,7 @@ import pytest from sagemaker.enums import EndpointType +from sagemaker.jumpstart.types import JumpStartModelSpecs from sagemaker.predictor import retrieve_default import tests.integ @@ -223,6 +224,8 @@ def test_jumpstart_gated_model_inference_component_enabled(setup): assert model.model_id == model_id assert model.endpoint_name == predictor.endpoint_name assert model.inference_component_name == predictor.component_name + assert isinstance(model._internal_config.specs, JumpStartModelSpecs) + assert model._internal_config.specs.model_id == model_id @mock.patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning") diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 062209e3a0..e973da10f2 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1171,8 +1171,10 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): @mock.patch("sagemaker.jumpstart.estimator.get_init_kwargs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") + @mock.patch("sagemaker.jumpstart.estimator.verify_model_region_and_return_specs") def test_validate_model_id_and_get_type( self, + verify_model_region_and_return_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 15c2c43bf0..2837b8465b 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -758,8 +758,10 @@ def test_jumpstart_model_kwargs_match_parent_class(self): @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") + @mock.patch("sagemaker.jumpstart.model.verify_model_region_and_return_specs") def test_validate_model_id_and_get_type( self, + verify_model_region_and_return_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock,