Skip to content

Commit 10dba2c

Browse files
author
Malav Shastri
committed
fix: fix _add_tags_to_kwargs to use hub_content_arn instead of hub_arn
1 parent 3fe2774 commit 10dba2c

File tree

3 files changed

+46
-1
lines changed

3 files changed

+46
-1
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
_retrieve_model_package_model_artifact_s3_uri,
3030
)
3131
from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
32+
from sagemaker.jumpstart.hub.utils import construct_hub_model_arn_from_inputs, construct_hub_model_reference_arn_from_inputs
3233
from sagemaker.session import Session
3334
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
3435
from sagemaker.base_deserializers import BaseDeserializer
@@ -52,6 +53,7 @@
5253
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
5354
from sagemaker.jumpstart.factory import model
5455
from sagemaker.jumpstart.types import (
56+
HubContentType,
5557
JumpStartEstimatorDeployKwargs,
5658
JumpStartEstimatorFitKwargs,
5759
JumpStartEstimatorInitKwargs,
@@ -201,6 +203,11 @@ def get_init_kwargs(
201203
estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs)
202204
estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs)
203205
estimator_init_kwargs = _add_image_uri_to_kwargs(estimator_init_kwargs)
206+
if hub_arn:
207+
estimator_init_kwargs = _add_model_reference_arn_to_kwargs(kwargs=estimator_init_kwargs)
208+
else:
209+
estimator_init_kwargs.model_reference_arn = None
210+
estimator_init_kwargs.hub_content_type = None
204211
estimator_init_kwargs = _add_model_uri_to_kwargs(estimator_init_kwargs)
205212
estimator_init_kwargs = _add_source_dir_to_kwargs(estimator_init_kwargs)
206213
estimator_init_kwargs = _add_entry_point_to_kwargs(estimator_init_kwargs)
@@ -511,7 +518,15 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima
511518
)
512519

513520
if kwargs.hub_arn:
514-
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn)
521+
if kwargs.model_reference_arn:
522+
hub_content_arn = construct_hub_model_reference_arn_from_inputs(
523+
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
524+
)
525+
else:
526+
hub_content_arn = construct_hub_model_arn_from_inputs(
527+
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
528+
)
529+
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)
515530

516531
return kwargs
517532

@@ -534,6 +549,32 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
534549

535550
return kwargs
536551

552+
def _add_model_reference_arn_to_kwargs(
553+
kwargs: JumpStartEstimatorInitKwargs,
554+
) -> JumpStartEstimatorInitKwargs:
555+
"""Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs."""
556+
557+
hub_content_type = verify_model_region_and_return_specs(
558+
model_id=kwargs.model_id,
559+
version=kwargs.model_version,
560+
hub_arn=kwargs.hub_arn,
561+
scope=JumpStartScriptScope.TRAINING,
562+
region=kwargs.region,
563+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
564+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
565+
sagemaker_session=kwargs.sagemaker_session,
566+
model_type=kwargs.model_type,
567+
).hub_content_type
568+
kwargs.hub_content_type = hub_content_type if kwargs.hub_arn else None
569+
570+
if hub_content_type == HubContentType.MODEL_REFERENCE:
571+
kwargs.model_reference_arn = construct_hub_model_reference_arn_from_inputs(
572+
hub_arn=kwargs.hub_arn, model_name=kwargs.model_id, version=kwargs.model_version
573+
)
574+
else:
575+
kwargs.model_reference_arn = None
576+
return kwargs
577+
537578

538579
def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
539580
"""Sets model uri in kwargs based on default or override, returns full kwargs."""

src/sagemaker/jumpstart/factory/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def _add_model_reference_arn_to_kwargs(
268268
kwargs: JumpStartModelInitKwargs,
269269
) -> JumpStartModelInitKwargs:
270270
"""Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs."""
271+
271272
hub_content_type = verify_model_region_and_return_specs(
272273
model_id=kwargs.model_id,
273274
version=kwargs.model_version,

src/sagemaker/jumpstart/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2026,6 +2026,8 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
20262026
"enable_infra_check",
20272027
"enable_remote_debug",
20282028
"enable_session_tag_chaining",
2029+
"hub_content_type",
2030+
"model_reference_arn",
20292031
]
20302032

20312033
SERIALIZATION_EXCLUSION_SET = {
@@ -2036,6 +2038,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
20362038
"model_version",
20372039
"hub_arn",
20382040
"model_type",
2041+
"hub_content_type",
20392042
}
20402043

20412044
def __init__(

0 commit comments

Comments
 (0)