Skip to content

Commit 748ea4b

Browse files
author
Joseph Zhang
committed
Use correct bucket for SM/JS draft models and minor formatting/validation updates.
1 parent b7b15b8 commit 748ea4b

File tree

6 files changed

+52
-33
lines changed

6 files changed

+52
-33
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@
5454
add_hub_content_arn_tags,
5555
add_jumpstart_model_info_tags,
5656
get_default_jumpstart_session_with_user_agent_suffix,
57-
get_neo_content_bucket,
5857
get_top_ranked_config_name,
5958
update_dict_if_key_not_present,
6059
resolve_model_sagemaker_config_field,
6160
verify_model_region_and_return_specs,
61+
get_draft_model_content_bucket,
6262
)
6363

6464
from sagemaker.jumpstart.factory.utils import (
@@ -76,7 +76,6 @@
7676
name_from_base,
7777
format_tags,
7878
Tags,
79-
get_domain_for_region,
8079
)
8180
from sagemaker.workflow.entities import PipelineVariable
8281
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
@@ -572,7 +571,9 @@ def _add_additional_model_data_sources_to_kwargs(
572571
# Append speculative decoding data source from metadata
573572
speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources()
574573
for data_source in speculative_decoding_data_sources:
575-
data_source.s3_data_source.set_bucket(get_neo_content_bucket(region=kwargs.region))
574+
data_source.s3_data_source.set_bucket(
575+
get_draft_model_content_bucket(provider=data_source.provider, region=kwargs.region)
576+
)
576577
api_shape_additional_model_data_sources = (
577578
[
578579
camel_case_to_pascal_case(data_source.to_json())

src/sagemaker/jumpstart/model.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -458,9 +458,7 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload:
458458
sagemaker_session=self.sagemaker_session,
459459
)
460460

461-
def set_deployment_config(
462-
self, config_name: str, instance_type: str
463-
) -> None:
461+
def set_deployment_config(self, config_name: str, instance_type: str) -> None:
464462
"""Sets the deployment config to apply to the model.
465463
466464
Args:
@@ -479,7 +477,7 @@ def set_deployment_config(
479477
instance_type=instance_type,
480478
config_name=config_name,
481479
sagemaker_session=self.sagemaker_session,
482-
role=self.role
480+
role=self.role,
483481
)
484482

485483
@property
@@ -766,11 +764,11 @@ def deploy(
766764
(Default: EndpointType.MODEL_BASED).
767765
routing_config (Optional[Dict]): Settings the control how the endpoint routes
768766
incoming traffic to the instances that the endpoint hosts.
769-
model_access_configs (Optional[Dict[str, ModelAccessConfig]]): For models that require ModelAccessConfig,
770-
provide a `{ "model_id", ModelAccessConfig(accept_eula=True) }` to indicate whether model terms
771-
of use have been accepted. The `accept_eula` value must be explicitly defined as `True` in order to
772-
accept the end-user license agreement (EULA) that some.
773-
(Default: None)
767+
model_access_configs (Optional[Dict[str, ModelAccessConfig]]): For models that require
768+
ModelAccessConfig, provide a `{ "model_id", ModelAccessConfig(accept_eula=True) }`
769+
to indicate whether model terms of use have been accepted. The `accept_eula` value
770+
must be explicitly defined as `True` in order to accept the end-user license
771+
agreement (EULA) that some. (Default: None)
774772
775773
Raises:
776774
MarketplaceModelSubscriptionError: If the caller is not subscribed to the model.

src/sagemaker/jumpstart/types.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,7 +1082,7 @@ def set_bucket(self, bucket: str) -> None:
10821082
class AdditionalModelDataSource(JumpStartDataHolderType):
10831083
"""Data class of additional model data source mirrors CreateModel API."""
10841084

1085-
SERIALIZATION_EXCLUSION_SET: Set[str] = set()
1085+
SERIALIZATION_EXCLUSION_SET = {"provider"}
10861086

10871087
__slots__ = ["channel_name", "s3_data_source", "hosting_eula_key"]
10881088

@@ -1103,6 +1103,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
11031103
self.channel_name: str = json_obj["channel_name"]
11041104
self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"])
11051105
self.hosting_eula_key: str = json_obj.get("hosting_eula_key")
1106+
self.provider: Dict = json_obj.get("provider", {})
11061107

11071108
def to_json(self, exclude_keys=True) -> Dict[str, Any]:
11081109
"""Returns json representation of AdditionalModelDataSource object."""
@@ -1121,7 +1122,9 @@ def to_json(self, exclude_keys=True) -> Dict[str, Any]:
11211122
class JumpStartModelDataSource(AdditionalModelDataSource):
11221123
"""Data class JumpStart additional model data source."""
11231124

1124-
SERIALIZATION_EXCLUSION_SET = {"artifact_version"}
1125+
SERIALIZATION_EXCLUSION_SET = {
1126+
"artifact_version"
1127+
} | AdditionalModelDataSource.SERIALIZATION_EXCLUSION_SET
11251128

11261129
__slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__
11271130

@@ -2241,7 +2244,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
22412244
"config_name",
22422245
"routing_config",
22432246
"specs",
2244-
"model_access_configs"
2247+
"model_access_configs",
22452248
]
22462249

22472250
SERIALIZATION_EXCLUSION_SET = {
@@ -2255,7 +2258,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
22552258
"sagemaker_session",
22562259
"training_instance_type",
22572260
"config_name",
2258-
"model_access_configs"
2261+
"model_access_configs",
22592262
}
22602263

22612264
def __init__(
@@ -2294,7 +2297,7 @@ def __init__(
22942297
endpoint_type: Optional[EndpointType] = None,
22952298
config_name: Optional[str] = None,
22962299
routing_config: Optional[Dict[str, Any]] = None,
2297-
model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None
2300+
model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None,
22982301
) -> None:
22992302
"""Instantiates JumpStartModelDeployKwargs object."""
23002303

src/sagemaker/jumpstart/utils.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
564564

565565

566566
def format_eula_message_from_specs(model_id: str, region: str, hosting_eula_key: str):
567+
"""Returns a formatted EULA message."""
567568
return (
568569
f"Model '{model_id}' requires accepting end-user license agreement (EULA). "
569570
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
@@ -1552,21 +1553,25 @@ def _add_model_access_configs_to_model_data_sources(
15521553
hosting_eula_key = model_data_source.get("HostingEulaKey")
15531554
if hosting_eula_key:
15541555
if not model_access_configs or not model_access_configs.get(model_id):
1555-
eula_message_template = "{model_source}{base_eula_message}{model_access_configs_message}"
1556+
eula_message_template = (
1557+
"{model_source}{base_eula_message}{model_access_configs_message}"
1558+
)
15561559
model_access_config_entry = (
1557-
"\"{model_id}\":ModelAccessConfig(accept_eula=True)".format(model_id=model_id)
1560+
'"{model_id}":ModelAccessConfig(accept_eula=True)'.format(model_id=model_id)
15581561
)
1559-
raise ValueError(eula_message_template.format(
1560-
model_source="Draft " if model_data_source.get("ChannelName") else "",
1561-
base_eula_message=format_eula_message_from_specs(
1562-
model_id=model_id, region=region, hosting_eula_key=hosting_eula_key
1563-
),
1564-
model_access_configs_message=(
1565-
" Please add a ModelAccessConfig entry:"
1566-
f" {model_access_config_entry} "
1567-
"to model_access_configs to acknowledge the EULA."
1562+
raise ValueError(
1563+
eula_message_template.format(
1564+
model_source="Draft " if model_data_source.get("ChannelName") else "",
1565+
base_eula_message=format_eula_message_from_specs(
1566+
model_id=model_id, region=region, hosting_eula_key=hosting_eula_key
1567+
),
1568+
model_access_configs_message=(
1569+
" Please add a ModelAccessConfig entry:"
1570+
f" {model_access_config_entry} "
1571+
"to model_access_configs to acknowledge the EULA."
1572+
),
15681573
)
1569-
))
1574+
)
15701575
acked_model_data_source = model_data_source.copy()
15711576
acked_model_data_source.pop("HostingEulaKey")
15721577
acked_model_data_source["S3DataSource"]["ModelAccessConfig"] = (
@@ -1576,3 +1581,17 @@ def _add_model_access_configs_to_model_data_sources(
15761581
else:
15771582
acked_model_data_sources.append(model_data_source)
15781583
return acked_model_data_sources
1584+
1585+
1586+
def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
1587+
"""Returns the correct content bucket for a 1p draft model."""
1588+
neo_bucket = get_neo_content_bucket(region=region)
1589+
if not provider:
1590+
return neo_bucket
1591+
provider_name = provider.get("name", "")
1592+
if provider_name == "JumpStart":
1593+
classification = provider.get("classification", "ungated")
1594+
if classification == "gated":
1595+
return get_jumpstart_gated_content_bucket(region=region)
1596+
return get_jumpstart_content_bucket(region=region)
1597+
return neo_bucket

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -504,9 +504,7 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
504504
sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration
505505
)
506506

507-
def set_deployment_config(
508-
self, config_name: str, instance_type: str
509-
) -> None:
507+
def set_deployment_config(self, config_name: str, instance_type: str) -> None:
510508
"""Sets the deployment config to apply to the model.
511509
512510
Args:

src/sagemaker/serve/builder/model_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1344,7 +1344,7 @@ def _optimize_for_hf(
13441344
Optional[Dict[str, Any]]: Model optimization job input arguments.
13451345
"""
13461346
if speculative_decoding_config:
1347-
if speculative_decoding_config.get("ModelProvider", "") == "JumpStart":
1347+
if speculative_decoding_config.get("ModelProvider", "").lower() == "jumpstart":
13481348
_jumpstart_speculative_decoding(
13491349
model=self.pysdk_model,
13501350
speculative_decoding_config=speculative_decoding_config,

0 commit comments

Comments
 (0)