Skip to content

chore: add private field to js classes for model specs #4759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
15 changes: 14 additions & 1 deletion src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs
from sagemaker.jumpstart.factory.model import get_default_predictor
from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job
from sagemaker.jumpstart.types import JumpStartModelInternalConfig
from sagemaker.jumpstart.utils import (
validate_model_id_and_get_type,
resolve_model_sagemaker_config_field,
Expand Down Expand Up @@ -607,8 +608,20 @@ def _validate_model_id_and_get_type_hook():
self.role = estimator_init_kwargs.role
self.sagemaker_session = estimator_init_kwargs.sagemaker_session
self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation

super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict())
self._internal_config = JumpStartModelInternalConfig(
specs=verify_model_region_and_return_specs(
region=self.region,
model_id=self.model_id,
version=self.model_version,
hub_arn=self.hub_arn,
model_type=self.model_type,
scope=JumpStartScriptScope.TRAINING,
sagemaker_session=self.sagemaker_session,
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
tolerate_deprecated_model=self.tolerate_deprecated_model,
)
)

def fit(
self,
Expand Down
16 changes: 14 additions & 2 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
get_register_kwargs,
)
from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint
from sagemaker.jumpstart.types import JumpStartSerializablePayload
from sagemaker.jumpstart.types import JumpStartModelInternalConfig, JumpStartSerializablePayload
from sagemaker.jumpstart.utils import (
validate_model_id_and_get_type,
verify_model_region_and_return_specs,
Expand Down Expand Up @@ -361,10 +361,22 @@ def _validate_model_id_and_type():
self.log_subscription_warning()

model_init_kwargs_dict = model_init_kwargs.to_kwargs_dict()

super(JumpStartModel, self).__init__(**model_init_kwargs_dict)

self.model_package_arn = model_init_kwargs.model_package_arn
self._internal_config = JumpStartModelInternalConfig(
specs=verify_model_region_and_return_specs(
region=self.region,
model_id=self.model_id,
version=self.model_version,
hub_arn=self.hub_arn,
model_type=self.model_type,
scope=JumpStartScriptScope.INFERENCE,
sagemaker_session=self.sagemaker_session,
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
tolerate_deprecated_model=self.tolerate_deprecated_model,
)
)

def log_subscription_warning(self) -> None:
"""Log message prompting the customer to subscribe to the proprietary model."""
Expand Down
14 changes: 14 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2482,3 +2482,17 @@ def __init__(
self.source_uri = source_uri
self.model_card = model_card
self.accept_eula = accept_eula


class JumpStartModelInternalConfig(JumpStartDataHolderType):
"""Data class for storing internal/private fields for JumpStart models."""

slots = ["specs"]

def __init__(self, specs: JumpStartModelSpecs):
"""Initializes a JumpStartModelInternalConfig object.

Args:
specs (JumpStartModelSpecs): specs for model.
"""
self.specs = specs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME

from sagemaker.jumpstart.estimator import JumpStartEstimator
from sagemaker.jumpstart.types import JumpStartModelSpecs
import tests
from tests.integ.sagemaker.jumpstart.constants import (
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
Expand Down Expand Up @@ -169,6 +170,9 @@ def test_gated_model_training_v2(setup):
sagemaker_session=get_sm_session(),
)

assert isinstance(attached_estimator._internal_config.specs, JumpStartModelSpecs)
assert attached_estimator._internal_config.specs.model_id == model_id

# uses ml.g5.2xlarge instance
predictor = attached_estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pytest
from sagemaker.enums import EndpointType
from sagemaker.jumpstart.types import JumpStartModelSpecs
from sagemaker.predictor import retrieve_default

import tests.integ
Expand Down Expand Up @@ -223,6 +224,8 @@ def test_jumpstart_gated_model_inference_component_enabled(setup):
assert model.model_id == model_id
assert model.endpoint_name == predictor.endpoint_name
assert model.inference_component_name == predictor.component_name
assert isinstance(model._internal_config.specs, JumpStartModelSpecs)
assert model._internal_config.specs.model_id == model_id


@mock.patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning")
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/sagemaker/jumpstart/estimator/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,8 +1171,10 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
@mock.patch("sagemaker.jumpstart.estimator.get_init_kwargs")
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.estimator.verify_model_region_and_return_specs")
def test_validate_model_id_and_get_type(
self,
verify_model_region_and_return_specs: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
mock_init: mock.Mock,
mock_get_init_kwargs: mock.Mock,
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,8 +758,10 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
@mock.patch("sagemaker.jumpstart.model.get_init_kwargs")
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.model.verify_model_region_and_return_specs")
def test_validate_model_id_and_get_type(
self,
verify_model_region_and_return_specs: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
mock_init: mock.Mock,
mock_get_init_kwargs: mock.Mock,
Expand Down
Loading