Skip to content

Commit 779f6d6

Browse files
committed
move the accept eula configurations into deploy flow
1 parent 8fb27a0 commit 779f6d6

File tree

5 files changed

+78
-60
lines changed

5 files changed

+78
-60
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717

1818
from typing import Any, Dict, List, Optional, Union
19+
from sagemaker_core.shapes import ModelAccessConfig
1920
from sagemaker import environment_variables, image_uris, instance_types, model_uris, script_uris
2021
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
2122
from sagemaker.base_deserializers import BaseDeserializer
@@ -58,7 +59,6 @@
5859
update_dict_if_key_not_present,
5960
resolve_model_sagemaker_config_field,
6061
verify_model_region_and_return_specs,
61-
get_jumpstart_content_bucket,
6262
)
6363

6464
from sagemaker.jumpstart.factory.utils import (
@@ -563,37 +563,6 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
563563
return kwargs
564564

565565

566-
def _apply_accept_eula_on_model_data_source(
567-
model_data_source: Dict[str, Any], model_id: str, region: str, accept_eula: bool
568-
):
569-
"""Sets AcceptEula to True for gated speculative decoding models"""
570-
571-
mutable_model_data_source = model_data_source.copy()
572-
573-
hosting_eula_key = mutable_model_data_source.get("hosting_eula_key")
574-
del mutable_model_data_source["hosting_eula_key"]
575-
576-
if not hosting_eula_key:
577-
return mutable_model_data_source
578-
579-
if not accept_eula:
580-
raise ValueError(
581-
(
582-
f"The set deployment config comes optimized with an additional model data source "
583-
f"'{model_id}' that requires accepting end-user license agreement (EULA). "
584-
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
585-
f"{get_domain_for_region(region)}"
586-
f"/{hosting_eula_key} for terms of use. Please set `accept_draft_model_eula=True` "
587-
f"once acknowledged."
588-
)
589-
)
590-
591-
mutable_model_data_source["s3_data_source"]["model_access_config"] = {
592-
"accept_eula": accept_eula
593-
}
594-
return mutable_model_data_source
595-
596-
597566
def _add_additional_model_data_sources_to_kwargs(
598567
kwargs: JumpStartModelInitKwargs,
599568
) -> JumpStartModelInitKwargs:
@@ -606,14 +575,7 @@ def _add_additional_model_data_sources_to_kwargs(
606575
data_source.s3_data_source.set_bucket(get_neo_content_bucket(region=kwargs.region))
607576
api_shape_additional_model_data_sources = (
608577
[
609-
camel_case_to_pascal_case(
610-
_apply_accept_eula_on_model_data_source(
611-
data_source.to_json(),
612-
kwargs.model_id,
613-
kwargs.region,
614-
kwargs.accept_draft_model_eula,
615-
)
616-
)
578+
camel_case_to_pascal_case(data_source.to_json())
617579
for data_source in speculative_decoding_data_sources
618580
]
619581
if specs.get_speculative_decoding_s3_data_sources()
@@ -693,6 +655,7 @@ def get_deploy_kwargs(
693655
training_config_name: Optional[str] = None,
694656
config_name: Optional[str] = None,
695657
routing_config: Optional[Dict[str, Any]] = None,
658+
model_access_configs: Optional[List[ModelAccessConfig]] = None,
696659
) -> JumpStartModelDeployKwargs:
697660
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""
698661

@@ -729,6 +692,7 @@ def get_deploy_kwargs(
729692
resources=resources,
730693
config_name=config_name,
731694
routing_config=routing_config,
695+
model_access_configs=model_access_configs,
732696
)
733697
deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs)
734698
deploy_kwargs.specs = verify_model_region_and_return_specs(
@@ -903,7 +867,6 @@ def get_init_kwargs(
903867
resources: Optional[ResourceRequirements] = None,
904868
config_name: Optional[str] = None,
905869
additional_model_data_sources: Optional[Dict[str, Any]] = None,
906-
accept_draft_model_eula: Optional[bool] = None,
907870
) -> JumpStartModelInitKwargs:
908871
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""
909872

@@ -938,7 +901,6 @@ def get_init_kwargs(
938901
resources=resources,
939902
config_name=config_name,
940903
additional_model_data_sources=additional_model_data_sources,
941-
accept_draft_model_eula=accept_draft_model_eula,
942904
)
943905
model_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(
944906
kwargs=model_init_kwargs

src/sagemaker/jumpstart/model.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pandas as pd
1919
from botocore.exceptions import ClientError
2020

21+
from sagemaker_core.shapes import ModelAccessConfig
2122
from sagemaker import payloads
2223
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
2324
from sagemaker.base_deserializers import BaseDeserializer
@@ -51,6 +52,7 @@
5152
add_instance_rate_stats_to_benchmark_metrics,
5253
deployment_config_response_data,
5354
_deployment_config_lru_cache,
55+
_add_model_access_configs_to_model_data_sources,
5456
)
5557
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER
5658
from sagemaker.jumpstart.enums import JumpStartModelType
@@ -111,7 +113,6 @@ def __init__(
111113
resources: Optional[ResourceRequirements] = None,
112114
config_name: Optional[str] = None,
113115
additional_model_data_sources: Optional[Dict[str, Any]] = None,
114-
accept_draft_model_eula: Optional[bool] = None,
115116
):
116117
"""Initializes a ``JumpStartModel``.
117118
@@ -302,10 +303,6 @@ def __init__(
302303
optionally applied to the model.
303304
additional_model_data_sources (Optional[Dict[str, Any]]): Additional location
304305
of SageMaker model data (default: None).
305-
accept_draft_model_eula (bool): For draft models that require a Model Access Config, specify True or
306-
False to indicate whether model terms of use have been accepted.
307-
The `accept_draft_model_eula` value must be explicitly defined as `True` in order to
308-
accept the end-user license agreement (EULA) that some
309306
Raises:
310307
ValueError: If the model ID is not recognized by JumpStart.
311308
"""
@@ -365,7 +362,6 @@ def _validate_model_id_and_type():
365362
resources=resources,
366363
config_name=config_name,
367364
additional_model_data_sources=additional_model_data_sources,
368-
accept_draft_model_eula=accept_draft_model_eula
369365
)
370366

371367
self.orig_predictor_cls = predictor_cls
@@ -463,7 +459,7 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload:
463459
)
464460

465461
def set_deployment_config(
466-
self, config_name: str, instance_type: str, accept_draft_model_eula: Optional[bool] = False
462+
self, config_name: str, instance_type: str
467463
) -> None:
468464
"""Sets the deployment config to apply to the model.
469465
@@ -483,8 +479,7 @@ def set_deployment_config(
483479
instance_type=instance_type,
484480
config_name=config_name,
485481
sagemaker_session=self.sagemaker_session,
486-
role=self.role,
487-
accept_draft_model_eula=accept_draft_model_eula,
482+
role=self.role
488483
)
489484

490485
@property
@@ -674,6 +669,7 @@ def deploy(
674669
managed_instance_scaling: Optional[str] = None,
675670
endpoint_type: EndpointType = EndpointType.MODEL_BASED,
676671
routing_config: Optional[Dict[str, Any]] = None,
672+
model_access_configs: Optional[List[ModelAccessConfig]] = None,
677673
) -> PredictorBase:
678674
"""Creates endpoint by calling base ``Model`` class `deploy` method.
679675
@@ -770,6 +766,11 @@ def deploy(
770766
(Default: EndpointType.MODEL_BASED).
771767
routing_config (Optional[Dict]): Settings the control how the endpoint routes
772768
incoming traffic to the instances that the endpoint hosts.
769+
model_access_configs (Optional[List[ModelAccessConfig]]): For models that require Model Access Configs,
770+
provide one or multiple ModelAccessConfig objects to indicate whether model terms of use have been accepted.
771+
The `AcceptEula` value must be explicitly defined as `True` in order to
772+
accept the end-user license agreement (EULA) that some.
773+
(Default: None)
773774
774775
Raises:
775776
MarketplaceModelSubscriptionError: If the caller is not subscribed to the model.
@@ -810,6 +811,7 @@ def deploy(
810811
model_type=self.model_type,
811812
config_name=self.config_name,
812813
routing_config=routing_config,
814+
model_access_configs=model_access_configs,
813815
)
814816
if (
815817
self.model_type == JumpStartModelType.PROPRIETARY
@@ -819,6 +821,13 @@ def deploy(
819821
f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models."
820822
)
821823

824+
self.additional_model_data_sources = _add_model_access_configs_to_model_data_sources(
825+
self.additional_model_data_sources,
826+
deploy_kwargs.model_access_configs,
827+
deploy_kwargs.model_id,
828+
deploy_kwargs.region,
829+
)
830+
822831
try:
823832
predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict())
824833
except ClientError as e:
@@ -1058,7 +1067,6 @@ def _get_deployment_configs(
10581067
region=self.region,
10591068
model_version=self.model_version,
10601069
hub_arn=self.hub_arn,
1061-
accept_draft_model_eula=True,
10621070
)
10631071
deploy_kwargs = get_deploy_kwargs(
10641072
model_id=self.model_id,

src/sagemaker/jumpstart/types.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from copy import deepcopy
1818
from enum import Enum
1919
from typing import Any, Dict, List, Optional, Set, Union
20+
from sagemaker_core.shapes import ModelAccessConfig as CoreModelAccessConfig
2021
from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard
2122
from sagemaker.utils import (
2223
S3_PREFIX,
@@ -2117,7 +2118,6 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
21172118
"hub_content_type",
21182119
"model_reference_arn",
21192120
"specs",
2120-
"accept_draft_model_eula",
21212121
]
21222122

21232123
SERIALIZATION_EXCLUSION_SET = {
@@ -2133,7 +2133,6 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
21332133
"training_instance_type",
21342134
"config_name",
21352135
"hub_content_type",
2136-
"accept_draft_model_eula",
21372136
}
21382137

21392138
def __init__(
@@ -2168,7 +2167,6 @@ def __init__(
21682167
resources: Optional[ResourceRequirements] = None,
21692168
config_name: Optional[str] = None,
21702169
additional_model_data_sources: Optional[Dict[str, Any]] = None,
2171-
accept_draft_model_eula: Optional[bool] = False
21722170
) -> None:
21732171
"""Instantiates JumpStartModelInitKwargs object."""
21742172

@@ -2202,7 +2200,6 @@ def __init__(
22022200
self.resources = resources
22032201
self.config_name = config_name
22042202
self.additional_model_data_sources = additional_model_data_sources
2205-
self.accept_draft_model_eula = accept_draft_model_eula
22062203

22072204

22082205
class JumpStartModelDeployKwargs(JumpStartKwargs):
@@ -2244,6 +2241,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
22442241
"config_name",
22452242
"routing_config",
22462243
"specs",
2244+
"model_access_configs"
22472245
]
22482246

22492247
SERIALIZATION_EXCLUSION_SET = {
@@ -2257,6 +2255,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
22572255
"sagemaker_session",
22582256
"training_instance_type",
22592257
"config_name",
2258+
"model_access_configs"
22602259
}
22612260

22622261
def __init__(
@@ -2295,6 +2294,7 @@ def __init__(
22952294
endpoint_type: Optional[EndpointType] = None,
22962295
config_name: Optional[str] = None,
22972296
routing_config: Optional[Dict[str, Any]] = None,
2297+
model_access_configs: Optional[List[CoreModelAccessConfig]] = None
22982298
) -> None:
22992299
"""Instantiates JumpStartModelDeployKwargs object."""
23002300

@@ -2332,6 +2332,7 @@ def __init__(
23322332
self.endpoint_type = endpoint_type
23332333
self.config_name = config_name
23342334
self.routing_config = routing_config
2335+
self.model_access_configs = model_access_configs
23352336

23362337

23372338
class JumpStartEstimatorInitKwargs(JumpStartKwargs):

src/sagemaker/jumpstart/utils.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains utilities related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
15+
1516
from copy import copy
1617
import logging
1718
import os
@@ -22,6 +23,7 @@
2223
from botocore.exceptions import ClientError
2324
from packaging.version import Version
2425
import botocore
26+
from sagemaker_core.shapes import ModelAccessConfig
2527
import sagemaker
2628
from sagemaker.config.config_schema import (
2729
MODEL_ENABLE_NETWORK_ISOLATION_PATH,
@@ -55,6 +57,7 @@
5557
TagsDict,
5658
get_instance_rate_per_hour,
5759
get_domain_for_region,
60+
camel_case_to_pascal_case,
5861
)
5962
from sagemaker.workflow import is_pipeline_variable
6063
from sagemaker.user_agent import get_user_agent_extra_suffix
@@ -555,11 +558,17 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
555558
"""Returns EULA message to display if one is available, else empty string."""
556559
if model_specs.hosting_eula_key is None:
557560
return ""
561+
return format_eula_message_from_specs(
562+
model_id=model_specs.model_id, region=region, hosting_eula_key=model_specs.hosting_eula_key
563+
)
564+
565+
566+
def format_eula_message_from_specs(model_id: str, region: str, hosting_eula_key: str):
558567
return (
559-
f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). "
568+
f"Model '{model_id}' requires accepting end-user license agreement (EULA). "
560569
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
561570
f"{get_domain_for_region(region)}"
562-
f"/{model_specs.hosting_eula_key} for terms of use."
571+
f"/{hosting_eula_key} for terms of use."
563572
)
564573

565574

@@ -1525,3 +1534,41 @@ def wrapped_f(*args, **kwargs):
15251534
if _func is None:
15261535
return wrapper_cache
15271536
return wrapper_cache(_func)
1537+
1538+
1539+
def _add_model_access_configs_to_model_data_sources(
1540+
model_data_sources: List[Dict[str, any]],
1541+
model_access_configs: List[ModelAccessConfig],
1542+
model_id: str,
1543+
region: str,
1544+
):
1545+
"""Sets AcceptEula to True for gated speculative decoding models"""
1546+
1547+
if not model_data_sources:
1548+
return model_data_sources
1549+
1550+
acked_model_data_sources = []
1551+
acked_model_access_configs = 0
1552+
for model_data_source in model_data_sources:
1553+
hosting_eula_key = model_data_source.pop("HostingEulaKey", None)
1554+
if hosting_eula_key:
1555+
if not model_access_configs or acked_model_access_configs == len(model_access_configs):
1556+
eula_message_template = "{model_source}{base_eula_message}{model_access_configs_message}"
1557+
raise ValueError(eula_message_template.format(
1558+
model_source="Draft " if model_data_source.get("ChannelName") else "",
1559+
base_eula_message=format_eula_message_from_specs(
1560+
model_id=model_id, region=region, hosting_eula_key=hosting_eula_key
1561+
),
1562+
model_access_configs_message=(
1563+
" Please add a ModelAccessConfig with AcceptEula=True"
1564+
" to model_access_configs to acknowledge the EULA."
1565+
)
1566+
))
1567+
acked_model_data_source = model_data_source.copy()
1568+
acked_model_data_source["S3DataSource"]["ModelAccessConfig"] = (
1569+
camel_case_to_pascal_case(model_access_configs[acked_model_access_configs].model_dump())
1570+
)
1571+
acked_model_data_sources.append(acked_model_data_source)
1572+
else:
1573+
acked_model_data_sources.append(model_data_source)
1574+
return acked_model_data_sources

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
505505
)
506506

507507
def set_deployment_config(
508-
self, config_name: str, instance_type: str, accept_draft_model_eula: Optional[bool] = False
508+
self, config_name: str, instance_type: str
509509
) -> None:
510510
"""Sets the deployment config to apply to the model.
511511
@@ -522,7 +522,7 @@ def set_deployment_config(
522522
if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
523523
raise Exception("Cannot set deployment config to an uninitialized model.")
524524

525-
self.pysdk_model.set_deployment_config(config_name, instance_type, accept_draft_model_eula)
525+
self.pysdk_model.set_deployment_config(config_name, instance_type)
526526
self.deployment_config_name = config_name
527527

528528
self.instance_type = instance_type

0 commit comments

Comments
 (0)