Skip to content

Commit c94a78b

Browse files
author
Joseph Zhang
committed
Tweak draft model EULA validations and messaging. Remove redundant deployment_config flow validation in optimize_utils in favor of the one directly on jumpstart/factory/model.
1 parent 5512c26 commit c94a78b

File tree

2 files changed

+15
-90
lines changed

2 files changed

+15
-90
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
_custom_speculative_decoding,
4949
SPECULATIVE_DRAFT_MODEL,
5050
_is_inferentia_or_trainium,
51-
_validate_and_set_eula_for_draft_model_sources,
5251
_jumpstart_speculative_decoding,
5352
)
5453
from sagemaker.serve.utils.predictors import (
@@ -837,9 +836,7 @@ def _is_gated_model(self, model=None) -> bool:
837836
return "private" in s3_uri
838837

839838
def _set_additional_model_source(
840-
self,
841-
speculative_decoding_config: Optional[Dict[str, Any]] = None,
842-
accept_eula: Optional[bool] = None,
839+
self, speculative_decoding_config: Optional[Dict[str, Any]] = None
843840
) -> None:
844841
"""Set Additional Model Source to ``this`` model.
845842
@@ -849,6 +846,7 @@ def _set_additional_model_source(
849846
"""
850847
if speculative_decoding_config:
851848
model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config)
849+
accept_draft_model_eula = speculative_decoding_config.get("AcceptEula", False)
852850

853851
channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources)
854852

@@ -865,20 +863,22 @@ def _set_additional_model_source(
865863
speculative_decoding_config
866864
)
867865
if deployment_config:
868-
self.pysdk_model.set_deployment_config(
869-
config_name=deployment_config.get("DeploymentConfigName"),
870-
instance_type=deployment_config.get("InstanceType"),
871-
)
866+
try:
867+
self.pysdk_model.set_deployment_config(
868+
config_name=deployment_config.get("DeploymentConfigName"),
869+
instance_type=deployment_config.get("InstanceType"),
870+
accept_draft_model_eula=accept_draft_model_eula,
871+
)
872+
except ValueError as e:
873+
raise ValueError(
874+
f"{e} If using speculative_decoding_config, "
875+
"accept the EULA by setting `AcceptEula`=True."
876+
)
872877
else:
873878
raise ValueError(
874879
"Cannot find deployment config compatible for optimization job."
875880
)
876881

877-
_validate_and_set_eula_for_draft_model_sources(
878-
pysdk_model=self.pysdk_model,
879-
accept_eula=speculative_decoding_config.get("AcceptEula"),
880-
)
881-
882882
self.pysdk_model.env.update(
883883
{"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}/"}
884884
)
@@ -893,7 +893,7 @@ def _set_additional_model_source(
893893
)
894894
else:
895895
self.pysdk_model = _custom_speculative_decoding(
896-
self.pysdk_model, speculative_decoding_config, accept_eula
896+
self.pysdk_model, speculative_decoding_config, accept_draft_model_eula
897897
)
898898

899899
def _find_compatible_deployment_config(

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 1 addition & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def _jumpstart_speculative_decoding(
428428
model_specs=model_specs, region=sagemaker_session.boto_region_name
429429
)
430430
raise ValueError(
431-
f"{eula_message} Please set `AcceptEula` to True in "
431+
f"{eula_message} Set `AcceptEula`=True in "
432432
f"speculative_decoding_config once acknowledged."
433433
)
434434
js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket()
@@ -446,78 +446,3 @@ def _jumpstart_speculative_decoding(
446446
model.add_tags(
447447
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "jumpstart"},
448448
)
449-
450-
451-
def _validate_and_set_eula_for_draft_model_sources(
452-
pysdk_model: Model,
453-
accept_eula: bool = False,
454-
):
455-
"""Validates whether the EULA has been accepted for gated additional draft model sources.
456-
457-
If accepted, updates the model data source's model access config.
458-
459-
Args:
460-
pysdk_model (Model): The model whose additional model data sources to check.
461-
accept_eula (bool): EULA acceptance for the draft model.
462-
"""
463-
if not pysdk_model:
464-
return
465-
466-
deployment_config_draft_model_sources = (
467-
pysdk_model.deployment_config.get("DeploymentArgs", {})
468-
.get("AdditionalDataSources", {})
469-
.get("speculative_decoding", [])
470-
if pysdk_model.deployment_config
471-
else None
472-
)
473-
pysdk_model_additional_model_sources = pysdk_model.additional_model_data_sources
474-
475-
if not deployment_config_draft_model_sources or not pysdk_model_additional_model_sources:
476-
return
477-
478-
# Gated/ungated classification is only available through deployment_config.
479-
# Thus we must check each draft model in the deployment_config and see if it is set
480-
# as an additional model data source on the PySDK model itself.
481-
model_access_config_updated = False
482-
for source in deployment_config_draft_model_sources:
483-
if source.get("channel_name") != "draft_model":
484-
continue
485-
486-
if not _is_draft_model_gated(source):
487-
continue
488-
489-
deployment_config_draft_model_source_s3_uri = (
490-
_extract_deployment_config_additional_model_data_source_s3_uri(source)
491-
)
492-
493-
# If EULA is accepted, proceed with modifying the draft model data source
494-
for additional_source in pysdk_model_additional_model_sources:
495-
if additional_source.get("ChannelName") != "draft_model":
496-
continue
497-
498-
# Verify the pysdk model source and deployment config model source match
499-
pysdk_model_source_s3_uri = _extract_additional_model_data_source_s3_uri(
500-
additional_source
501-
)
502-
if deployment_config_draft_model_source_s3_uri not in pysdk_model_source_s3_uri:
503-
continue
504-
505-
if not accept_eula:
506-
raise ValueError(
507-
"Gated draft model requires accepting end-user license agreement (EULA)."
508-
)
509-
510-
# Set ModelAccessConfig.AcceptEula to True
511-
updated_source = additional_source.copy()
512-
updated_source["S3DataSource"]["ModelAccessConfig"] = {"AcceptEula": True}
513-
514-
index = pysdk_model.additional_model_data_sources.index(additional_source)
515-
pysdk_model.additional_model_data_sources[index] = updated_source
516-
517-
model_access_config_updated = True
518-
break
519-
520-
if model_access_config_updated:
521-
break
522-
523-
return

0 commit comments

Comments
 (0)