Skip to content

feat: optimization technique related validations. #4921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7ec16e6
Enable quantization and compilation in the same optimization job via …
Sep 18, 2024
cf70f59
Require EULA acceptance when using a gated 1p draft model via ModelBu…
Nov 8, 2024
fcb5092
add accept_draft_model_eula to JumpStartModel when deployment config …
Nov 8, 2024
9489b8d
add map of valid optimization combinations
Nov 8, 2024
5512c26
Add ModelBuilder support for JumpStart-provided draft models.
Nov 9, 2024
c94a78b
Tweak draft model EULA validations and messaging. Remove redundant de…
Nov 9, 2024
d10c475
Add "Auto" speculative decoding ModelProvider option; add validations…
Nov 11, 2024
8fb27a0
Fix JumpStartModel.AdditionalModelDataSource model access config assi…
Nov 12, 2024
779f6d6
move the accept eula configurations into deploy flow
gwang111 Nov 12, 2024
aef3a90
Merge branch 'master' into QuicksilverV2
gwang111 Nov 12, 2024
b7b15b8
move the accept eula configurations into deploy flow
gwang111 Nov 12, 2024
748ea4b
Use correct bucket for SM/JS draft models and minor formatting/valida…
Nov 13, 2024
a7feb54
Remove obsolete docstring.
Nov 13, 2024
694b4f2
remove references to accept_draft_model_eula
gwang111 Nov 13, 2024
7b6aef1
renaming of eula fn and error msg
gwang111 Nov 13, 2024
ce47be5
Merge branch 'master' into QuicksilverV2
gwang111 Nov 13, 2024
1f75072
fix: pin testing deps (#4925)
benieric Nov 13, 2024
277e0b1
Revert "change: add TGI 2.4.0 image uri (#4922)" (#4926)
Captainia Nov 13, 2024
8f0083b
fix naming and messaging
gwang111 Nov 14, 2024
8b73f34
ModelBuilder speculative decoding UTs and minor fixes.
Nov 14, 2024
c06aef0
Merge branch 'master' into QuicksilverV2
gwang111 Nov 14, 2024
09a54dc
Fix set union.
Nov 14, 2024
3b147cd
add UTs for JumpStart deployment
gwang111 Nov 15, 2024
65cb5b3
fix formatting issues
gwang111 Nov 15, 2024
4d1e12b
address validation comments
gwang111 Nov 15, 2024
bf706ad
fix doc strings
gwang111 Nov 15, 2024
f121eb0
Add TRTLLM compilation + speculative decoding validation.
Nov 15, 2024
9148e70
address nits
gwang111 Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
update_dict_if_key_not_present,
resolve_model_sagemaker_config_field,
verify_model_region_and_return_specs,
get_jumpstart_content_bucket,
)

from sagemaker.jumpstart.factory.utils import (
Expand All @@ -70,7 +71,13 @@

from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
from sagemaker.session import Session
from sagemaker.utils import camel_case_to_pascal_case, name_from_base, format_tags, Tags
from sagemaker.utils import (
camel_case_to_pascal_case,
name_from_base,
format_tags,
Tags,
get_domain_for_region,
)
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
from sagemaker import resource_requirements
Expand Down Expand Up @@ -556,6 +563,37 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
return kwargs


def _apply_accept_eula_on_model_data_source(
model_data_source: Dict[str, Any],
model_id: str,
region: str,
accept_eula: bool
):
"""Sets AcceptEula to True for gated speculative decoding models"""

mutable_model_data_source = model_data_source.copy()

hosting_eula_key = mutable_model_data_source.get("hosting_eula_key")
del mutable_model_data_source["hosting_eula_key"]

if not hosting_eula_key:
return mutable_model_data_source

if not accept_eula:
raise ValueError(
(
f"The set deployment config comes optimized with an additional model data source "
f"'{model_id}' that requires accepting end-user license agreement (EULA). "
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
f"{get_domain_for_region(region)}"
f"/{hosting_eula_key} for terms of use. Please set `accept_eula=True` once acknowledged."
)
)

mutable_model_data_source["model_access_config"] = {"accept_eula": accept_eula}
return mutable_model_data_source


def _add_additional_model_data_sources_to_kwargs(
kwargs: JumpStartModelInitKwargs,
) -> JumpStartModelInitKwargs:
Expand All @@ -568,7 +606,11 @@ def _add_additional_model_data_sources_to_kwargs(
data_source.s3_data_source.set_bucket(get_neo_content_bucket(region=kwargs.region))
api_shape_additional_model_data_sources = (
[
camel_case_to_pascal_case(data_source.to_json())
camel_case_to_pascal_case(
_apply_accept_eula_on_model_data_source(
data_source.to_json(), kwargs.model_id, kwargs.region, kwargs.accept_draft_model_eula,
)
)
for data_source in speculative_decoding_data_sources
]
if specs.get_speculative_decoding_s3_data_sources()
Expand Down Expand Up @@ -858,6 +900,7 @@ def get_init_kwargs(
resources: Optional[ResourceRequirements] = None,
config_name: Optional[str] = None,
additional_model_data_sources: Optional[Dict[str, Any]] = None,
accept_draft_model_eula: Optional[bool] = None,
) -> JumpStartModelInitKwargs:
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""

Expand Down Expand Up @@ -892,6 +935,7 @@ def get_init_kwargs(
resources=resources,
config_name=config_name,
additional_model_data_sources=additional_model_data_sources,
accept_draft_model_eula=accept_draft_model_eula,
)
model_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(
kwargs=model_init_kwargs
Expand Down
39 changes: 28 additions & 11 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
resources: Optional[ResourceRequirements] = None,
config_name: Optional[str] = None,
additional_model_data_sources: Optional[Dict[str, Any]] = None,
accept_draft_model_eula: Optional[bool] = None,
):
"""Initializes a ``JumpStartModel``.

Expand Down Expand Up @@ -301,6 +302,10 @@ def __init__(
optionally applied to the model.
additional_model_data_sources (Optional[Dict[str, Any]]): Additional location
of SageMaker model data (default: None).
accept_draft_model_eula (bool): For draft models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_draft_model_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
Raises:
ValueError: If the model ID is not recognized by JumpStart.
"""
Expand Down Expand Up @@ -360,6 +365,7 @@ def _validate_model_id_and_type():
resources=resources,
config_name=config_name,
additional_model_data_sources=additional_model_data_sources,
accept_draft_model_eula=accept_draft_model_eula
)

self.orig_predictor_cls = predictor_cls
Expand Down Expand Up @@ -456,7 +462,9 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload:
sagemaker_session=self.sagemaker_session,
)

def set_deployment_config(self, config_name: str, instance_type: str) -> None:
def set_deployment_config(
self, config_name: str, instance_type: str, accept_draft_model_eula: Optional[bool] = False
) -> None:
"""Sets the deployment config to apply to the model.

Args:
Expand All @@ -466,6 +474,8 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None:
instance_type (str):
The instance_type that the model will use after setting
the config.
accept_draft_model_eula (Optional[bool]):
If the config selected comes with a gated additional model data source.
"""
self.__init__(
model_id=self.model_id,
Expand All @@ -474,6 +484,7 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None:
config_name=config_name,
sagemaker_session=self.sagemaker_session,
role=self.role,
accept_draft_model_eula=accept_draft_model_eula,
)

@property
Expand Down Expand Up @@ -540,12 +551,16 @@ def attach(
inferred_model_id = inferred_model_version = inferred_inference_component_name = None

if inference_component_name is None or model_id is None or model_version is None:
inferred_model_id, inferred_model_version, inferred_inference_component_name, _, _ = (
get_model_info_from_endpoint(
endpoint_name=endpoint_name,
inference_component_name=inference_component_name,
sagemaker_session=sagemaker_session,
)
(
inferred_model_id,
inferred_model_version,
inferred_inference_component_name,
_,
_,
) = get_model_info_from_endpoint(
endpoint_name=endpoint_name,
inference_component_name=inference_component_name,
sagemaker_session=sagemaker_session,
)

model_id = model_id or inferred_model_id
Expand Down Expand Up @@ -1016,10 +1031,11 @@ def _get_deployment_configs(
)

if metadata_config.benchmark_metrics:
err, metadata_config.benchmark_metrics = (
add_instance_rate_stats_to_benchmark_metrics(
self.region, metadata_config.benchmark_metrics
)
(
err,
metadata_config.benchmark_metrics,
) = add_instance_rate_stats_to_benchmark_metrics(
self.region, metadata_config.benchmark_metrics
)

config_components = metadata_config.config_components.get(config_name)
Expand All @@ -1042,6 +1058,7 @@ def _get_deployment_configs(
region=self.region,
model_version=self.model_version,
hub_arn=self.hub_arn,
accept_draft_model_eula=True,
)
deploy_kwargs = get_deploy_kwargs(
model_id=self.model_id,
Expand Down
7 changes: 6 additions & 1 deletion src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@ class AdditionalModelDataSource(JumpStartDataHolderType):

SERIALIZATION_EXCLUSION_SET: Set[str] = set()

__slots__ = ["channel_name", "s3_data_source"]
__slots__ = ["channel_name", "s3_data_source", "hosting_eula_key"]

def __init__(self, spec: Dict[str, Any]):
"""Initializes a AdditionalModelDataSource object.
Expand All @@ -1101,6 +1101,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
"""
self.channel_name: str = json_obj["channel_name"]
self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"])
self.hosting_eula_key: str = json_obj.get("hosting_eula_key")

def to_json(self, exclude_keys=True) -> Dict[str, Any]:
"""Returns json representation of AdditionalModelDataSource object."""
Expand Down Expand Up @@ -2116,6 +2117,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
"hub_content_type",
"model_reference_arn",
"specs",
"accept_draft_model_eula",
]

SERIALIZATION_EXCLUSION_SET = {
Expand All @@ -2131,6 +2133,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
"training_instance_type",
"config_name",
"hub_content_type",
"accept_draft_model_eula",
}

def __init__(
Expand Down Expand Up @@ -2165,6 +2168,7 @@ def __init__(
resources: Optional[ResourceRequirements] = None,
config_name: Optional[str] = None,
additional_model_data_sources: Optional[Dict[str, Any]] = None,
accept_draft_model_eula: Optional[bool] = False
) -> None:
"""Instantiates JumpStartModelInitKwargs object."""

Expand Down Expand Up @@ -2198,6 +2202,7 @@ def __init__(
self.resources = resources
self.config_name = config_name
self.additional_model_data_sources = additional_model_data_sources
self.accept_draft_model_eula = accept_draft_model_eula


class JumpStartModelDeployKwargs(JumpStartKwargs):
Expand Down
54 changes: 39 additions & 15 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
_custom_speculative_decoding,
SPECULATIVE_DRAFT_MODEL,
_is_inferentia_or_trainium,
_validate_and_set_eula_for_draft_model_sources,
)
from sagemaker.serve.utils.predictors import (
DjlLocalModePredictor,
Expand Down Expand Up @@ -501,7 +502,9 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration
)

def set_deployment_config(self, config_name: str, instance_type: str) -> None:
def set_deployment_config(
self, config_name: str, instance_type: str, accept_draft_model_eula: Optional[bool] = False
) -> None:
"""Sets the deployment config to apply to the model.

Args:
Expand All @@ -511,11 +514,13 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None:
instance_type (str):
The instance_type that the model will use after setting
the config.
accept_draft_model_eula (Optional[bool]):
If the config selected comes with a gated additional model data source.
"""
if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
raise Exception("Cannot set deployment config to an uninitialized model.")

self.pysdk_model.set_deployment_config(config_name, instance_type)
self.pysdk_model.set_deployment_config(config_name, instance_type, accept_draft_model_eula)
self.deployment_config_name = config_name

self.instance_type = instance_type
Expand Down Expand Up @@ -718,24 +723,32 @@ def _optimize_for_jumpstart(
f"Model '{self.model}' requires accepting end-user license agreement (EULA)."
)

is_compilation = (not quantization_config) and (
(compilation_config is not None) or _is_inferentia_or_trainium(instance_type)
is_compilation = (compilation_config is not None) or _is_inferentia_or_trainium(
instance_type
)

pysdk_model_env_vars = dict()
if is_compilation:
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)

optimization_config, override_env = _extract_optimization_config_and_env(
quantization_config, compilation_config
# optimization_config can contain configs for both quantization and compilation
optimization_config, quantization_override_env, compilation_override_env = (
_extract_optimization_config_and_env(quantization_config, compilation_config)
)
if not optimization_config and is_compilation:
override_env = override_env or pysdk_model_env_vars
optimization_config = {
"ModelCompilationConfig": {
"OverrideEnvironment": override_env,
}
}
if (
not optimization_config or not optimization_config.get("ModelCompilationConfig")
) and is_compilation:
# Fallback to default if override_env is None or empty
if not compilation_override_env:
compilation_override_env = pysdk_model_env_vars

# Update optimization_config with ModelCompilationConfig
override_compilation_config = (
{"OverrideEnvironment": compilation_override_env}
if compilation_override_env
else {}
)
optimization_config["ModelCompilationConfig"] = override_compilation_config

if speculative_decoding_config:
self._set_additional_model_source(speculative_decoding_config)
Expand Down Expand Up @@ -766,7 +779,7 @@ def _optimize_for_jumpstart(
"OptimizationJobName": job_name,
"ModelSource": model_source,
"DeploymentInstanceType": self.instance_type,
"OptimizationConfigs": [optimization_config],
"OptimizationConfigs": [{k: v} for k, v in optimization_config.items()],
"OutputConfig": output_config,
"RoleArn": self.role_arn,
}
Expand All @@ -789,7 +802,13 @@ def _optimize_for_jumpstart(
"AcceptEula": True
}

optimization_env_vars = _update_environment_variables(optimization_env_vars, override_env)
optimization_env_vars = _update_environment_variables(
optimization_env_vars,
{
**(quantization_override_env or {}),
**(compilation_override_env or {}),
},
)
if optimization_env_vars:
self.pysdk_model.env.update(optimization_env_vars)
if quantization_config or is_compilation:
Expand Down Expand Up @@ -849,6 +868,11 @@ def _set_additional_model_source(
"Cannot find deployment config compatible for optimization job."
)

_validate_and_set_eula_for_draft_model_sources(
pysdk_model=self.pysdk_model,
accept_eula=speculative_decoding_config.get("AcceptEula"),
)

self.pysdk_model.env.update(
{"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}"}
)
Expand Down
Loading
Loading