Skip to content

Commit b7b15b8

Browse files
committed
move the accept eula configurations into deploy flow
1 parent aef3a90 commit b7b15b8

File tree

4 files changed

+16
-12
lines changed

4 files changed

+16
-12
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ def get_deploy_kwargs(
655655
training_config_name: Optional[str] = None,
656656
config_name: Optional[str] = None,
657657
routing_config: Optional[Dict[str, Any]] = None,
658-
model_access_configs: Optional[List[ModelAccessConfig]] = None,
658+
model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None,
659659
) -> JumpStartModelDeployKwargs:
660660
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""
661661

src/sagemaker/jumpstart/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -766,9 +766,9 @@ def deploy(
766766
(Default: EndpointType.MODEL_BASED).
767767
routing_config (Optional[Dict]): Settings the control how the endpoint routes
768768
incoming traffic to the instances that the endpoint hosts.
769-
model_access_configs (Optional[List[ModelAccessConfig]]): For models that require Model Access Configs,
770-
provide one or multiple ModelAccessConfig objects to indicate whether model terms of use have been accepted.
771-
The `AcceptEula` value must be explicitly defined as `True` in order to
769+
model_access_configs (Optional[Dict[str, ModelAccessConfig]]): For models that require ModelAccessConfig,
770+
provide a `{ "model_id", ModelAccessConfig(accept_eula=True) }` to indicate whether model terms
771+
of use have been accepted. The `accept_eula` value must be explicitly defined as `True` in order to
772772
accept the end-user license agreement (EULA) that some.
773773
(Default: None)
774774

src/sagemaker/jumpstart/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2294,7 +2294,7 @@ def __init__(
22942294
endpoint_type: Optional[EndpointType] = None,
22952295
config_name: Optional[str] = None,
22962296
routing_config: Optional[Dict[str, Any]] = None,
2297-
model_access_configs: Optional[List[CoreModelAccessConfig]] = None
2297+
model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None
22982298
) -> None:
22992299
"""Instantiates JumpStartModelDeployKwargs object."""
23002300

src/sagemaker/jumpstart/utils.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,7 +1538,7 @@ def wrapped_f(*args, **kwargs):
15381538

15391539
def _add_model_access_configs_to_model_data_sources(
15401540
model_data_sources: List[Dict[str, any]],
1541-
model_access_configs: List[ModelAccessConfig],
1541+
model_access_configs: Dict[str, ModelAccessConfig],
15421542
model_id: str,
15431543
region: str,
15441544
):
@@ -1548,25 +1548,29 @@ def _add_model_access_configs_to_model_data_sources(
15481548
return model_data_sources
15491549

15501550
acked_model_data_sources = []
1551-
acked_model_access_configs = 0
15521551
for model_data_source in model_data_sources:
1553-
hosting_eula_key = model_data_source.pop("HostingEulaKey", None)
1552+
hosting_eula_key = model_data_source.get("HostingEulaKey")
15541553
if hosting_eula_key:
1555-
if not model_access_configs or acked_model_access_configs == len(model_access_configs):
1554+
if not model_access_configs or not model_access_configs.get(model_id):
15561555
eula_message_template = "{model_source}{base_eula_message}{model_access_configs_message}"
1556+
model_access_config_entry = (
1557+
"\"{model_id}\":ModelAccessConfig(accept_eula=True)".format(model_id=model_id)
1558+
)
15571559
raise ValueError(eula_message_template.format(
15581560
model_source="Draft " if model_data_source.get("ChannelName") else "",
15591561
base_eula_message=format_eula_message_from_specs(
15601562
model_id=model_id, region=region, hosting_eula_key=hosting_eula_key
15611563
),
15621564
model_access_configs_message=(
1563-
" Please add a ModelAccessConfig with AcceptEula=True"
1564-
" to model_access_configs to acknowledge the EULA."
1565+
" Please add a ModelAccessConfig entry:"
1566+
f" {model_access_config_entry} "
1567+
"to model_access_configs to acknowledge the EULA."
15651568
)
15661569
))
15671570
acked_model_data_source = model_data_source.copy()
1571+
acked_model_data_source.pop("HostingEulaKey")
15681572
acked_model_data_source["S3DataSource"]["ModelAccessConfig"] = (
1569-
camel_case_to_pascal_case(model_access_configs[acked_model_access_configs].model_dump())
1573+
camel_case_to_pascal_case(model_access_configs.get(model_id).model_dump())
15701574
)
15711575
acked_model_data_sources.append(acked_model_data_source)
15721576
else:

0 commit comments

Comments
 (0)