Skip to content

Commit 0275a2c

Browse files
author
Ashish Gupta
committed
preliminary changed for blackbird - model sharding
1 parent 292a00d commit 0275a2c

File tree

4 files changed

+38
-2
lines changed

4 files changed

+38
-2
lines changed

src/sagemaker/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ def __init__(
371371
self.endpoint_name = None
372372
self.inference_component_name = None
373373
self._is_compiled_model = False
374+
self._is_sharded_model = False
374375
self._compilation_job_name = None
375376
self._is_edge_packaged_model = False
376377
self.inference_recommender_job_results = None
@@ -1595,6 +1596,11 @@ def deploy(
15951596
if self._base_name is not None:
15961597
self._base_name = "-".join((self._base_name, compiled_model_suffix))
15971598

1599+
if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
1600+
logging.warning("Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
1601+
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints.")
1602+
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED
1603+
15981604
# Support multiple models on same endpoint
15991605
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
16001606
if endpoint_name:

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,7 @@ def _optimize_for_jumpstart(
681681
quantization_config: Optional[Dict] = None,
682682
compilation_config: Optional[Dict] = None,
683683
speculative_decoding_config: Optional[Dict] = None,
684+
sharding_config: Optional[Dict] = None,
684685
env_vars: Optional[Dict] = None,
685686
vpc_config: Optional[Dict] = None,
686687
kms_key: Optional[str] = None,
@@ -702,6 +703,8 @@ def _optimize_for_jumpstart(
702703
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
703704
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
704705
Defaults to ``None``
706+
sharding_config (Optional[Dict]): Model sharding configuration.
707+
Defaults to ``None``
705708
env_vars (Optional[Dict]): Additional environment variables to run the optimization
706709
container. Defaults to ``None``.
707710
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.

src/sagemaker/serve/builder/model_builder.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,7 @@ def optimize(
11191119
quantization_config: Optional[Dict] = None,
11201120
compilation_config: Optional[Dict] = None,
11211121
speculative_decoding_config: Optional[Dict] = None,
1122+
sharding_config: Optional[Dict] = None,
11221123
env_vars: Optional[Dict] = None,
11231124
vpc_config: Optional[Dict] = None,
11241125
kms_key: Optional[str] = None,
@@ -1142,6 +1143,8 @@ def optimize(
11421143
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
11431144
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
11441145
Defaults to ``None``
1146+
sharding_config (Optional[Dict]): Model sharding configuration.
1147+
Defaults to ``None``
11451148
env_vars (Optional[Dict]): Additional environment variables to run the optimization
11461149
container. Defaults to ``None``.
11471150
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1170,6 +1173,7 @@ def optimize(
11701173
quantization_config=quantization_config,
11711174
compilation_config=compilation_config,
11721175
speculative_decoding_config=speculative_decoding_config,
1176+
sharding_config=sharding_config,
11731177
env_vars=env_vars,
11741178
vpc_config=vpc_config,
11751179
kms_key=kms_key,
@@ -1189,6 +1193,7 @@ def _model_builder_optimize_wrapper(
11891193
quantization_config: Optional[Dict] = None,
11901194
compilation_config: Optional[Dict] = None,
11911195
speculative_decoding_config: Optional[Dict] = None,
1196+
sharding_config: Optional[Dict] = None,
11921197
env_vars: Optional[Dict] = None,
11931198
vpc_config: Optional[Dict] = None,
11941199
kms_key: Optional[str] = None,
@@ -1212,6 +1217,8 @@ def _model_builder_optimize_wrapper(
12121217
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
12131218
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
12141219
Defaults to ``None``
1220+
sharding_config (Optional[Dict]): Model sharding configuration.
1221+
Defaults to ``None``
12151222
env_vars (Optional[Dict]): Additional environment variables to run the optimization
12161223
container. Defaults to ``None``.
12171224
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1238,6 +1245,12 @@ def _model_builder_optimize_wrapper(
12381245
if quantization_config and compilation_config:
12391246
raise ValueError("Quantization config and compilation config are mutually exclusive.")
12401247

1248+
if sharding_config and (quantization_config or compilation_config or speculative_decoding_config):
1249+
raise ValueError("Sharding config is mutually exclusive and cannot be combined with any other optimization.")
1250+
1251+
if sharding_config and "OPTION_TENSOR_PARALLEL_DEGREE" not in env_vars:
1252+
raise ValueError("OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.")
1253+
12411254
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
12421255
self.instance_type = instance_type or self.instance_type
12431256
self.role_arn = role_arn or self.role_arn
@@ -1254,6 +1267,7 @@ def _model_builder_optimize_wrapper(
12541267
quantization_config=quantization_config,
12551268
compilation_config=compilation_config,
12561269
speculative_decoding_config=speculative_decoding_config,
1270+
sharding_config=sharding_config,
12571271
env_vars=env_vars,
12581272
vpc_config=vpc_config,
12591273
kms_key=kms_key,
@@ -1272,6 +1286,7 @@ def _model_builder_optimize_wrapper(
12721286
quantization_config=quantization_config,
12731287
compilation_config=compilation_config,
12741288
speculative_decoding_config=speculative_decoding_config,
1289+
sharding_config=sharding_config,
12751290
env_vars=env_vars,
12761291
vpc_config=vpc_config,
12771292
kms_key=kms_key,
@@ -1287,6 +1302,9 @@ def _model_builder_optimize_wrapper(
12871302
if not speculative_decoding_config:
12881303
self.pysdk_model.remove_tag_with_key(Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER)
12891304

1305+
if sharding_config:
1306+
self.pysdk_model._is_sharded_model = True
1307+
12901308
return self.pysdk_model
12911309

12921310
def _optimize_for_hf(
@@ -1297,6 +1315,7 @@ def _optimize_for_hf(
12971315
quantization_config: Optional[Dict] = None,
12981316
compilation_config: Optional[Dict] = None,
12991317
speculative_decoding_config: Optional[Dict] = None,
1318+
sharding_config: Optional[Dict] = None,
13001319
env_vars: Optional[Dict] = None,
13011320
vpc_config: Optional[Dict] = None,
13021321
kms_key: Optional[str] = None,
@@ -1312,6 +1331,8 @@ def _optimize_for_hf(
13121331
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
13131332
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
13141333
Defaults to ``None``
1334+
sharding_config (Optional[Dict]): Model sharding configuration.
1335+
Defaults to ``None``
13151336
env_vars (Optional[Dict]): Additional environment variables to run the optimization
13161337
container. Defaults to ``None``.
13171338
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1327,7 +1348,7 @@ def _optimize_for_hf(
13271348
self.pysdk_model, speculative_decoding_config, False
13281349
)
13291350

1330-
if quantization_config or compilation_config:
1351+
if quantization_config or compilation_config or sharding_config:
13311352
create_optimization_job_args = {
13321353
"OptimizationJobName": job_name,
13331354
"DeploymentInstanceType": self.instance_type,

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,15 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool:
259259

260260

261261
def _extract_optimization_config_and_env(
262-
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None
262+
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None,
263+
sharding_config: Optional[Dict] = None
263264
) -> Optional[Tuple[Optional[Dict], Optional[Dict]]]:
264265
"""Extracts optimization config and environment variables.
265266
266267
Args:
267268
quantization_config (Optional[Dict]): The quantization config.
268269
compilation_config (Optional[Dict]): The compilation config.
270+
sharding_config (Optional[Dict]): The sharding config.
269271
270272
Returns:
271273
Optional[Tuple[Optional[Dict], Optional[Dict]]]:
@@ -279,6 +281,10 @@ def _extract_optimization_config_and_env(
279281
return {"ModelCompilationConfig": compilation_config}, compilation_config.get(
280282
"OverrideEnvironment"
281283
)
284+
if sharding_config:
285+
return {"ModelShardingConfig": sharding_config}, sharding_config.get(
286+
"OverrideEnvironment"
287+
)
282288
return None, None
283289

284290

0 commit comments

Comments
 (0)