Skip to content

Commit 5512c26

Browse files
author
Joseph Zhang
committed
Add ModelBuilder support for JumpStart-provided draft models.
1 parent 9489b8d commit 5512c26

File tree

3 files changed

+100
-31
lines changed

3 files changed

+100
-31
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
SPECULATIVE_DRAFT_MODEL,
5050
_is_inferentia_or_trainium,
5151
_validate_and_set_eula_for_draft_model_sources,
52+
_jumpstart_speculative_decoding,
5253
)
5354
from sagemaker.serve.utils.predictors import (
5455
DjlLocalModePredictor,
@@ -503,7 +504,7 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
503504
)
504505

505506
def set_deployment_config(
506-
self, config_name: str, instance_type: str, accept_draft_model_eula: Optional[bool] = False
507+
self, config_name: str, instance_type: str, accept_draft_model_eula: Optional[bool] = False
507508
) -> None:
508509
"""Sets the deployment config to apply to the model.
509510
@@ -735,6 +736,10 @@ def _optimize_for_jumpstart(
735736
optimization_config, quantization_override_env, compilation_override_env = (
736737
_extract_optimization_config_and_env(quantization_config, compilation_config)
737738
)
739+
740+
if not optimization_config:
741+
optimization_config = {}
742+
738743
if (
739744
not optimization_config or not optimization_config.get("ModelCompilationConfig")
740745
) and is_compilation:
@@ -844,6 +849,7 @@ def _set_additional_model_source(
844849
"""
845850
if speculative_decoding_config:
846851
model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config)
852+
847853
channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources)
848854

849855
if model_provider == "sagemaker":
@@ -868,17 +874,23 @@ def _set_additional_model_source(
868874
"Cannot find deployment config compatible for optimization job."
869875
)
870876

871-
_validate_and_set_eula_for_draft_model_sources(
872-
pysdk_model=self.pysdk_model,
873-
accept_eula=speculative_decoding_config.get("AcceptEula"),
874-
)
877+
_validate_and_set_eula_for_draft_model_sources(
878+
pysdk_model=self.pysdk_model,
879+
accept_eula=speculative_decoding_config.get("AcceptEula"),
880+
)
875881

876882
self.pysdk_model.env.update(
877-
{"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}"}
883+
{"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}/"}
878884
)
879885
self.pysdk_model.add_tags(
880886
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"},
881887
)
888+
elif model_provider == "jumpstart":
889+
_jumpstart_speculative_decoding(
890+
model=self.pysdk_model,
891+
speculative_decoding_config=speculative_decoding_config,
892+
sagemaker_session=self.sagemaker_session,
893+
)
882894
else:
883895
self.pysdk_model = _custom_speculative_decoding(
884896
self.pysdk_model, speculative_decoding_config, accept_eula

src/sagemaker/serve/builder/model_builder.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
_is_s3_uri,
7777
_custom_speculative_decoding,
7878
_extract_speculative_draft_model_provider,
79+
_jumpstart_speculative_decoding,
7980
)
8081
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
8182
from sagemaker.serve.utils.hardware_detector import (
@@ -99,7 +100,6 @@
99100
validate_image_uri_and_hardware,
100101
)
101102
from sagemaker.utils import Tags
102-
from sagemaker.serve.utils.optimize_utils import _validate_and_set_eula_for_draft_model_sources
103103
from sagemaker.workflow.entities import PipelineVariable
104104
from sagemaker.huggingface.llm_utils import (
105105
get_huggingface_model_metadata,
@@ -590,21 +590,6 @@ def _model_builder_deploy_wrapper(
590590
model_server=self.model_server,
591591
)
592592

593-
if self.deployment_config:
594-
accept_draft_model_eula = kwargs.get("accept_draft_model_eula", False)
595-
try:
596-
_validate_and_set_eula_for_draft_model_sources(
597-
pysdk_model=self,
598-
accept_eula=accept_draft_model_eula,
599-
)
600-
except ValueError as e:
601-
logger.error(
602-
"This deployment tried to use a gated draft model but the EULA was not "
603-
"accepted. Please review the EULA, set accept_draft_model_eula to True, "
604-
"and try again."
605-
)
606-
raise e
607-
608593
if "endpoint_logging" not in kwargs:
609594
kwargs["endpoint_logging"] = True
610595
predictor = self._original_deploy(
@@ -1358,9 +1343,17 @@ def _optimize_for_hf(
13581343
Returns:
13591344
Optional[Dict[str, Any]]: Model optimization job input arguments.
13601345
"""
1361-
self.pysdk_model = _custom_speculative_decoding(
1362-
self.pysdk_model, speculative_decoding_config, False
1363-
)
1346+
if speculative_decoding_config:
1347+
if speculative_decoding_config.get("ModelProvider", "") == "JumpStart":
1348+
_jumpstart_speculative_decoding(
1349+
model=self.pysdk_model,
1350+
speculative_decoding_config=speculative_decoding_config,
1351+
sagemaker_session=self.sagemaker_session,
1352+
)
1353+
else:
1354+
self.pysdk_model = _custom_speculative_decoding(
1355+
self.pysdk_model, speculative_decoding_config, False
1356+
)
13641357

13651358
if quantization_config or compilation_config:
13661359
create_optimization_job_args = {

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
import logging
1818
from typing import Dict, Any, Optional, Union, List, Tuple
1919

20-
from sagemaker import Model
20+
from sagemaker import Model, Session
2121
from sagemaker.enums import Tag
22+
from sagemaker.jumpstart.utils import accessors, get_eula_message
23+
2224

2325
logger = logging.getLogger(__name__)
2426

@@ -164,6 +166,9 @@ def _extract_speculative_draft_model_provider(
164166
if speculative_decoding_config is None:
165167
return None
166168

169+
if speculative_decoding_config.get("ModelProvider") == "JumpStart":
170+
return "jumpstart"
171+
167172
if speculative_decoding_config.get(
168173
"ModelProvider"
169174
) == "Custom" or speculative_decoding_config.get("ModelSource"):
@@ -292,7 +297,7 @@ def _generate_additional_model_data_sources(
292297
},
293298
}
294299
if accept_eula:
295-
additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = {"ACCEPT_EULA": True}
300+
additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = {"AcceptEula": True}
296301

297302
return [additional_model_data_source]
298303

@@ -327,10 +332,10 @@ def _extract_optimization_config_and_env(
327332
"""
328333
optimization_config = {}
329334
quantization_override_env = (
330-
quantization_config.get("OverrideEnvironment", {}) if quantization_config else None
335+
quantization_config.get("OverrideEnvironment") if quantization_config else None
331336
)
332337
compilation_override_env = (
333-
compilation_config.get("OverrideEnvironment", {}) if compilation_config else None
338+
compilation_config.get("OverrideEnvironment") if compilation_config else None
334339
)
335340

336341
if quantization_config is not None:
@@ -343,7 +348,7 @@ def _extract_optimization_config_and_env(
343348
if optimization_config:
344349
return optimization_config, quantization_override_env, compilation_override_env
345350

346-
return {}, None, None
351+
return None, None, None
347352

348353

349354
def _custom_speculative_decoding(
@@ -364,7 +369,7 @@ def _custom_speculative_decoding(
364369
speculative_decoding_config
365370
)
366371

367-
accept_eula = speculative_decoding_config.get("AcceptEula", False)
372+
accept_eula = speculative_decoding_config.get("AcceptEula", accept_eula)
368373

369374
if _is_s3_uri(additional_model_source):
370375
channel_name = _generate_channel_name(model.additional_model_data_sources)
@@ -384,6 +389,65 @@ def _custom_speculative_decoding(
384389
return model
385390

386391

392+
def _jumpstart_speculative_decoding(
393+
model=Model,
394+
speculative_decoding_config: Optional[Dict[str, Any]] = None,
395+
sagemaker_session: Optional[Session] = None,
396+
):
397+
"""Modifies the given model for speculative decoding config with JumpStart provider.
398+
399+
Args:
400+
model (Model): The model.
401+
speculative_decoding_config (Optional[Dict]): The speculative decoding config.
402+
sagemaker_session (Optional[Session]): Sagemaker session for execution.
403+
"""
404+
if speculative_decoding_config:
405+
js_id = speculative_decoding_config.get("ModelID")
406+
if not js_id:
407+
raise ValueError(
408+
"`ModelID` is a required field in `speculative_decoding_config` when "
409+
"using JumpStart as draft model provider."
410+
)
411+
model_version = speculative_decoding_config.get("ModelVersion", "*")
412+
accept_eula = speculative_decoding_config.get("AcceptEula", False)
413+
channel_name = _generate_channel_name(model.additional_model_data_sources)
414+
415+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
416+
model_id=js_id,
417+
version=model_version,
418+
region=sagemaker_session.boto_region_name,
419+
sagemaker_session=sagemaker_session,
420+
)
421+
model_spec_json = model_specs.to_json()
422+
423+
js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket()
424+
425+
if model_spec_json.get("gated_bucket", False):
426+
if not accept_eula:
427+
eula_message = get_eula_message(
428+
model_specs=model_specs, region=sagemaker_session.boto_region_name
429+
)
430+
raise ValueError(
431+
f"{eula_message} Please set `AcceptEula` to True in "
432+
f"speculative_decoding_config once acknowledged."
433+
)
434+
js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket()
435+
436+
key_prefix = model_spec_json.get("hosting_prepacked_artifact_key")
437+
model.additional_model_data_sources = _generate_additional_model_data_sources(
438+
f"s3://{js_bucket}/{key_prefix}",
439+
channel_name,
440+
accept_eula,
441+
)
442+
443+
model.env.update(
444+
{"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}/"}
445+
)
446+
model.add_tags(
447+
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "jumpstart"},
448+
)
449+
450+
387451
def _validate_and_set_eula_for_draft_model_sources(
388452
pysdk_model: Model,
389453
accept_eula: bool = False,

0 commit comments

Comments
 (0)