Skip to content

feat: optimization technique related validations. #4921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7ec16e6
Enable quantization and compilation in the same optimization job via …
Sep 18, 2024
cf70f59
Require EULA acceptance when using a gated 1p draft model via ModelBu…
Nov 8, 2024
fcb5092
add accept_draft_model_eula to JumpStartModel when deployment config …
Nov 8, 2024
9489b8d
add map of valid optimization combinations
Nov 8, 2024
5512c26
Add ModelBuilder support for JumpStart-provided draft models.
Nov 9, 2024
c94a78b
Tweak draft model EULA validations and messaging. Remove redundant de…
Nov 9, 2024
d10c475
Add "Auto" speculative decoding ModelProvider option; add validations…
Nov 11, 2024
8fb27a0
Fix JumpStartModel.AdditionalModelDataSource model access config assi…
Nov 12, 2024
779f6d6
move the accept eula configurations into deploy flow
gwang111 Nov 12, 2024
aef3a90
Merge branch 'master' into QuicksilverV2
gwang111 Nov 12, 2024
b7b15b8
move the accept eula configurations into deploy flow
gwang111 Nov 12, 2024
748ea4b
Use correct bucket for SM/JS draft models and minor formatting/valida…
Nov 13, 2024
a7feb54
Remove obsolete docstring.
Nov 13, 2024
694b4f2
remove references to accept_draft_model_eula
gwang111 Nov 13, 2024
7b6aef1
renaming of eula fn and error msg
gwang111 Nov 13, 2024
ce47be5
Merge branch 'master' into QuicksilverV2
gwang111 Nov 13, 2024
1f75072
fix: pin testing deps (#4925)
benieric Nov 13, 2024
277e0b1
Revert "change: add TGI 2.4.0 image uri (#4922)" (#4926)
Captainia Nov 13, 2024
8f0083b
fix naming and messaging
gwang111 Nov 14, 2024
8b73f34
ModelBuilder speculative decoding UTs and minor fixes.
Nov 14, 2024
c06aef0
Merge branch 'master' into QuicksilverV2
gwang111 Nov 14, 2024
09a54dc
Fix set union.
Nov 14, 2024
3b147cd
add UTs for JumpStart deployment
gwang111 Nov 15, 2024
65cb5b3
fix formatting issues
gwang111 Nov 15, 2024
4d1e12b
address validation comments
gwang111 Nov 15, 2024
bf706ad
fix doc strings
gwang111 Nov 15, 2024
f121eb0
Add TRTLLM compilation + speculative decoding validation.
Nov 15, 2024
9148e70
address nits
gwang111 Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def deploy(
managed_instance_scaling: Optional[str] = None,
endpoint_type: EndpointType = EndpointType.MODEL_BASED,
routing_config: Optional[Dict[str, Any]] = None,
model_access_configs: Optional[List[ModelAccessConfig]] = None,
model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None,
) -> PredictorBase:
"""Creates endpoint by calling base ``Model`` class `deploy` method.

Expand Down Expand Up @@ -766,7 +766,7 @@ def deploy(
ModelAccessConfig, provide a `{ "model_id", ModelAccessConfig(accept_eula=True) }`
to indicate whether model terms of use have been accepted. The `accept_eula` value
must be explicitly defined as `True` in order to accept the end-user license
agreement (EULA) that some. (Default: None)
agreement (EULA) that some models require. (Default: None)

Raises:
MarketplaceModelSubscriptionError: If the caller is not subscribed to the model.
Expand Down Expand Up @@ -817,12 +817,14 @@ def deploy(
f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models."
)

print(self.additional_model_data_sources)
self.additional_model_data_sources = _add_model_access_configs_to_model_data_sources(
self.additional_model_data_sources,
deploy_kwargs.model_access_configs,
deploy_kwargs.model_id,
deploy_kwargs.region,
)
print(self.additional_model_data_sources)

try:
predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict())
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,9 +1122,9 @@ def to_json(self, exclude_keys=True) -> Dict[str, Any]:
class JumpStartModelDataSource(AdditionalModelDataSource):
"""Data class JumpStart additional model data source."""

SERIALIZATION_EXCLUSION_SET = {
"artifact_version"
} | AdditionalModelDataSource.SERIALIZATION_EXCLUSION_SET
SERIALIZATION_EXCLUSION_SET = AdditionalModelDataSource.SERIALIZATION_EXCLUSION_SET.union(
{"artifact_version"}
)

__slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__

Expand Down
45 changes: 32 additions & 13 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,12 +558,12 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
"""Returns EULA message to display if one is available, else empty string."""
if model_specs.hosting_eula_key is None:
return ""
return format_eula_message_template(
return get_formatted_eula_message_template(
model_id=model_specs.model_id, region=region, hosting_eula_key=model_specs.hosting_eula_key
)


def format_eula_message_template(model_id: str, region: str, hosting_eula_key: str):
def get_formatted_eula_message_template(model_id: str, region: str, hosting_eula_key: str) -> str:
"""Returns a formatted EULA message."""
return (
f"Model '{model_id}' requires accepting end-user license agreement (EULA). "
Expand Down Expand Up @@ -1542,17 +1542,32 @@ def _add_model_access_configs_to_model_data_sources(
model_access_configs: Dict[str, ModelAccessConfig],
model_id: str,
region: str,
):
"""Sets AcceptEula to True for gated speculative decoding models"""
) -> List[Dict[str, any]]:
"""Iterate over the accept EULA configs to ensure all channels are matched

Args:
model_data_sources (DeploymentConfigMetadata): Model data sources that will be updated
model_access_configs (DeploymentConfigMetadata): Config holding accept_eula field
model_id (DeploymentConfigMetadata): Jumpstart mode id.
region (str): Region where the user is operating in.
Returns:
List[Dict[str, Any]]: List of model data sources with accept EULA configs applied
Raise:
ValueError if at least one channel that requires EULA acceptance as not passed.
"""
if not model_data_sources:
return model_data_sources

acked_model_data_sources = []
for model_data_source in model_data_sources:
hosting_eula_key = model_data_source.get("HostingEulaKey")
mutable_model_data_source = model_data_source.copy()
if hosting_eula_key:
if not model_access_configs or not model_access_configs.get(model_id):
if (
not model_access_configs
or not model_access_configs.get(model_id)
or not model_access_configs.get(model_id).accept_eula
):
eula_message_template = (
"{model_source}{base_eula_message}{model_access_configs_message}"
)
Expand All @@ -1562,24 +1577,28 @@ def _add_model_access_configs_to_model_data_sources(
raise ValueError(
eula_message_template.format(
model_source="Additional " if model_data_source.get("ChannelName") else "",
base_eula_message=format_eula_message_template(
base_eula_message=get_formatted_eula_message_template(
model_id=model_id, region=region, hosting_eula_key=hosting_eula_key
),
model_access_configs_message=(
" Please add a ModelAccessConfig entry:"
"Please add a ModelAccessConfig entry:"
f" {model_access_config_entry} "
"to model_access_configs to acknowledge the EULA."
"to model_access_configs to accept the EULA."
),
)
)
acked_model_data_source = model_data_source.copy()
acked_model_data_source.pop("HostingEulaKey")
acked_model_data_source["S3DataSource"]["ModelAccessConfig"] = (
mutable_model_data_source.pop(
"HostingEulaKey"
) # pop when model access config is applied
mutable_model_data_source["S3DataSource"]["ModelAccessConfig"] = (
camel_case_to_pascal_case(model_access_configs.get(model_id).model_dump())
)
acked_model_data_sources.append(acked_model_data_source)
acked_model_data_sources.append(mutable_model_data_source)
else:
acked_model_data_sources.append(model_data_source)
mutable_model_data_source.pop(
"HostingEulaKey"
) # pop when model access config is not applicable
acked_model_data_sources.append(mutable_model_data_source)
return acked_model_data_sources


Expand Down
8 changes: 4 additions & 4 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,9 +737,7 @@ def _optimize_for_jumpstart(
if not optimization_config:
optimization_config = {}

if (
not optimization_config or not optimization_config.get("ModelCompilationConfig")
) and is_compilation:
if not optimization_config.get("ModelCompilationConfig") and is_compilation:
# Fallback to default if override_env is None or empty
if not compilation_override_env:
compilation_override_env = pysdk_model_env_vars
Expand Down Expand Up @@ -907,7 +905,9 @@ def _set_additional_model_source(
)
else:
self.pysdk_model = _custom_speculative_decoding(
self.pysdk_model, speculative_decoding_config, speculative_decoding_config.get("AcceptEula", False)
self.pysdk_model,
speculative_decoding_config,
speculative_decoding_config.get("AcceptEula", False),
)

def _find_compatible_deployment_config(
Expand Down
16 changes: 12 additions & 4 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def _model_builder_deploy_wrapper(
)

if "endpoint_logging" not in kwargs:
kwargs["endpoint_logging"] = True
kwargs["endpoint_logging"] = False
predictor = self._original_deploy(
*args,
instance_type=instance_type,
Expand Down Expand Up @@ -1283,7 +1283,7 @@ def _model_builder_optimize_wrapper(
# TRTLLM is used by Neo if the following are provided:
# 1) a GPU instance type
# 2) compilation config
gpu_instance_families = ["g4", "g5", "p4d"]
gpu_instance_families = ["g5", "g6", "p4d", "p4de", "p5"]
is_gpu_instance = optimization_instance_type and any(
gpu_instance_family in optimization_instance_type
for gpu_instance_family in gpu_instance_families
Expand All @@ -1296,8 +1296,16 @@ def _model_builder_optimize_wrapper(
keyword in self.model.lower() for keyword in llama_3_1_keywords
)

if is_gpu_instance and self.model and is_llama_3_1 and self.is_compiled:
raise ValueError("Compilation is not supported for Llama-3.1 with a GPU instance.")
if is_gpu_instance and self.model and self.is_compiled:
if is_llama_3_1:
raise ValueError(
"Compilation is not supported for Llama-3.1 with a GPU instance."
)
if speculative_decoding_config:
raise ValueError(
"Compilation is not supported with speculative decoding with "
"a GPU instance."
)

self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args)
job_status = self.sagemaker_session.wait_for_optimization_job(job_name)
Expand Down
23 changes: 11 additions & 12 deletions src/sagemaker/serve/utils/optimize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,8 @@ def _deployment_config_contains_draft_model(deployment_config: Optional[Dict]) -
return False
deployment_args = deployment_config.get("DeploymentArgs", {})
additional_data_sources = deployment_args.get("AdditionalDataSources")
if not additional_data_sources:
return False
return additional_data_sources.get("speculative_decoding", False)

return "speculative_decoding" in additional_data_sources if additional_data_sources else False


def _is_draft_model_jumpstart_provided(deployment_config: Optional[Dict]) -> bool:
Expand Down Expand Up @@ -207,15 +206,15 @@ def _extract_speculative_draft_model_provider(
if speculative_decoding_config is None:
return None

if speculative_decoding_config.get("ModelProvider").lower() == "jumpstart":
model_provider = speculative_decoding_config.get("ModelProvider", "").lower()

if model_provider == "jumpstart":
return "jumpstart"

if speculative_decoding_config.get(
"ModelProvider"
).lower() == "custom" or speculative_decoding_config.get("ModelSource"):
if model_provider == "custom" or speculative_decoding_config.get("ModelSource"):
return "custom"

if speculative_decoding_config.get("ModelProvider").lower() == "sagemaker":
if model_provider == "sagemaker":
return "sagemaker"

return "auto"
Expand All @@ -238,7 +237,7 @@ def _extract_additional_model_data_source_s3_uri(
):
return None

return additional_model_data_source.get("S3DataSource").get("S3Uri", None)
return additional_model_data_source.get("S3DataSource").get("S3Uri")


def _extract_deployment_config_additional_model_data_source_s3_uri(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate of _extract_additional_model_data_source_s3_uri ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deployment config uses Pascal case while the PySDK model will use snake case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack.

ToDo: We need to find a different way of closing these differences.

Expand Down Expand Up @@ -272,7 +271,7 @@ def _is_draft_model_gated(
Returns:
bool: Whether the draft model is gated or not.
"""
return draft_model_config.get("hosting_eula_key", None)
return "hosting_eula_key" in draft_model_config if draft_model_config else False


def _extracts_and_validates_speculative_model_source(
Expand Down Expand Up @@ -371,7 +370,7 @@ def _extract_optimization_config_and_env(
compilation_config (Optional[Dict]): The compilation config.

Returns:
Optional[Tuple[Optional[Dict], Optional[Dict]]]:
Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]:
The optimization config and environment variables.
"""
optimization_config = {}
Expand All @@ -388,7 +387,7 @@ def _extract_optimization_config_and_env(
if compilation_config is not None:
optimization_config["ModelCompilationConfig"] = compilation_config

# Return both dicts and environment variable if either is present
# Return optimization config dict and environment variables if either is present
if optimization_config:
return optimization_config, quantization_override_env, compilation_override_env

Expand Down

This file was deleted.

Loading
Loading