Skip to content

chore: add private field to js classes for model specs #4759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
13 changes: 13 additions & 0 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs
from sagemaker.jumpstart.factory.model import get_default_predictor
from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job
from sagemaker.jumpstart.types import JumpStartInternalMetadata
from sagemaker.jumpstart.utils import (
validate_model_id_and_get_type,
resolve_model_sagemaker_config_field,
Expand Down Expand Up @@ -534,6 +535,18 @@ def _validate_model_id_and_get_type_hook():
if not self.model_type and not hub_arn:
raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id))

self._internal = JumpStartInternalMetadata(
specs=verify_model_region_and_return_specs(
region=self.region,
model_id=self.model_id,
version=self.model_version,
hub_arn=self.hub_arn,
model_type=self.model_type,
scope=JumpStartScriptScope.TRAINING,
sagemaker_session=self.sagemaker_session,
)
)

estimator_init_kwargs = get_init_kwargs(
model_id=model_id,
model_version=model_version,
Expand Down
14 changes: 13 additions & 1 deletion src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
get_register_kwargs,
)
from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint
from sagemaker.jumpstart.types import JumpStartSerializablePayload
from sagemaker.jumpstart.types import JumpStartInternalMetadata, JumpStartSerializablePayload
from sagemaker.jumpstart.utils import (
validate_model_id_and_get_type,
verify_model_region_and_return_specs,
Expand Down Expand Up @@ -312,6 +312,18 @@ def _validate_model_id_and_type():
if not self.model_type and not hub_arn:
raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id))

self._internal = JumpStartInternalMetadata(
specs=verify_model_region_and_return_specs(
region=self.region,
model_id=self.model_id,
version=self.model_version,
hub_arn=self.hub_arn,
model_type=self.model_type,
scope=JumpStartScriptScope.INFERENCE,
sagemaker_session=self.sagemaker_session,
)
)

self._model_data_is_set = model_data is not None
model_init_kwargs = get_init_kwargs(
model_id=model_id,
Expand Down
14 changes: 14 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2482,3 +2482,17 @@ def __init__(
self.source_uri = source_uri
self.model_card = model_card
self.accept_eula = accept_eula


class JumpStartInternalMetadata(JumpStartDataHolderType):
"""Data class for storing internal/private fields for JumpStart models."""

slots = ["specs"]

def __init__(self, specs: JumpStartModelSpecs):
"""Initializes a JumpStartInternalMetadata object.

Args:
specs (JumpStartModelSpecs): specs for model.
"""
self.specs = specs