105105 get_huggingface_model_metadata ,
106106 download_huggingface_model_metadata ,
107107)
108+ from sagemaker .serve .validations .optimization import _validate_optimization_configuration
108109
109110logger = logging .getLogger (__name__ )
110111
@@ -1120,6 +1121,7 @@ def optimize(
11201121 quantization_config : Optional [Dict ] = None ,
11211122 compilation_config : Optional [Dict ] = None ,
11221123 speculative_decoding_config : Optional [Dict ] = None ,
1124+ sharding_config : Optional [Dict ] = None ,
11231125 env_vars : Optional [Dict ] = None ,
11241126 vpc_config : Optional [Dict ] = None ,
11251127 kms_key : Optional [str ] = None ,
@@ -1143,6 +1145,8 @@ def optimize(
11431145 compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
11441146 speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
11451147 Defaults to ``None``
1148+ sharding_config (Optional[Dict]): Model sharding configuration.
1149+ Defaults to ``None``
11461150 env_vars (Optional[Dict]): Additional environment variables to run the optimization
11471151 container. Defaults to ``None``.
11481152 vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1171,6 +1175,7 @@ def optimize(
11711175 quantization_config = quantization_config ,
11721176 compilation_config = compilation_config ,
11731177 speculative_decoding_config = speculative_decoding_config ,
1178+ sharding_config = sharding_config ,
11741179 env_vars = env_vars ,
11751180 vpc_config = vpc_config ,
11761181 kms_key = kms_key ,
@@ -1190,6 +1195,7 @@ def _model_builder_optimize_wrapper(
11901195 quantization_config : Optional [Dict ] = None ,
11911196 compilation_config : Optional [Dict ] = None ,
11921197 speculative_decoding_config : Optional [Dict ] = None ,
1198+ sharding_config : Optional [Dict ] = None ,
11931199 env_vars : Optional [Dict ] = None ,
11941200 vpc_config : Optional [Dict ] = None ,
11951201 kms_key : Optional [str ] = None ,
@@ -1213,6 +1219,8 @@ def _model_builder_optimize_wrapper(
12131219 compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
12141220 speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
12151221 Defaults to ``None``
1222+ sharding_config (Optional[Dict]): Model sharding configuration.
1223+ Defaults to ``None``
12161224 env_vars (Optional[Dict]): Additional environment variables to run the optimization
12171225 container. Defaults to ``None``.
12181226 vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1227,6 +1235,27 @@ def _model_builder_optimize_wrapper(
12271235 Returns:
12281236 Model: A deployable ``Model`` object.
12291237 """
1238+ if (
1239+ hasattr (self , "enable_network_isolation" )
1240+ and self .enable_network_isolation
1241+ and sharding_config
1242+ ):
1243+ raise ValueError (
1244+ "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
1245+ "Loading of model requires network access."
1246+ )
1247+
1248+ # TODO: ideally these dictionaries need to be sagemaker_core shapes
1249+ # TODO: for organization, abstract all validation behind this fn
1250+ _validate_optimization_configuration (
1251+ is_jumpstart = self ._is_jumpstart_model_id (),
1252+ instance_type = instance_type ,
1253+ quantization_config = quantization_config ,
1254+ compilation_config = compilation_config ,
1255+ sharding_config = sharding_config ,
1256+ speculative_decoding_config = speculative_decoding_config ,
1257+ )
1258+
12301259 self .is_compiled = compilation_config is not None
12311260 self .is_quantized = quantization_config is not None
12321261 self .speculative_decoding_draft_model_source = _extract_speculative_draft_model_provider (
@@ -1236,6 +1265,36 @@ def _model_builder_optimize_wrapper(
12361265 if self .mode != Mode .SAGEMAKER_ENDPOINT :
12371266 raise ValueError ("Model optimization is only supported in Sagemaker Endpoint Mode." )
12381267
1268+ if sharding_config and (
1269+ quantization_config or compilation_config or speculative_decoding_config
1270+ ):
1271+ raise ValueError (
1272+ (
1273+ "Sharding config is mutually exclusive "
1274+ "and cannot be combined with any other optimization."
1275+ )
1276+ )
1277+
1278+ if sharding_config :
1279+ has_tensor_parallel_degree_in_env_vars = (
1280+ env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" in env_vars
1281+ )
1282+ has_tensor_parallel_degree_in_overrides = (
1283+ sharding_config
1284+ and sharding_config .get ("OverrideEnvironment" )
1285+ and "OPTION_TENSOR_PARALLEL_DEGREE" in sharding_config .get ("OverrideEnvironment" )
1286+ )
1287+ if (
1288+ not has_tensor_parallel_degree_in_env_vars
1289+ and not has_tensor_parallel_degree_in_overrides
1290+ ):
1291+ raise ValueError (
1292+ (
1293+ "OPTION_TENSOR_PARALLEL_DEGREE is a required "
1294+ "environment variable with sharding config."
1295+ )
1296+ )
1297+
12391298 self .sagemaker_session = sagemaker_session or self .sagemaker_session or Session ()
12401299 self .instance_type = instance_type or self .instance_type
12411300 self .role_arn = role_arn or self .role_arn
@@ -1252,6 +1311,7 @@ def _model_builder_optimize_wrapper(
12521311 quantization_config = quantization_config ,
12531312 compilation_config = compilation_config ,
12541313 speculative_decoding_config = speculative_decoding_config ,
1314+ sharding_config = sharding_config ,
12551315 env_vars = env_vars ,
12561316 vpc_config = vpc_config ,
12571317 kms_key = kms_key ,
@@ -1270,12 +1330,16 @@ def _model_builder_optimize_wrapper(
12701330 quantization_config = quantization_config ,
12711331 compilation_config = compilation_config ,
12721332 speculative_decoding_config = speculative_decoding_config ,
1333+ sharding_config = sharding_config ,
12731334 env_vars = env_vars ,
12741335 vpc_config = vpc_config ,
12751336 kms_key = kms_key ,
12761337 max_runtime_in_sec = max_runtime_in_sec ,
12771338 )
12781339
1340+ if sharding_config :
1341+ self .pysdk_model ._is_sharded_model = True
1342+
12791343 if input_args :
12801344 optimization_instance_type = input_args ["DeploymentInstanceType" ]
12811345
@@ -1325,6 +1389,7 @@ def _optimize_for_hf(
13251389 quantization_config : Optional [Dict ] = None ,
13261390 compilation_config : Optional [Dict ] = None ,
13271391 speculative_decoding_config : Optional [Dict ] = None ,
1392+ sharding_config : Optional [Dict ] = None ,
13281393 env_vars : Optional [Dict ] = None ,
13291394 vpc_config : Optional [Dict ] = None ,
13301395 kms_key : Optional [str ] = None ,
@@ -1340,6 +1405,8 @@ def _optimize_for_hf(
13401405 compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
13411406 speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
13421407 Defaults to ``None``
1408+ sharding_config (Optional[Dict]): Model sharding configuration.
1409+ Defaults to ``None``
13431410 env_vars (Optional[Dict]): Additional environment variables to run the optimization
13441411 container. Defaults to ``None``.
13451412 vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1363,7 +1430,7 @@ def _optimize_for_hf(
13631430 self .pysdk_model , speculative_decoding_config , False
13641431 )
13651432
1366- if quantization_config or compilation_config :
1433+ if quantization_config or compilation_config or sharding_config :
13671434 create_optimization_job_args = {
13681435 "OptimizationJobName" : job_name ,
13691436 "DeploymentInstanceType" : self .instance_type ,
@@ -1378,8 +1445,13 @@ def _optimize_for_hf(
13781445 model_source = _generate_model_source (self .pysdk_model .model_data , False )
13791446 create_optimization_job_args ["ModelSource" ] = model_source
13801447
1381- optimization_config , quantization_override_env , compilation_override_env = (
1382- _extract_optimization_config_and_env (quantization_config , compilation_config )
1448+ (
1449+ optimization_config ,
1450+ quantization_override_env ,
1451+ compilation_override_env ,
1452+ sharding_override_env ,
1453+ ) = _extract_optimization_config_and_env (
1454+ quantization_config , compilation_config , sharding_config
13831455 )
13841456 create_optimization_job_args ["OptimizationConfigs" ] = [
13851457 {k : v } for k , v in optimization_config .items ()
@@ -1388,6 +1460,7 @@ def _optimize_for_hf(
13881460 {
13891461 ** (quantization_override_env or {}),
13901462 ** (compilation_override_env or {}),
1463+ ** (sharding_override_env or {}),
13911464 }
13921465 )
13931466
0 commit comments