Skip to content

Feat: Optimize() validations across TRT, VLLM, Neuron container optimizations #4927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def __init__(
self.endpoint_name = None
self.inference_component_name = None
self._is_compiled_model = False
self._is_sharded_model = False
self._compilation_job_name = None
self._is_edge_packaged_model = False
self.inference_recommender_job_results = None
Expand Down Expand Up @@ -1599,6 +1600,19 @@ def deploy(
if self._base_name is not None:
self._base_name = "-".join((self._base_name, compiled_model_suffix))

if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
logging.warning(
"Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
)
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED

if self._is_sharded_model and self._enable_network_isolation:
raise ValueError(
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
"Loading of model requires network access."
)

# Support multiple models on same endpoint
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
if endpoint_name:
Expand Down
23 changes: 20 additions & 3 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,7 @@ def _optimize_for_jumpstart(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -705,6 +706,8 @@ def _optimize_for_jumpstart(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand All @@ -730,8 +733,13 @@ def _optimize_for_jumpstart(
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)

# optimization_config can contain configs for both quantization and compilation
optimization_config, quantization_override_env, compilation_override_env = (
_extract_optimization_config_and_env(quantization_config, compilation_config)
(
optimization_config,
quantization_override_env,
compilation_override_env,
sharding_override_env,
) = _extract_optimization_config_and_env(
quantization_config, compilation_config, sharding_config
)

if not optimization_config:
Expand Down Expand Up @@ -807,11 +815,20 @@ def _optimize_for_jumpstart(
{
**(quantization_override_env or {}),
**(compilation_override_env or {}),
**(sharding_override_env or {}),
},
)
if optimization_env_vars:
self.pysdk_model.env.update(optimization_env_vars)
if quantization_config or is_compilation:

if sharding_config and self.pysdk_model._enable_network_isolation:
logger.warning(
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
"Loading of model requires network access. Setting it to False."
)
self.pysdk_model._enable_network_isolation = False

if quantization_config or sharding_config or is_compilation:
return create_optimization_job_args
return None

Expand Down
79 changes: 76 additions & 3 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
get_huggingface_model_metadata,
download_huggingface_model_metadata,
)
from sagemaker.serve.validations.optimization import _validate_optimization_configuration

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1120,6 +1121,7 @@ def optimize(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -1143,6 +1145,8 @@ def optimize(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand Down Expand Up @@ -1171,6 +1175,7 @@ def optimize(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
Expand All @@ -1190,6 +1195,7 @@ def _model_builder_optimize_wrapper(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -1213,6 +1219,8 @@ def _model_builder_optimize_wrapper(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand All @@ -1227,6 +1235,27 @@ def _model_builder_optimize_wrapper(
Returns:
Model: A deployable ``Model`` object.
"""
if (
hasattr(self, "enable_network_isolation")
and self.enable_network_isolation
and sharding_config
):
raise ValueError(
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
"Loading of model requires network access."
)

# TODO: ideally these dictionaries need to be sagemaker_core shapes
# TODO: for organization, abstract all validation behind this fn
_validate_optimization_configuration(
is_jumpstart=self._is_jumpstart_model_id(),
instance_type=instance_type,
quantization_config=quantization_config,
compilation_config=compilation_config,
sharding_config=sharding_config,
speculative_decoding_config=speculative_decoding_config,
)

self.is_compiled = compilation_config is not None
self.is_quantized = quantization_config is not None
self.speculative_decoding_draft_model_source = _extract_speculative_draft_model_provider(
Expand All @@ -1236,6 +1265,36 @@ def _model_builder_optimize_wrapper(
if self.mode != Mode.SAGEMAKER_ENDPOINT:
raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.")

if sharding_config and (
quantization_config or compilation_config or speculative_decoding_config
):
raise ValueError(
(
"Sharding config is mutually exclusive "
"and cannot be combined with any other optimization."
)
)

if sharding_config:
has_tensor_parallel_degree_in_env_vars = (
env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" in env_vars
)
has_tensor_parallel_degree_in_overrides = (
sharding_config
and sharding_config.get("OverrideEnvironment")
and "OPTION_TENSOR_PARALLEL_DEGREE" in sharding_config.get("OverrideEnvironment")
)
if (
not has_tensor_parallel_degree_in_env_vars
and not has_tensor_parallel_degree_in_overrides
):
raise ValueError(
(
"OPTION_TENSOR_PARALLEL_DEGREE is a required "
"environment variable with sharding config."
)
)

self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
self.instance_type = instance_type or self.instance_type
self.role_arn = role_arn or self.role_arn
Expand All @@ -1252,6 +1311,7 @@ def _model_builder_optimize_wrapper(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
Expand All @@ -1270,12 +1330,16 @@ def _model_builder_optimize_wrapper(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
max_runtime_in_sec=max_runtime_in_sec,
)

if sharding_config:
self.pysdk_model._is_sharded_model = True

if input_args:
optimization_instance_type = input_args["DeploymentInstanceType"]

Expand Down Expand Up @@ -1325,6 +1389,7 @@ def _optimize_for_hf(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -1340,6 +1405,8 @@ def _optimize_for_hf(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand All @@ -1363,7 +1430,7 @@ def _optimize_for_hf(
self.pysdk_model, speculative_decoding_config, False
)

if quantization_config or compilation_config:
if quantization_config or compilation_config or sharding_config:
create_optimization_job_args = {
"OptimizationJobName": job_name,
"DeploymentInstanceType": self.instance_type,
Expand All @@ -1378,8 +1445,13 @@ def _optimize_for_hf(
model_source = _generate_model_source(self.pysdk_model.model_data, False)
create_optimization_job_args["ModelSource"] = model_source

optimization_config, quantization_override_env, compilation_override_env = (
_extract_optimization_config_and_env(quantization_config, compilation_config)
(
optimization_config,
quantization_override_env,
compilation_override_env,
sharding_override_env,
) = _extract_optimization_config_and_env(
quantization_config, compilation_config, sharding_config
)
create_optimization_job_args["OptimizationConfigs"] = [
{k: v} for k, v in optimization_config.items()
Expand All @@ -1388,6 +1460,7 @@ def _optimize_for_hf(
{
**(quantization_override_env or {}),
**(compilation_override_env or {}),
**(sharding_override_env or {}),
}
)

Expand Down
22 changes: 17 additions & 5 deletions src/sagemaker/serve/utils/optimize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,16 +361,19 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool:


def _extract_optimization_config_and_env(
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None
) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]:
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]:
"""Extracts optimization config and environment variables.

Args:
quantization_config (Optional[Dict]): The quantization config.
compilation_config (Optional[Dict]): The compilation config.
sharding_config (Optional[Dict]): The sharding config.

Returns:
Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]:
Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]:
The optimization config and environment variables.
"""
optimization_config = {}
Expand All @@ -380,18 +383,27 @@ def _extract_optimization_config_and_env(
compilation_override_env = (
compilation_config.get("OverrideEnvironment") if compilation_config else None
)
sharding_override_env = sharding_config.get("OverrideEnvironment") if sharding_config else None

if quantization_config is not None:
optimization_config["ModelQuantizationConfig"] = quantization_config

if compilation_config is not None:
optimization_config["ModelCompilationConfig"] = compilation_config

if sharding_config is not None:
optimization_config["ModelShardingConfig"] = sharding_config

# Return optimization config dict and environment variables if either is present
if optimization_config:
return optimization_config, quantization_override_env, compilation_override_env
return (
optimization_config,
quantization_override_env,
compilation_override_env,
sharding_override_env,
)

return None, None, None
return None, None, None, None


def _custom_speculative_decoding(
Expand Down
Loading
Loading