Skip to content

Commit fcb5092

Browse files
author
EC2 Default User
committed
add accept_draft_model_eula to JumpStartModel when deployment config with gated draft model is selected
1 parent cf70f59 commit fcb5092

File tree

4 files changed

+86
-16
lines changed

4 files changed

+86
-16
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
update_dict_if_key_not_present,
5959
resolve_model_sagemaker_config_field,
6060
verify_model_region_and_return_specs,
61+
get_jumpstart_content_bucket,
6162
)
6263

6364
from sagemaker.jumpstart.factory.utils import (
@@ -70,7 +71,13 @@
7071

7172
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
7273
from sagemaker.session import Session
73-
from sagemaker.utils import camel_case_to_pascal_case, name_from_base, format_tags, Tags
74+
from sagemaker.utils import (
75+
camel_case_to_pascal_case,
76+
name_from_base,
77+
format_tags,
78+
Tags,
79+
get_domain_for_region,
80+
)
7481
from sagemaker.workflow.entities import PipelineVariable
7582
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
7683
from sagemaker import resource_requirements
@@ -556,6 +563,37 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
556563
return kwargs
557564

558565

566+
def _apply_accept_eula_on_model_data_source(
567+
model_data_source: Dict[str, Any],
568+
model_id: str,
569+
region: str,
570+
accept_eula: bool
571+
):
572+
"""Sets AcceptEula to True for gated speculative decoding models"""
573+
574+
mutable_model_data_source = model_data_source.copy()
575+
576+
hosting_eula_key = mutable_model_data_source.get("hosting_eula_key")
577+
del mutable_model_data_source["hosting_eula_key"]
578+
579+
if not hosting_eula_key:
580+
return mutable_model_data_source
581+
582+
if not accept_eula:
583+
raise ValueError(
584+
(
585+
f"The set deployment config comes optimized with an additional model data source "
586+
f"'{model_id}' that requires accepting end-user license agreement (EULA). "
587+
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
588+
f"{get_domain_for_region(region)}"
589+
f"/{hosting_eula_key} for terms of use. Please set `accept_eula=True` once acknowledged."
590+
)
591+
)
592+
593+
mutable_model_data_source["model_access_config"] = {"accept_eula": accept_eula}
594+
return mutable_model_data_source
595+
596+
559597
def _add_additional_model_data_sources_to_kwargs(
560598
kwargs: JumpStartModelInitKwargs,
561599
) -> JumpStartModelInitKwargs:
@@ -568,7 +606,11 @@ def _add_additional_model_data_sources_to_kwargs(
568606
data_source.s3_data_source.set_bucket(get_neo_content_bucket(region=kwargs.region))
569607
api_shape_additional_model_data_sources = (
570608
[
571-
camel_case_to_pascal_case(data_source.to_json())
609+
camel_case_to_pascal_case(
610+
_apply_accept_eula_on_model_data_source(
611+
data_source.to_json(), kwargs.model_id, kwargs.region, kwargs.accept_draft_model_eula,
612+
)
613+
)
572614
for data_source in speculative_decoding_data_sources
573615
]
574616
if specs.get_speculative_decoding_s3_data_sources()
@@ -858,6 +900,7 @@ def get_init_kwargs(
858900
resources: Optional[ResourceRequirements] = None,
859901
config_name: Optional[str] = None,
860902
additional_model_data_sources: Optional[Dict[str, Any]] = None,
903+
accept_draft_model_eula: Optional[bool] = None,
861904
) -> JumpStartModelInitKwargs:
862905
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""
863906

@@ -892,6 +935,7 @@ def get_init_kwargs(
892935
resources=resources,
893936
config_name=config_name,
894937
additional_model_data_sources=additional_model_data_sources,
938+
accept_draft_model_eula=accept_draft_model_eula,
895939
)
896940
model_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(
897941
kwargs=model_init_kwargs

src/sagemaker/jumpstart/model.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(
111111
resources: Optional[ResourceRequirements] = None,
112112
config_name: Optional[str] = None,
113113
additional_model_data_sources: Optional[Dict[str, Any]] = None,
114+
accept_draft_model_eula: Optional[bool] = None,
114115
):
115116
"""Initializes a ``JumpStartModel``.
116117
@@ -301,6 +302,10 @@ def __init__(
301302
optionally applied to the model.
302303
additional_model_data_sources (Optional[Dict[str, Any]]): Additional location
303304
of SageMaker model data (default: None).
305+
accept_draft_model_eula (bool): For draft models that require a Model Access Config, specify True or
306+
False to indicate whether model terms of use have been accepted.
307+
The `accept_draft_model_eula` value must be explicitly defined as `True` in order to
308+
accept the end-user license agreement (EULA) that some
304309
Raises:
305310
ValueError: If the model ID is not recognized by JumpStart.
306311
"""
@@ -360,6 +365,7 @@ def _validate_model_id_and_type():
360365
resources=resources,
361366
config_name=config_name,
362367
additional_model_data_sources=additional_model_data_sources,
368+
accept_draft_model_eula=accept_draft_model_eula
363369
)
364370

365371
self.orig_predictor_cls = predictor_cls
@@ -456,7 +462,9 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload:
456462
sagemaker_session=self.sagemaker_session,
457463
)
458464

459-
def set_deployment_config(self, config_name: str, instance_type: str) -> None:
465+
def set_deployment_config(
466+
self, config_name: str, instance_type: str, accept_draft_model_eula: Optional[bool] = False
467+
) -> None:
460468
"""Sets the deployment config to apply to the model.
461469
462470
Args:
@@ -466,6 +474,8 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None:
466474
instance_type (str):
467475
The instance_type that the model will use after setting
468476
the config.
477+
accept_draft_model_eula (Optional[bool]):
478+
If the config selected comes with a gated additional model data source.
469479
"""
470480
self.__init__(
471481
model_id=self.model_id,
@@ -474,6 +484,7 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None:
474484
config_name=config_name,
475485
sagemaker_session=self.sagemaker_session,
476486
role=self.role,
487+
accept_draft_model_eula=accept_draft_model_eula,
477488
)
478489

479490
@property
@@ -540,12 +551,16 @@ def attach(
540551
inferred_model_id = inferred_model_version = inferred_inference_component_name = None
541552

542553
if inference_component_name is None or model_id is None or model_version is None:
543-
inferred_model_id, inferred_model_version, inferred_inference_component_name, _, _ = (
544-
get_model_info_from_endpoint(
545-
endpoint_name=endpoint_name,
546-
inference_component_name=inference_component_name,
547-
sagemaker_session=sagemaker_session,
548-
)
554+
(
555+
inferred_model_id,
556+
inferred_model_version,
557+
inferred_inference_component_name,
558+
_,
559+
_,
560+
) = get_model_info_from_endpoint(
561+
endpoint_name=endpoint_name,
562+
inference_component_name=inference_component_name,
563+
sagemaker_session=sagemaker_session,
549564
)
550565

551566
model_id = model_id or inferred_model_id
@@ -1016,10 +1031,11 @@ def _get_deployment_configs(
10161031
)
10171032

10181033
if metadata_config.benchmark_metrics:
1019-
err, metadata_config.benchmark_metrics = (
1020-
add_instance_rate_stats_to_benchmark_metrics(
1021-
self.region, metadata_config.benchmark_metrics
1022-
)
1034+
(
1035+
err,
1036+
metadata_config.benchmark_metrics,
1037+
) = add_instance_rate_stats_to_benchmark_metrics(
1038+
self.region, metadata_config.benchmark_metrics
10231039
)
10241040

10251041
config_components = metadata_config.config_components.get(config_name)
@@ -1042,6 +1058,7 @@ def _get_deployment_configs(
10421058
region=self.region,
10431059
model_version=self.model_version,
10441060
hub_arn=self.hub_arn,
1061+
accept_draft_model_eula=True,
10451062
)
10461063
deploy_kwargs = get_deploy_kwargs(
10471064
model_id=self.model_id,

src/sagemaker/jumpstart/types.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ class AdditionalModelDataSource(JumpStartDataHolderType):
10831083

10841084
SERIALIZATION_EXCLUSION_SET: Set[str] = set()
10851085

1086-
__slots__ = ["channel_name", "s3_data_source"]
1086+
__slots__ = ["channel_name", "s3_data_source", "hosting_eula_key"]
10871087

10881088
def __init__(self, spec: Dict[str, Any]):
10891089
"""Initializes a AdditionalModelDataSource object.
@@ -1101,6 +1101,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
11011101
"""
11021102
self.channel_name: str = json_obj["channel_name"]
11031103
self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"])
1104+
self.hosting_eula_key: str = json_obj.get("hosting_eula_key")
11041105

11051106
def to_json(self, exclude_keys=True) -> Dict[str, Any]:
11061107
"""Returns json representation of AdditionalModelDataSource object."""
@@ -2116,6 +2117,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
21162117
"hub_content_type",
21172118
"model_reference_arn",
21182119
"specs",
2120+
"accept_draft_model_eula",
21192121
]
21202122

21212123
SERIALIZATION_EXCLUSION_SET = {
@@ -2131,6 +2133,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
21312133
"training_instance_type",
21322134
"config_name",
21332135
"hub_content_type",
2136+
"accept_draft_model_eula",
21342137
}
21352138

21362139
def __init__(
@@ -2165,6 +2168,7 @@ def __init__(
21652168
resources: Optional[ResourceRequirements] = None,
21662169
config_name: Optional[str] = None,
21672170
additional_model_data_sources: Optional[Dict[str, Any]] = None,
2171+
accept_draft_model_eula: Optional[bool] = False
21682172
) -> None:
21692173
"""Instantiates JumpStartModelInitKwargs object."""
21702174

@@ -2198,6 +2202,7 @@ def __init__(
21982202
self.resources = resources
21992203
self.config_name = config_name
22002204
self.additional_model_data_sources = additional_model_data_sources
2205+
self.accept_draft_model_eula = accept_draft_model_eula
22012206

22022207

22032208
class JumpStartModelDeployKwargs(JumpStartKwargs):

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,9 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
502502
sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration
503503
)
504504

505-
def set_deployment_config(self, config_name: str, instance_type: str) -> None:
505+
def set_deployment_config(
506+
self, config_name: str, instance_type: str, accept_draft_model_eula: Optional[bool] = False
507+
) -> None:
506508
"""Sets the deployment config to apply to the model.
507509
508510
Args:
@@ -512,11 +514,13 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None:
512514
instance_type (str):
513515
The instance_type that the model will use after setting
514516
the config.
517+
accept_draft_model_eula (Optional[bool]):
518+
If the config selected comes with a gated additional model data source.
515519
"""
516520
if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
517521
raise Exception("Cannot set deployment config to an uninitialized model.")
518522

519-
self.pysdk_model.set_deployment_config(config_name, instance_type)
523+
self.pysdk_model.set_deployment_config(config_name, instance_type, accept_draft_model_eula)
520524
self.deployment_config_name = config_name
521525

522526
self.instance_type = instance_type

0 commit comments

Comments
 (0)