Skip to content

Commit 1a6e184

Browse files
committed
chore: create JumpStartInternalMetadata class
1 parent 490af37 commit 1a6e184

File tree

3 files changed

+36
-17
lines changed

3 files changed

+36
-17
lines changed

src/sagemaker/jumpstart/estimator.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs
3636
from sagemaker.jumpstart.factory.model import get_default_predictor
3737
from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job
38+
from sagemaker.jumpstart.types import JumpStartInternalMetadata
3839
from sagemaker.jumpstart.utils import (
3940
validate_model_id_and_get_type,
4041
resolve_model_sagemaker_config_field,
@@ -534,14 +535,16 @@ def _validate_model_id_and_get_type_hook():
534535
if not self.model_type and not hub_arn:
535536
raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id))
536537

537-
self._specs = verify_model_region_and_return_specs(
538-
region=self.region,
539-
model_id=self.model_id,
540-
version=self.model_version,
541-
hub_arn=self.hub_arn,
542-
model_type=self.model_type,
543-
scope=JumpStartScriptScope.TRAINING,
544-
sagemaker_session=self.sagemaker_session,
538+
self._internal = JumpStartInternalMetadata(
539+
specs=verify_model_region_and_return_specs(
540+
region=self.region,
541+
model_id=self.model_id,
542+
version=self.model_version,
543+
hub_arn=self.hub_arn,
544+
model_type=self.model_type,
545+
scope=JumpStartScriptScope.TRAINING,
546+
sagemaker_session=self.sagemaker_session,
547+
)
545548
)
546549

547550
estimator_init_kwargs = get_init_kwargs(

src/sagemaker/jumpstart/model.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
get_register_kwargs,
3939
)
4040
from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint
41-
from sagemaker.jumpstart.types import JumpStartSerializablePayload
41+
from sagemaker.jumpstart.types import JumpStartInternalMetadata, JumpStartSerializablePayload
4242
from sagemaker.jumpstart.utils import (
4343
validate_model_id_and_get_type,
4444
verify_model_region_and_return_specs,
@@ -312,14 +312,16 @@ def _validate_model_id_and_type():
312312
if not self.model_type and not hub_arn:
313313
raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id))
314314

315-
self._specs = verify_model_region_and_return_specs(
316-
region=self.region,
317-
model_id=self.model_id,
318-
version=self.model_version,
319-
hub_arn=self.hub_arn,
320-
model_type=self.model_type,
321-
scope=JumpStartScriptScope.INFERENCE,
322-
sagemaker_session=self.sagemaker_session,
315+
self._internal = JumpStartInternalMetadata(
316+
specs=verify_model_region_and_return_specs(
317+
region=self.region,
318+
model_id=self.model_id,
319+
version=self.model_version,
320+
hub_arn=self.hub_arn,
321+
model_type=self.model_type,
322+
scope=JumpStartScriptScope.INFERENCE,
323+
sagemaker_session=self.sagemaker_session,
324+
)
323325
)
324326

325327
self._model_data_is_set = model_data is not None

src/sagemaker/jumpstart/types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2482,3 +2482,17 @@ def __init__(
24822482
self.source_uri = source_uri
24832483
self.model_card = model_card
24842484
self.accept_eula = accept_eula
2485+
2486+
2487+
class JumpStartInternalMetadata(JumpStartDataHolderType):
2488+
"""Data class for storing internal/private fields for JumpStart models."""
2489+
2490+
slots = ["specs"]
2491+
2492+
def __init__(self, specs: JumpStartModelSpecs):
2493+
"""Initializes a JumpStartInternalMetadata object.
2494+
2495+
Args:
2496+
specs (JumpStartModelSpecs): specs for model.
2497+
"""
2498+
self.specs = specs

0 commit comments

Comments
 (0)