Skip to content

Commit f5b568a

Browse files
Ashish Guptagwang111
authored andcommitted
changes for blackbird - model sharding
changes for blackbird - model sharding add more tests fix sharded model flag add optimization validations fix formatting and msging fixing validation bugs add UTs simplify logic update messaging formatting fix UTs add more UTs fix validations update ruleset update formatting update validation logic update bug fixes Disable network isolation if using sharded models. check sharding + network iso pre optimization add more UTs for sharding add more UTs
1 parent 64e138b commit f5b568a

File tree

8 files changed

+1138
-22
lines changed

8 files changed

+1138
-22
lines changed

src/sagemaker/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def __init__(
372372
self.endpoint_name = None
373373
self.inference_component_name = None
374374
self._is_compiled_model = False
375+
self._is_sharded_model = False
375376
self._compilation_job_name = None
376377
self._is_edge_packaged_model = False
377378
self.inference_recommender_job_results = None
@@ -1599,6 +1600,19 @@ def deploy(
15991600
if self._base_name is not None:
16001601
self._base_name = "-".join((self._base_name, compiled_model_suffix))
16011602

1603+
if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
1604+
logging.warning(
1605+
"Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
1606+
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
1607+
)
1608+
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED
1609+
1610+
if self._is_sharded_model and self._enable_network_isolation:
1611+
raise ValueError(
1612+
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
1613+
"Loading of model requires network access."
1614+
)
1615+
16021616
# Support multiple models on same endpoint
16031617
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
16041618
if endpoint_name:

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,7 @@ def _optimize_for_jumpstart(
684684
quantization_config: Optional[Dict] = None,
685685
compilation_config: Optional[Dict] = None,
686686
speculative_decoding_config: Optional[Dict] = None,
687+
sharding_config: Optional[Dict] = None,
687688
env_vars: Optional[Dict] = None,
688689
vpc_config: Optional[Dict] = None,
689690
kms_key: Optional[str] = None,
@@ -705,6 +706,8 @@ def _optimize_for_jumpstart(
705706
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
706707
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
707708
Defaults to ``None``
709+
sharding_config (Optional[Dict]): Model sharding configuration.
710+
Defaults to ``None``
708711
env_vars (Optional[Dict]): Additional environment variables to run the optimization
709712
container. Defaults to ``None``.
710713
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -730,8 +733,13 @@ def _optimize_for_jumpstart(
730733
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)
731734

732735
# optimization_config can contain configs for both quantization and compilation
733-
optimization_config, quantization_override_env, compilation_override_env = (
734-
_extract_optimization_config_and_env(quantization_config, compilation_config)
736+
(
737+
optimization_config,
738+
quantization_override_env,
739+
compilation_override_env,
740+
sharding_override_env,
741+
) = _extract_optimization_config_and_env(
742+
quantization_config, compilation_config, sharding_config
735743
)
736744

737745
if not optimization_config:
@@ -807,11 +815,20 @@ def _optimize_for_jumpstart(
807815
{
808816
**(quantization_override_env or {}),
809817
**(compilation_override_env or {}),
818+
**(sharding_override_env or {}),
810819
},
811820
)
812821
if optimization_env_vars:
813822
self.pysdk_model.env.update(optimization_env_vars)
814-
if quantization_config or is_compilation:
823+
824+
if sharding_config and self.pysdk_model._enable_network_isolation:
825+
logger.warning(
826+
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
827+
"Loading of model requires network access. Setting it to False."
828+
)
829+
self.pysdk_model._enable_network_isolation = False
830+
831+
if quantization_config or sharding_config or is_compilation:
815832
return create_optimization_job_args
816833
return None
817834

src/sagemaker/serve/builder/model_builder.py

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
get_huggingface_model_metadata,
106106
download_huggingface_model_metadata,
107107
)
108+
from sagemaker.serve.validations.optimization import _validate_optimization_configuration
108109

109110
logger = logging.getLogger(__name__)
110111

@@ -1120,6 +1121,7 @@ def optimize(
11201121
quantization_config: Optional[Dict] = None,
11211122
compilation_config: Optional[Dict] = None,
11221123
speculative_decoding_config: Optional[Dict] = None,
1124+
sharding_config: Optional[Dict] = None,
11231125
env_vars: Optional[Dict] = None,
11241126
vpc_config: Optional[Dict] = None,
11251127
kms_key: Optional[str] = None,
@@ -1143,6 +1145,8 @@ def optimize(
11431145
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
11441146
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
11451147
Defaults to ``None``
1148+
sharding_config (Optional[Dict]): Model sharding configuration.
1149+
Defaults to ``None``
11461150
env_vars (Optional[Dict]): Additional environment variables to run the optimization
11471151
container. Defaults to ``None``.
11481152
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1171,6 +1175,7 @@ def optimize(
11711175
quantization_config=quantization_config,
11721176
compilation_config=compilation_config,
11731177
speculative_decoding_config=speculative_decoding_config,
1178+
sharding_config=sharding_config,
11741179
env_vars=env_vars,
11751180
vpc_config=vpc_config,
11761181
kms_key=kms_key,
@@ -1190,6 +1195,7 @@ def _model_builder_optimize_wrapper(
11901195
quantization_config: Optional[Dict] = None,
11911196
compilation_config: Optional[Dict] = None,
11921197
speculative_decoding_config: Optional[Dict] = None,
1198+
sharding_config: Optional[Dict] = None,
11931199
env_vars: Optional[Dict] = None,
11941200
vpc_config: Optional[Dict] = None,
11951201
kms_key: Optional[str] = None,
@@ -1213,6 +1219,8 @@ def _model_builder_optimize_wrapper(
12131219
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
12141220
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
12151221
Defaults to ``None``
1222+
sharding_config (Optional[Dict]): Model sharding configuration.
1223+
Defaults to ``None``
12161224
env_vars (Optional[Dict]): Additional environment variables to run the optimization
12171225
container. Defaults to ``None``.
12181226
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1227,6 +1235,26 @@ def _model_builder_optimize_wrapper(
12271235
Returns:
12281236
Model: A deployable ``Model`` object.
12291237
"""
1238+
if (
1239+
hasattr(self, "enable_network_isolation")
1240+
and self.enable_network_isolation
1241+
and sharding_config
1242+
):
1243+
raise ValueError(
1244+
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
1245+
"Loading of model requires network access."
1246+
)
1247+
1248+
# TODO: ideally these dictionaries need to be sagemaker_core shapes
1249+
# TODO: for organization, abstract all validation behind this fn
1250+
_validate_optimization_configuration(
1251+
instance_type=instance_type,
1252+
quantization_config=quantization_config,
1253+
compilation_config=compilation_config,
1254+
sharding_config=sharding_config,
1255+
speculative_decoding_config=speculative_decoding_config,
1256+
)
1257+
12301258
self.is_compiled = compilation_config is not None
12311259
self.is_quantized = quantization_config is not None
12321260
self.speculative_decoding_draft_model_source = _extract_speculative_draft_model_provider(
@@ -1236,6 +1264,43 @@ def _model_builder_optimize_wrapper(
12361264
if self.mode != Mode.SAGEMAKER_ENDPOINT:
12371265
raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.")
12381266

1267+
if sharding_config and (
1268+
quantization_config or compilation_config or speculative_decoding_config
1269+
):
1270+
raise ValueError(
1271+
"Sharding config is mutually exclusive and cannot be combined with any other optimization."
1272+
)
1273+
1274+
if sharding_config and (
1275+
quantization_config or compilation_config or speculative_decoding_config
1276+
):
1277+
raise ValueError(
1278+
(
1279+
"Sharding config is mutually exclusive "
1280+
"and cannot be combined with any other optimization."
1281+
)
1282+
)
1283+
1284+
if sharding_config:
1285+
has_tensor_parallel_degree_in_env_vars = (
1286+
env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" in env_vars
1287+
)
1288+
has_tensor_parallel_degree_in_overrides = (
1289+
sharding_config
1290+
and sharding_config.get("OverrideEnvironment")
1291+
and "OPTION_TENSOR_PARALLEL_DEGREE" in sharding_config.get("OverrideEnvironment")
1292+
)
1293+
if (
1294+
not has_tensor_parallel_degree_in_env_vars
1295+
and not has_tensor_parallel_degree_in_overrides
1296+
):
1297+
raise ValueError(
1298+
(
1299+
"OPTION_TENSOR_PARALLEL_DEGREE is a required "
1300+
"environment variable with sharding config."
1301+
)
1302+
)
1303+
12391304
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
12401305
self.instance_type = instance_type or self.instance_type
12411306
self.role_arn = role_arn or self.role_arn
@@ -1252,6 +1317,7 @@ def _model_builder_optimize_wrapper(
12521317
quantization_config=quantization_config,
12531318
compilation_config=compilation_config,
12541319
speculative_decoding_config=speculative_decoding_config,
1320+
sharding_config=sharding_config,
12551321
env_vars=env_vars,
12561322
vpc_config=vpc_config,
12571323
kms_key=kms_key,
@@ -1270,12 +1336,16 @@ def _model_builder_optimize_wrapper(
12701336
quantization_config=quantization_config,
12711337
compilation_config=compilation_config,
12721338
speculative_decoding_config=speculative_decoding_config,
1339+
sharding_config=sharding_config,
12731340
env_vars=env_vars,
12741341
vpc_config=vpc_config,
12751342
kms_key=kms_key,
12761343
max_runtime_in_sec=max_runtime_in_sec,
12771344
)
12781345

1346+
if sharding_config:
1347+
self.pysdk_model._is_sharded_model = True
1348+
12791349
if input_args:
12801350
optimization_instance_type = input_args["DeploymentInstanceType"]
12811351

@@ -1325,6 +1395,7 @@ def _optimize_for_hf(
13251395
quantization_config: Optional[Dict] = None,
13261396
compilation_config: Optional[Dict] = None,
13271397
speculative_decoding_config: Optional[Dict] = None,
1398+
sharding_config: Optional[Dict] = None,
13281399
env_vars: Optional[Dict] = None,
13291400
vpc_config: Optional[Dict] = None,
13301401
kms_key: Optional[str] = None,
@@ -1340,6 +1411,8 @@ def _optimize_for_hf(
13401411
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
13411412
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
13421413
Defaults to ``None``
1414+
sharding_config (Optional[Dict]): Model sharding configuration.
1415+
Defaults to ``None``
13431416
env_vars (Optional[Dict]): Additional environment variables to run the optimization
13441417
container. Defaults to ``None``.
13451418
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1363,7 +1436,7 @@ def _optimize_for_hf(
13631436
self.pysdk_model, speculative_decoding_config, False
13641437
)
13651438

1366-
if quantization_config or compilation_config:
1439+
if quantization_config or compilation_config or sharding_config:
13671440
create_optimization_job_args = {
13681441
"OptimizationJobName": job_name,
13691442
"DeploymentInstanceType": self.instance_type,
@@ -1378,16 +1451,20 @@ def _optimize_for_hf(
13781451
model_source = _generate_model_source(self.pysdk_model.model_data, False)
13791452
create_optimization_job_args["ModelSource"] = model_source
13801453

1381-
optimization_config, quantization_override_env, compilation_override_env = (
1382-
_extract_optimization_config_and_env(quantization_config, compilation_config)
1383-
)
1454+
(
1455+
optimization_config,
1456+
quantization_override_env,
1457+
compilation_override_env,
1458+
sharding_override_env,
1459+
) = _extract_optimization_config_and_env(quantization_config, compilation_config)
13841460
create_optimization_job_args["OptimizationConfigs"] = [
13851461
{k: v} for k, v in optimization_config.items()
13861462
]
13871463
self.pysdk_model.env.update(
13881464
{
13891465
**(quantization_override_env or {}),
13901466
**(compilation_override_env or {}),
1467+
**(sharding_override_env or {}),
13911468
}
13921469
)
13931470

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -361,16 +361,19 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool:
361361

362362

363363
def _extract_optimization_config_and_env(
364-
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None
365-
) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]:
364+
quantization_config: Optional[Dict] = None,
365+
compilation_config: Optional[Dict] = None,
366+
sharding_config: Optional[Dict] = None,
367+
) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]:
366368
"""Extracts optimization config and environment variables.
367369
368370
Args:
369371
quantization_config (Optional[Dict]): The quantization config.
370372
compilation_config (Optional[Dict]): The compilation config.
373+
sharding_config (Optional[Dict]): The sharding config.
371374
372375
Returns:
373-
Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]:
376+
Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]:
374377
The optimization config and environment variables.
375378
"""
376379
optimization_config = {}
@@ -380,19 +383,27 @@ def _extract_optimization_config_and_env(
380383
compilation_override_env = (
381384
compilation_config.get("OverrideEnvironment") if compilation_config else None
382385
)
386+
sharding_override_env = sharding_config.get("OverrideEnvironment") if sharding_config else None
383387

384388
if quantization_config is not None:
385389
optimization_config["ModelQuantizationConfig"] = quantization_config
386390

387391
if compilation_config is not None:
388392
optimization_config["ModelCompilationConfig"] = compilation_config
389393

394+
if sharding_config is not None:
395+
optimization_config["ModelShardingConfig"] = sharding_config
396+
390397
# Return optimization config dict and environment variables if either is present
391398
if optimization_config:
392-
return optimization_config, quantization_override_env, compilation_override_env
393-
394-
return None, None, None
399+
return (
400+
optimization_config,
401+
quantization_override_env,
402+
compilation_override_env,
403+
sharding_override_env,
404+
)
395405

406+
return None, None, None, None
396407

397408
def _custom_speculative_decoding(
398409
model: Model,

0 commit comments

Comments
 (0)