105105 get_huggingface_model_metadata ,
106106 download_huggingface_model_metadata ,
107107)
108+ from sagemaker .serve .validations .optimization import _validate_optimization_configuration
108109
109110logger = logging .getLogger (__name__ )
110111
@@ -1120,6 +1121,7 @@ def optimize(
11201121 quantization_config : Optional [Dict ] = None ,
11211122 compilation_config : Optional [Dict ] = None ,
11221123 speculative_decoding_config : Optional [Dict ] = None ,
1124+ sharding_config : Optional [Dict ] = None ,
11231125 env_vars : Optional [Dict ] = None ,
11241126 vpc_config : Optional [Dict ] = None ,
11251127 kms_key : Optional [str ] = None ,
@@ -1143,6 +1145,8 @@ def optimize(
11431145 compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
11441146 speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
11451147 Defaults to ``None``
1148+ sharding_config (Optional[Dict]): Model sharding configuration.
1149+ Defaults to ``None``
11461150 env_vars (Optional[Dict]): Additional environment variables to run the optimization
11471151 container. Defaults to ``None``.
11481152 vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1171,6 +1175,7 @@ def optimize(
11711175 quantization_config = quantization_config ,
11721176 compilation_config = compilation_config ,
11731177 speculative_decoding_config = speculative_decoding_config ,
1178+ sharding_config = sharding_config ,
11741179 env_vars = env_vars ,
11751180 vpc_config = vpc_config ,
11761181 kms_key = kms_key ,
@@ -1190,6 +1195,7 @@ def _model_builder_optimize_wrapper(
11901195 quantization_config : Optional [Dict ] = None ,
11911196 compilation_config : Optional [Dict ] = None ,
11921197 speculative_decoding_config : Optional [Dict ] = None ,
1198+ sharding_config : Optional [Dict ] = None ,
11931199 env_vars : Optional [Dict ] = None ,
11941200 vpc_config : Optional [Dict ] = None ,
11951201 kms_key : Optional [str ] = None ,
@@ -1213,6 +1219,8 @@ def _model_builder_optimize_wrapper(
12131219 compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
12141220 speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
12151221 Defaults to ``None``
1222+ sharding_config (Optional[Dict]): Model sharding configuration.
1223+ Defaults to ``None``
12161224 env_vars (Optional[Dict]): Additional environment variables to run the optimization
12171225 container. Defaults to ``None``.
12181226 vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1227,6 +1235,26 @@ def _model_builder_optimize_wrapper(
12271235 Returns:
12281236 Model: A deployable ``Model`` object.
12291237 """
1238+ if (
1239+ hasattr (self , "enable_network_isolation" )
1240+ and self .enable_network_isolation
1241+ and sharding_config
1242+ ):
1243+ raise ValueError (
1244+ "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
1245+ "Loading of model requires network access."
1246+ )
1247+
1248+ # TODO: ideally these dictionaries need to be sagemaker_core shapes
1249+ # TODO: for organization, abstract all validation behind this fn
1250+ _validate_optimization_configuration (
1251+ instance_type = instance_type ,
1252+ quantization_config = quantization_config ,
1253+ compilation_config = compilation_config ,
1254+ sharding_config = sharding_config ,
1255+ speculative_decoding_config = speculative_decoding_config ,
1256+ )
1257+
12301258 self .is_compiled = compilation_config is not None
12311259 self .is_quantized = quantization_config is not None
12321260 self .speculative_decoding_draft_model_source = _extract_speculative_draft_model_provider (
@@ -1236,6 +1264,43 @@ def _model_builder_optimize_wrapper(
12361264 if self .mode != Mode .SAGEMAKER_ENDPOINT :
12371265 raise ValueError ("Model optimization is only supported in Sagemaker Endpoint Mode." )
12381266
1267+ if sharding_config and (
1268+ quantization_config or compilation_config or speculative_decoding_config
1269+ ):
1270+ raise ValueError (
1271+ "Sharding config is mutually exclusive and cannot be combined with any other optimization."
1272+ )
1273+
1274+ if sharding_config and (
1275+ quantization_config or compilation_config or speculative_decoding_config
1276+ ):
1277+ raise ValueError (
1278+ (
1279+ "Sharding config is mutually exclusive "
1280+ "and cannot be combined with any other optimization."
1281+ )
1282+ )
1283+
1284+ if sharding_config :
1285+ has_tensor_parallel_degree_in_env_vars = (
1286+ env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" in env_vars
1287+ )
1288+ has_tensor_parallel_degree_in_overrides = (
1289+ sharding_config
1290+ and sharding_config .get ("OverrideEnvironment" )
1291+ and "OPTION_TENSOR_PARALLEL_DEGREE" in sharding_config .get ("OverrideEnvironment" )
1292+ )
1293+ if (
1294+ not has_tensor_parallel_degree_in_env_vars
1295+ and not has_tensor_parallel_degree_in_overrides
1296+ ):
1297+ raise ValueError (
1298+ (
1299+ "OPTION_TENSOR_PARALLEL_DEGREE is a required "
1300+ "environment variable with sharding config."
1301+ )
1302+ )
1303+
12391304 self .sagemaker_session = sagemaker_session or self .sagemaker_session or Session ()
12401305 self .instance_type = instance_type or self .instance_type
12411306 self .role_arn = role_arn or self .role_arn
@@ -1252,6 +1317,7 @@ def _model_builder_optimize_wrapper(
12521317 quantization_config = quantization_config ,
12531318 compilation_config = compilation_config ,
12541319 speculative_decoding_config = speculative_decoding_config ,
1320+ sharding_config = sharding_config ,
12551321 env_vars = env_vars ,
12561322 vpc_config = vpc_config ,
12571323 kms_key = kms_key ,
@@ -1270,12 +1336,16 @@ def _model_builder_optimize_wrapper(
12701336 quantization_config = quantization_config ,
12711337 compilation_config = compilation_config ,
12721338 speculative_decoding_config = speculative_decoding_config ,
1339+ sharding_config = sharding_config ,
12731340 env_vars = env_vars ,
12741341 vpc_config = vpc_config ,
12751342 kms_key = kms_key ,
12761343 max_runtime_in_sec = max_runtime_in_sec ,
12771344 )
12781345
1346+ if sharding_config :
1347+ self .pysdk_model ._is_sharded_model = True
1348+
12791349 if input_args :
12801350 optimization_instance_type = input_args ["DeploymentInstanceType" ]
12811351
@@ -1325,6 +1395,7 @@ def _optimize_for_hf(
13251395 quantization_config : Optional [Dict ] = None ,
13261396 compilation_config : Optional [Dict ] = None ,
13271397 speculative_decoding_config : Optional [Dict ] = None ,
1398+ sharding_config : Optional [Dict ] = None ,
13281399 env_vars : Optional [Dict ] = None ,
13291400 vpc_config : Optional [Dict ] = None ,
13301401 kms_key : Optional [str ] = None ,
@@ -1340,6 +1411,8 @@ def _optimize_for_hf(
13401411 compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
13411412 speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
13421413 Defaults to ``None``
1414+ sharding_config (Optional[Dict]): Model sharding configuration.
1415+ Defaults to ``None``
13431416 env_vars (Optional[Dict]): Additional environment variables to run the optimization
13441417 container. Defaults to ``None``.
13451418 vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1363,7 +1436,7 @@ def _optimize_for_hf(
13631436 self .pysdk_model , speculative_decoding_config , False
13641437 )
13651438
1366- if quantization_config or compilation_config :
1439+ if quantization_config or compilation_config or sharding_config :
13671440 create_optimization_job_args = {
13681441 "OptimizationJobName" : job_name ,
13691442 "DeploymentInstanceType" : self .instance_type ,
@@ -1378,16 +1451,20 @@ def _optimize_for_hf(
13781451 model_source = _generate_model_source (self .pysdk_model .model_data , False )
13791452 create_optimization_job_args ["ModelSource" ] = model_source
13801453
1381- optimization_config , quantization_override_env , compilation_override_env = (
1382- _extract_optimization_config_and_env (quantization_config , compilation_config )
1383- )
1454+ (
1455+ optimization_config ,
1456+ quantization_override_env ,
1457+ compilation_override_env ,
1458+ sharding_override_env ,
1459+ ) = _extract_optimization_config_and_env (quantization_config , compilation_config )
13841460 create_optimization_job_args ["OptimizationConfigs" ] = [
13851461 {k : v } for k , v in optimization_config .items ()
13861462 ]
13871463 self .pysdk_model .env .update (
13881464 {
13891465 ** (quantization_override_env or {}),
13901466 ** (compilation_override_env or {}),
1467+ ** (sharding_override_env or {}),
13911468 }
13921469 )
13931470
0 commit comments