Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def __init__(
self.endpoint_name = None
self.inference_component_name = None
self._is_compiled_model = False
self._is_sharded_model = False
self._compilation_job_name = None
self._is_edge_packaged_model = False
self.inference_recommender_job_results = None
Expand Down Expand Up @@ -1599,6 +1600,11 @@ def deploy(
if self._base_name is not None:
self._base_name = "-".join((self._base_name, compiled_model_suffix))

if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
logging.warning("Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints.")
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED

# Support multiple models on same endpoint
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
if endpoint_name:
Expand Down
7 changes: 5 additions & 2 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ def _optimize_for_jumpstart(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -702,6 +703,8 @@ def _optimize_for_jumpstart(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand All @@ -727,7 +730,7 @@ def _optimize_for_jumpstart(
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)

optimization_config, override_env = _extract_optimization_config_and_env(
quantization_config, compilation_config
quantization_config, compilation_config, sharding_config
)
if not optimization_config and is_compilation:
override_env = override_env or pysdk_model_env_vars
Expand Down Expand Up @@ -792,7 +795,7 @@ def _optimize_for_jumpstart(
optimization_env_vars = _update_environment_variables(optimization_env_vars, override_env)
if optimization_env_vars:
self.pysdk_model.env.update(optimization_env_vars)
if quantization_config or is_compilation:
if quantization_config or sharding_config or is_compilation:
return create_optimization_job_args
return None

Expand Down
23 changes: 22 additions & 1 deletion src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,7 @@ def optimize(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -1142,6 +1143,8 @@ def optimize(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand Down Expand Up @@ -1170,6 +1173,7 @@ def optimize(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
Expand All @@ -1189,6 +1193,7 @@ def _model_builder_optimize_wrapper(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -1212,6 +1217,8 @@ def _model_builder_optimize_wrapper(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand All @@ -1238,6 +1245,12 @@ def _model_builder_optimize_wrapper(
if quantization_config and compilation_config:
raise ValueError("Quantization config and compilation config are mutually exclusive.")

if sharding_config and (quantization_config or compilation_config or speculative_decoding_config):
raise ValueError("Sharding config is mutually exclusive and cannot be combined with any other optimization.")

if sharding_config and ((env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" not in env_vars) or (sharding_config.get("OverrideEnvironment") and "OPTION_TENSOR_PARALLEL_DEGREE" not in sharding_config["OverrideEnvironment"])):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side note, the same validation is also performed in NeoLambda

raise ValueError("OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.")

self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
self.instance_type = instance_type or self.instance_type
self.role_arn = role_arn or self.role_arn
Expand All @@ -1254,6 +1267,7 @@ def _model_builder_optimize_wrapper(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
Expand All @@ -1272,6 +1286,7 @@ def _model_builder_optimize_wrapper(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
Expand All @@ -1287,6 +1302,9 @@ def _model_builder_optimize_wrapper(
if not speculative_decoding_config:
self.pysdk_model.remove_tag_with_key(Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER)

if sharding_config:
self.pysdk_model._is_sharded_model = True

return self.pysdk_model

def _optimize_for_hf(
Expand All @@ -1297,6 +1315,7 @@ def _optimize_for_hf(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -1312,6 +1331,8 @@ def _optimize_for_hf(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand All @@ -1327,7 +1348,7 @@ def _optimize_for_hf(
self.pysdk_model, speculative_decoding_config, False
)

if quantization_config or compilation_config:
if quantization_config or compilation_config or sharding_config:
create_optimization_job_args = {
"OptimizationJobName": job_name,
"DeploymentInstanceType": self.instance_type,
Expand Down
8 changes: 7 additions & 1 deletion src/sagemaker/serve/utils/optimize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,15 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool:


def _extract_optimization_config_and_env(
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None
) -> Optional[Tuple[Optional[Dict], Optional[Dict]]]:
"""Extracts optimization config and environment variables.

Args:
quantization_config (Optional[Dict]): The quantization config.
compilation_config (Optional[Dict]): The compilation config.
sharding_config (Optional[Dict]): The sharding config.

Returns:
Optional[Tuple[Optional[Dict], Optional[Dict]]]:
Expand All @@ -279,6 +281,10 @@ def _extract_optimization_config_and_env(
return {"ModelCompilationConfig": compilation_config}, compilation_config.get(
"OverrideEnvironment"
)
if sharding_config:
return {"ModelShardingConfig": sharding_config}, sharding_config.get(
"OverrideEnvironment"
)
return None, None


Expand Down
34 changes: 34 additions & 0 deletions tests/unit/sagemaker/serve/builder/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2667,6 +2667,40 @@ def test_optimize_exclusive_args(self, mock_get_serve_setting):
),
)

@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
def test_optimize_exclusive_sharding(self, mock_get_serve_setting):
mock_sagemaker_session = Mock()
model_builder = ModelBuilder(
model="meta-textgeneration-llama-3-70b",
sagemaker_session=mock_sagemaker_session,
)

self.assertRaisesRegex(
ValueError,
"Sharding config is mutually exclusive and cannot be combined with any other optimization.",
lambda: model_builder.optimize(
quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
),
)

@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
def test_optimize_exclusive_sharding_args(self, mock_get_serve_setting):
mock_sagemaker_session = Mock()
model_builder = ModelBuilder(
model="meta-textgeneration-llama-3-70b",
sagemaker_session=mock_sagemaker_session,
)

self.assertRaisesRegex(
ValueError,
"OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.",
lambda: model_builder.optimize(
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
),
)

@patch.object(ModelBuilder, "_prepare_for_mode")
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
def test_optimize_for_hf_with_custom_s3_path(
Expand Down
25 changes: 23 additions & 2 deletions tests/unit/sagemaker/serve/utils/test_optimize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def test_is_s3_uri(s3_uri, expected):


@pytest.mark.parametrize(
"quantization_config, compilation_config, expected_config, expected_env",
"quantization_config, compilation_config, sharding_config, expected_config, expected_env",
[
(
None,
Expand All @@ -270,6 +270,7 @@ def test_is_s3_uri(s3_uri, expected):
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
}
},
None,
{
"ModelCompilationConfig": {
"OverrideEnvironment": {
Expand All @@ -288,6 +289,7 @@ def test_is_s3_uri(s3_uri, expected):
}
},
None,
None,
{
"ModelQuantizationConfig": {
"OverrideEnvironment": {
Expand All @@ -299,7 +301,26 @@ def test_is_s3_uri(s3_uri, expected):
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
},
),
(None, None, None, None),
(
None,
None,
{
"OverrideEnvironment": {
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
}
},
{
"ModelShardingConfig": {
"OverrideEnvironment": {
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
}
},
},
{
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
},
),
(None, None, None, None, None),
],
)
def test_extract_optimization_config_and_env(
Expand Down
Loading