Skip to content

Commit d10c475

Browse files
author
Joseph Zhang
committed
Add "Auto" speculative decoding ModelProvider option; add validations to differentiate SageMaker/JumpStart draft models.
1 parent c94a78b commit d10c475

File tree

3 files changed

+81
-14
lines changed

3 files changed

+81
-14
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -564,10 +564,7 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
564564

565565

566566
def _apply_accept_eula_on_model_data_source(
567-
model_data_source: Dict[str, Any],
568-
model_id: str,
569-
region: str,
570-
accept_eula: bool
567+
model_data_source: Dict[str, Any], model_id: str, region: str, accept_eula: bool
571568
):
572569
"""Sets AcceptEula to True for gated speculative decoding models"""
573570

@@ -586,7 +583,8 @@ def _apply_accept_eula_on_model_data_source(
586583
f"'{model_id}' that requires accepting end-user license agreement (EULA). "
587584
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
588585
f"{get_domain_for_region(region)}"
589-
f"/{hosting_eula_key} for terms of use. Please set `accept_eula=True` once acknowledged."
586+
f"/{hosting_eula_key} for terms of use. Please set `accept_draft_model_eula=True` "
587+
f"once acknowledged."
590588
)
591589
)
592590

@@ -608,7 +606,10 @@ def _add_additional_model_data_sources_to_kwargs(
608606
[
609607
camel_case_to_pascal_case(
610608
_apply_accept_eula_on_model_data_source(
611-
data_source.to_json(), kwargs.model_id, kwargs.region, kwargs.accept_draft_model_eula,
609+
data_source.to_json(),
610+
kwargs.model_id,
611+
kwargs.region,
612+
kwargs.accept_draft_model_eula,
612613
)
613614
)
614615
for data_source in speculative_decoding_data_sources

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
SPECULATIVE_DRAFT_MODEL,
5050
_is_inferentia_or_trainium,
5151
_jumpstart_speculative_decoding,
52+
_deployment_config_contains_draft_model,
53+
_is_draft_model_jumpstart_provided,
5254
)
5355
from sagemaker.serve.utils.predictors import (
5456
DjlLocalModePredictor,
@@ -850,7 +852,7 @@ def _set_additional_model_source(
850852

851853
channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources)
852854

853-
if model_provider == "sagemaker":
855+
if model_provider in ["sagemaker", "auto"]:
854856
additional_model_data_sources = (
855857
self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get(
856858
"AdditionalDataSources"
@@ -863,6 +865,15 @@ def _set_additional_model_source(
863865
speculative_decoding_config
864866
)
865867
if deployment_config:
868+
if model_provider == "sagemaker" and _is_draft_model_jumpstart_provided(
869+
deployment_config
870+
):
871+
raise ValueError(
872+
"No `Sagemaker` provided draft model was found for "
873+
f"{self.model}. Try setting `ModelProvider` "
874+
"to `Auto` instead."
875+
)
876+
866877
try:
867878
self.pysdk_model.set_deployment_config(
868879
config_name=deployment_config.get("DeploymentConfigName"),
@@ -878,12 +889,21 @@ def _set_additional_model_source(
878889
raise ValueError(
879890
"Cannot find deployment config compatible for optimization job."
880891
)
892+
else:
893+
if model_provider == "sagemaker" and _is_draft_model_jumpstart_provided(
894+
self.pysdk_model.deployment_config
895+
):
896+
raise ValueError(
897+
"No `Sagemaker` provided draft model was found for "
898+
f"{self.model}. Try setting `ModelProvider` "
899+
"to `Auto` instead."
900+
)
881901

882902
self.pysdk_model.env.update(
883903
{"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}/"}
884904
)
885905
self.pysdk_model.add_tags(
886-
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"},
906+
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": model_provider},
887907
)
888908
elif model_provider == "jumpstart":
889909
_jumpstart_speculative_decoding(
@@ -911,15 +931,17 @@ def _find_compatible_deployment_config(
911931
for deployment_config in self.pysdk_model.list_deployment_configs():
912932
image_uri = deployment_config.get("deployment_config", {}).get("ImageUri")
913933

914-
if _is_image_compatible_with_optimization_job(image_uri):
934+
if _is_image_compatible_with_optimization_job(
935+
image_uri
936+
) and _deployment_config_contains_draft_model(deployment_config):
915937
if (
916-
model_provider == "sagemaker"
938+
model_provider in ["sagemaker", "auto"]
917939
and deployment_config.get("DeploymentArgs", {}).get("AdditionalDataSources")
918940
) or model_provider == "custom":
919941
return deployment_config
920942

921943
# There's no matching config from jumpstart to add sagemaker draft model location
922-
if model_provider == "sagemaker":
944+
if model_provider in ["sagemaker", "auto"]:
923945
return None
924946

925947
# fall back to the default jumpstart model deployment config for optimization job

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,47 @@ def _is_image_compatible_with_optimization_job(image_uri: Optional[str]) -> bool
6060
return "djl-inference:" in image_uri and ("-lmi" in image_uri or "-neuronx-" in image_uri)
6161

6262

63+
def _deployment_config_contains_draft_model(deployment_config: Optional[Dict]) -> bool:
64+
"""Checks whether a deployment config contains a speculative decoding draft model.
65+
66+
Args:
67+
deployment_config (Dict): The deployment config to check.
68+
69+
Returns:
70+
bool: Whether the deployment config contains a draft model or not.
71+
"""
72+
if deployment_config is None:
73+
return False
74+
deployment_args = deployment_config.get("DeploymentArgs", {})
75+
additional_data_sources = deployment_args.get("AdditionalDataSources")
76+
if not additional_data_sources:
77+
return False
78+
return additional_data_sources.get("speculative_decoding", False)
79+
80+
81+
def _is_draft_model_jumpstart_provided(deployment_config: Optional[Dict]) -> bool:
82+
"""Checks whether a deployment config's draft model is provided by JumpStart.
83+
84+
Args:
85+
deployment_config (Dict): The deployment config to check.
86+
87+
Returns:
88+
bool: Whether the draft model is provided by JumpStart or not.
89+
"""
90+
if deployment_config is None:
91+
return False
92+
93+
additional_model_data_sources = deployment_config.get("DeploymentArgs", {}).get(
94+
"AdditionalDataSources"
95+
)
96+
for source in additional_model_data_sources.get("speculative_decoding", []):
97+
if source["channel_name"] == "draft_model":
98+
if source.get("provider", {}).get("name") == "JumpStart":
99+
return True
100+
continue
101+
return False
102+
103+
63104
def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) -> Model:
64105
"""Generates a new optimization model.
65106
@@ -166,15 +207,18 @@ def _extract_speculative_draft_model_provider(
166207
if speculative_decoding_config is None:
167208
return None
168209

169-
if speculative_decoding_config.get("ModelProvider") == "JumpStart":
210+
if speculative_decoding_config.get("ModelProvider").lower() == "jumpstart":
170211
return "jumpstart"
171212

172213
if speculative_decoding_config.get(
173214
"ModelProvider"
174-
) == "Custom" or speculative_decoding_config.get("ModelSource"):
215+
).lower() == "custom" or speculative_decoding_config.get("ModelSource"):
175216
return "custom"
176217

177-
return "sagemaker"
218+
if speculative_decoding_config.get("ModelProvider").lower() == "sagemaker":
219+
return "sagemaker"
220+
221+
return "auto"
178222

179223

180224
def _extract_additional_model_data_source_s3_uri(

0 commit comments

Comments
 (0)